diff --git a/src/simd/aarch64.zig b/src/simd/aarch64.zig index 1aa39e1dd..0f28ae339 100644 --- a/src/simd/aarch64.zig +++ b/src/simd/aarch64.zig @@ -4,6 +4,36 @@ const std = @import("std"); const assert = std.debug.assert; +pub inline fn vaddlvq_u8(v: @Vector(16, u8)) u16 { + const result = asm ( + \\ uaddlv %[ret:h], %[v].16b + : [ret] "=w" (-> @Vector(8, u16)), + : [v] "w" (v), + ); + + return result[0]; +} + +pub inline fn vaddvq_u8(v: @Vector(16, u8)) u8 { + const result = asm ( + \\ addv %[ret:b], %[v].16b + : [ret] "=w" (-> @Vector(16, u8)), + : [v] "w" (v), + ); + + return result[0]; +} + +pub inline fn vaddv_u8(v: @Vector(8, u8)) u8 { + const result = asm ( + \\ addv %[ret:b], %[v].8b + : [ret] "=w" (-> @Vector(8, u8)), + : [v] "w" (v), + ); + + return result[0]; +} + pub inline fn vandq_u8(a: @Vector(16, u8), b: @Vector(16, u8)) @Vector(16, u8) { return asm ( \\ and %[ret].16b, %[a].16b, %[b].16b @@ -31,6 +61,15 @@ pub inline fn vcgeq_u8(a: @Vector(16, u8), b: @Vector(16, u8)) @Vector(16, u8) { ); } +pub inline fn vcgtq_s8(a: @Vector(16, i8), b: @Vector(16, i8)) @Vector(16, u8) { + return asm ( + \\ cmgt %[ret].16b, %[a].16b, %[b].16b + : [ret] "=w" (-> @Vector(16, u8)), + : [a] "w" (a), + [b] "w" (b), + ); +} + pub inline fn vcgtq_u8(a: @Vector(16, u8), b: @Vector(16, u8)) @Vector(16, u8) { return asm ( \\ cmhi %[ret].16b, %[a].16b, %[b].16b @@ -40,6 +79,30 @@ pub inline fn vcgtq_u8(a: @Vector(16, u8), b: @Vector(16, u8)) @Vector(16, u8) { ); } +pub inline fn vcnt_u8(v: @Vector(8, u8)) @Vector(8, u8) { + return asm ( + \\ cnt %[ret].8b, %[v].8b + : [ret] "=w" (-> @Vector(8, u8)), + : [v] "w" (v), + ); +} + +pub inline fn vcreate_u8(v: u64) @Vector(8, u8) { + return asm ( + \\ ins %[ret].D[0], %[value] + : [ret] "=w" (-> @Vector(8, u8)), + : [value] "r" (v), + ); +} + +pub inline fn vdupq_n_s8(v: i8) @Vector(16, i8) { + return asm ( + \\ dup %[ret].16b, %[value:w] + : [ret] "=w" (-> @Vector(16, i8)), + : [value] "r" (v), + ); +} + pub inline fn vdupq_n_u8(v: u8) @Vector(16, u8) { return asm ( \\ dup %[ret].16b, %[value:w] @@ -76,6 +139,15 @@ pub inline fn vget_lane_u64(v: @Vector(1, u64)) u64 { ); } +pub inline fn vgetq_lane_u64(v: @Vector(2, u64), n: u1) u64 { + return asm ( + \\ umov %[ret], %[v].d[%[n]] + : [ret] "=r" (-> u64), + : [v] "w" (v), + [n] "I" (n), + ); +} + pub inline fn vld1q_u8(v: []const u8) @Vector(16, u8) { return asm ( \\ ld1 { %[ret].16b }, [%[value]] diff --git a/src/simd/main.zig b/src/simd/main.zig index 3493f133d..c29f5e201 100644 --- a/src/simd/main.zig +++ b/src/simd/main.zig @@ -1,7 +1,8 @@ const std = @import("std"); pub const isa = @import("isa.zig"); -pub const utf8 = @import("utf8.zig"); +pub const utf8_count = @import("utf8_count.zig"); +pub const utf8_validate = @import("utf8_validate.zig"); pub const index_of = @import("index_of.zig"); // TODO: temporary, only for zig build simd to inspect disasm easily diff --git a/src/simd/utf8_count.zig b/src/simd/utf8_count.zig new file mode 100644 index 000000000..98e4c77d1 --- /dev/null +++ b/src/simd/utf8_count.zig @@ -0,0 +1,72 @@ +const std = @import("std"); +const isa = @import("isa.zig"); +const aarch64 = @import("aarch64.zig"); + +/// Count the number of UTF-8 codepoints in the given string. The string +/// is assumed to be valid UTF-8. Invalid UTF-8 will result in undefined +/// (and probably incorrect) behaviour. +pub const Count = fn ([]const u8) usize; + +/// Returns the count function for the given ISA. +pub fn countFunc(v: isa.ISA) *const Count { + return switch (v) { + .avx2 => &Scalar.count, // todo + .neon => &Neon.count, + .scalar => &Scalar.count, + }; +} + +pub const Scalar = struct { + pub fn count(input: []const u8) usize { + return std.unicode.utf8CountCodepoints(input) catch unreachable; + } +}; + +/// Arm NEON implementation of the count function. +pub const Neon = struct { + pub fn count(input: []const u8) usize { + var result: usize = 0; + var i: usize = 0; + while (i + 16 <= input.len) : (i += 16) { + const input_vec = aarch64.vld1q_u8(input[i..]); + result += @intCast(process(input_vec)); + } + + if (i < input.len) result += Scalar.count(input[i..]); + return result; + } + + pub fn process(v: @Vector(16, u8)) u8 { + // Find all the bits greater than -65 in binary (0b10000001) which + // are a leading byte of a UTF-8 codepoint. This will set the resulting + // vector to 0xFF for all leading bytes and 0x00 for all non-leading. + const mask = aarch64.vcgtq_s8(@bitCast(v), aarch64.vdupq_n_s8(-65)); + + // Shift to turn 0xFF to 0x01. + const mask_shift = aarch64.vshrq_n_u8(mask, 7); + + // Sum across the vector + const sum = aarch64.vaddvq_u8(mask_shift); + + // std.log.warn("mask={}", .{mask}); + // std.log.warn("mask_shift={}", .{mask_shift}); + // std.log.warn("sum={}", .{sum}); + return sum; + } +}; + +/// Generic test function so we can test against multiple implementations. +/// This is initially copied from the Zig stdlib but may be expanded. +fn testCount(func: *const Count) !void { + const testing = std.testing; + try testing.expectEqual(@as(usize, 16), func("hello friends!!!")); + try testing.expectEqual(@as(usize, 10), func("abcdefghij")); + try testing.expectEqual(@as(usize, 10), func("äåéëþüúíóö")); + try testing.expectEqual(@as(usize, 5), func("こんにちは")); +} + +test "count" { + const v = isa.detect(); + var it = v.iterator(); + while (it.next()) |isa_v| try testCount(countFunc(isa_v)); +} diff --git a/src/simd/utf8.zig b/src/simd/utf8_validate.zig similarity index 93% rename from src/simd/utf8.zig rename to src/simd/utf8_validate.zig index 0f324024b..af86b5230 100644 --- a/src/simd/utf8.zig +++ b/src/simd/utf8_validate.zig @@ -4,26 +4,26 @@ const assert = std.debug.assert; const isa = @import("isa.zig"); const aarch64 = @import("aarch64.zig"); -const Validate = @TypeOf(utf8Validate); +const Validate = fn ([]const u8) bool; // All of the work in this file is based heavily on the work of // Daniel Lemire and John Keiser. Their original work can be found here: // - https://arxiv.org/pdf/2010.03090.pdf // - https://simdutf.github.io/simdutf/ (MIT License) -pub fn utf8Validate(input: []const u8) bool { - return utf8ValidateNeon(input); +pub fn validateFunc(v: isa.ISA) *const Validate { + return switch (v) { + .avx2 => &Scalar.validate, // todo + .neon => &Neon.validate, + .scalar => &Scalar.validate, + }; } -pub fn utf8ValidateNeon(input: []const u8) bool { - var neon = Neon.init(); - neon.validate(input); - return !neon.hasErrors(); -} - -pub fn utf8ValidateScalar(input: []const u8) bool { - return std.unicode.utf8ValidateSlice(input); -} +pub const Scalar = struct { + pub fn validate(input: []const u8) bool { + return std.unicode.utf8ValidateSlice(input); + } +}; pub const Neon = struct { /// The previous input in a vector. This is required because to check @@ -46,11 +46,17 @@ pub const Neon = struct { }; } + pub fn validate(input: []const u8) bool { + var neon = Neon.init(); + neon.feed(input); + return !neon.hasErrors(); + } + /// Validate a chunk of UTF-8 data. This function is designed to be /// called multiple times with successive chunks of data. When the /// data is complete, you must call `finalize` to check for any /// remaining errors. - pub fn validate(self: *Neon, input: []const u8) void { + pub fn feed(self: *Neon, input: []const u8) void { // Break up our input into 16 byte chunks, and process each chunk // separately. The size of a Neon register is 16 bytes. var i: usize = 0; @@ -288,8 +294,8 @@ fn testValidate(func: *const Validate) !void { try testing.expect(!func("\xed\xbf\xbf")); } -test "utf8Validate neon" { - if (comptime !isa.possible(.neon)) return error.SkipZigTest; - const set = isa.detect(); - if (set.contains(.neon)) try testValidate(&utf8ValidateNeon); +test "validate" { + const v = isa.detect(); + var it = v.iterator(); + while (it.next()) |isa_v| try testValidate(validateFunc(isa_v)); }