From c042b052b2a0a12e68a69c1f9018e853e667d054 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Fri, 2 Feb 2024 21:20:30 -0800 Subject: [PATCH] simd: convert indexOf, mess around with simdvt --- src/simd/index_of.zig | 94 ++++++++++++++++++---------------- src/simd/main.zig | 3 ++ src/terminal/main.zig | 5 ++ src/terminal/simdvt.zig | 5 ++ src/terminal/simdvt/parser.zig | 82 +++++++++++++++++++++++++++++ 5 files changed, 146 insertions(+), 43 deletions(-) create mode 100644 src/terminal/simdvt.zig create mode 100644 src/terminal/simdvt/parser.zig diff --git a/src/simd/index_of.zig b/src/simd/index_of.zig index b66349847..4b4affe56 100644 --- a/src/simd/index_of.zig +++ b/src/simd/index_of.zig @@ -9,34 +9,57 @@ const aarch64 = @import("aarch64.zig"); // time of writing this comment, reimplements it using manual assembly. This is // so I can compare to Zig's @Vector lowering. -const IndexOf = @TypeOf(indexOf); +pub const IndexOf = fn ([]const u8, u8) ?usize; -/// Returns the first index of `needle` in `input` or `null` if `needle` -/// is not found. -pub fn indexOf(input: []const u8, needle: u8) ?usize { - return indexOfNeon(input, needle); - //return indexOfScalar(input, needle); +/// Returns the indexOf function for the given ISA. +pub fn indexOfFunc(v: isa.ISA) *const IndexOf { + return isa.funcMap(IndexOf, v, .{ + .{ .avx2, Scalar.indexOf }, // todo + .{ .neon, Neon.indexOf }, + .{ .scalar, Scalar.indexOf }, + }); } -/// indexOf implementation using ARM NEON instructions. -fn indexOfNeon(input: []const u8, needle: u8) ?usize { - // This function is going to be commented in a lot of detail. SIMD is - // complicated and nonintuitive, so I want to make sure I understand what's - // going on. More importantly, I want to make sure when I look back on this - // code in the future, I understand what's going on. +pub const Scalar = struct { + pub fn indexOf(input: []const u8, needle: u8) ?usize { + return std.mem.indexOfScalar(u8, input, needle); + } +}; - // Load our needle into a vector register. This duplicates the needle 16 - // times, once for each byte in the 128-bit vector register. - const needle_vec = aarch64.vdupq_n_u8(needle); +pub const Neon = struct { + /// indexOf implementation using ARM NEON instructions. + pub fn indexOf(input: []const u8, needle: u8) ?usize { + // This function is going to be commented in a lot of detail. SIMD is + // complicated and nonintuitive, so I want to make sure I understand what's + // going on. More importantly, I want to make sure when I look back on this + // code in the future, I understand what's going on. - // note(mitchellh): benchmark to see if we should align to 16 bytes here + // Load our needle into a vector register. This duplicates the needle 16 + // times, once for each byte in the 128-bit vector register. + const needle_vec = aarch64.vdupq_n_u8(needle); - // Iterate 16 bytes at a time, which is the max size of a vector register. - var i: usize = 0; - while (i + 16 <= input.len) : (i += 16) { - // Load the next 16 bytes into a vector register. - const input_vec = aarch64.vld1q_u8(input[i..]); + // note(mitchellh): benchmark to see if we should align to 16 bytes here + // Iterate 16 bytes at a time, which is the max size of a vector register. + var i: usize = 0; + while (i + 16 <= input.len) : (i += 16) { + const input_vec = aarch64.vld1q_u8(input[i..]); + if (indexOfVec(input_vec, needle_vec)) |index| { + return i + index; + } + } + + // Handle the remaining bytes + if (i < input.len) { + while (i < input.len) : (i += 1) { + if (input[i] == needle) return i; + } + } + + return null; + } + + pub fn indexOfVec(input_vec: @Vector(16, u8), needle_vec: @Vector(16, u8)) ?usize { // Compare the input vector to the needle vector. This will set // all bits to "1" in the output vector for each matching byte. const match_vec = aarch64.vceqq_u8(input_vec, needle_vec); @@ -48,7 +71,7 @@ fn indexOfNeon(input: []const u8, needle: u8) ?usize { const shift_u64 = aarch64.vget_lane_u64(@bitCast(shift_vec)); if (shift_u64 == 0) { // This means no matches were found. - continue; + return null; } // A match was found! Reverse the bits and divide by 4 to get the @@ -57,24 +80,9 @@ fn indexOfNeon(input: []const u8, needle: u8) ?usize { // is due to all data being repeated 4 times by vceqq. const reversed = aarch64.rbit(u64, shift_u64); const index = aarch64.clz(u64, reversed) >> 2; - return i + index; + return index; } - - // Handle the remaining bytes - if (i < input.len) { - while (i < input.len) : (i += 1) { - if (input[i] == needle) return i; - } - } - - return null; -} - -fn indexOfScalar(input: []const u8, needle: u8) ?usize { - // Note this actually uses vector operations if supported. See - // our comment at the top of the file. - return std.mem.indexOfScalar(u8, input, needle); -} +}; /// Generic test function so we can test against multiple implementations. fn testIndexOf(func: *const IndexOf) !void { @@ -91,8 +99,8 @@ fn testIndexOf(func: *const IndexOf) !void { , ' ').?); } -test "indexOf neon" { - if (comptime !isa.possible(.neon)) return error.SkipZigTest; - const set = isa.detect(); - if (set.contains(.neon)) try testIndexOf(&indexOfNeon); +test "indexOf" { + const v = isa.detect(); + var it = v.iterator(); + while (it.next()) |isa_v| try testIndexOf(indexOfFunc(isa_v)); } diff --git a/src/simd/main.zig b/src/simd/main.zig index 9bc99fcf1..4523c1b40 100644 --- a/src/simd/main.zig +++ b/src/simd/main.zig @@ -1,6 +1,9 @@ const std = @import("std"); pub const isa = @import("isa.zig"); + +pub const aarch64 = @import("aarch64.zig"); + pub const utf8_count = @import("utf8_count.zig"); pub const utf8_decode = @import("utf8_decode.zig"); pub const utf8_validate = @import("utf8_validate.zig"); diff --git a/src/terminal/main.zig b/src/terminal/main.zig index a4224e63a..58854fa8c 100644 --- a/src/terminal/main.zig +++ b/src/terminal/main.zig @@ -43,6 +43,11 @@ pub const EraseLine = csi.EraseLine; pub const TabClear = csi.TabClear; pub const Attribute = sgr.Attribute; +// TODO: we only have a hardcoded Neon implementation for now +pub usingnamespace if (builtin.target.cpu.arch == .aarch64) struct { + pub const simdvt = @import("simdvt.zig"); +} else struct {}; + /// If we're targeting wasm then we export some wasm APIs. pub usingnamespace if (builtin.target.isWasm()) struct { pub usingnamespace @import("wasm.zig"); diff --git a/src/terminal/simdvt.zig b/src/terminal/simdvt.zig new file mode 100644 index 000000000..d123ecc59 --- /dev/null +++ b/src/terminal/simdvt.zig @@ -0,0 +1,5 @@ +pub const Parser = @import("simdvt/Parser.zig"); + +test { + @import("std").testing.refAllDecls(@This()); +} diff --git a/src/terminal/simdvt/parser.zig b/src/terminal/simdvt/parser.zig new file mode 100644 index 000000000..5152a8394 --- /dev/null +++ b/src/terminal/simdvt/parser.zig @@ -0,0 +1,82 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const ArenaAllocator = std.heap.ArenaAllocator; + +const terminal = @import("../main.zig"); +const ScalarStream = terminal.Stream; +const simd = @import("../../simd/main.zig"); +const aarch64 = simd.aarch64; + +pub fn Stream(comptime Handler: type) type { + return struct { + const Self = @This(); + + handler: Handler, + + pub fn init(h: Handler) Self { + return .{ .handler = h }; + } + + pub fn feed(self: *Self, input: []const u8) void { + // TODO: I want to do the UTF-8 decoding as we stream the input, + // but I don't want to deal with UTF-8 decode in SIMD right now. + // So for now we just go back over the input and decode using + // a scalar loop. Ugh. + + // We search for ESC (0x1B) very frequently, since this is what triggers + // the start of a terminal escape sequence of any kind, so put this into + // a register immediately. + const esc_vec = aarch64.vdupq_n_u8(0x1B); + + // Iterate 16 bytes at a time, which is the max size of a vector register. + var i: usize = 0; + while (i + 16 <= input.len) : (i += 16) { + // Load the next 16 bytes into a vector register. + const input_vec = aarch64.vld1q_u8(input[i..]); + + // Check for ESC to determine if we should go to the next state. + if (simd.index_of.Neon.indexOfVec(input_vec, esc_vec)) |index| { + _ = index; + @panic("TODO"); + } + + // No ESC found, decode UTF-8. + // TODO(mitchellh): I don't have a UTF-8 decoder in SIMD yet, so + // for now we just use a scalar loop. This is slow. + const view = std.unicode.Utf8View.initUnchecked(input[i .. i + 16]); + var it = view.iterator(); + while (it.nextCodepoint()) |cp| { + self.handler.print(cp); + } + } + + // Handle the remaining bytes + if (i < input.len) { + @panic("input must be a multiple of 16 bytes for now"); + } + } + }; +} + +test "ascii" { + const testing = std.testing; + var arena = ArenaAllocator.init(testing.allocator); + defer arena.deinit(); + const alloc = arena.allocator(); + + const H = struct { + const Self = @This(); + alloc: Allocator, + buf: std.ArrayListUnmanaged(u21) = .{}, + + pub fn print(self: *Self, c: u21) void { + self.buf.append(self.alloc, c) catch unreachable; + } + }; + + const str = "hello" ** 16; + var s = Stream(H).init(.{ .alloc = alloc }); + s.feed(str); + + try testing.expectEqual(str.len, s.handler.buf.items.len); +}