simd: convert indexOf, mess around with simdvt

This commit is contained in:
Mitchell Hashimoto
2024-02-02 21:20:30 -08:00
parent a66174678b
commit c042b052b2
5 changed files with 146 additions and 43 deletions

View File

@ -9,34 +9,57 @@ const aarch64 = @import("aarch64.zig");
// time of writing this comment, reimplements it using manual assembly. This is
// so I can compare to Zig's @Vector lowering.
const IndexOf = @TypeOf(indexOf);
pub const IndexOf = fn ([]const u8, u8) ?usize;
/// Returns the first index of `needle` in `input` or `null` if `needle`
/// is not found.
pub fn indexOf(input: []const u8, needle: u8) ?usize {
return indexOfNeon(input, needle);
//return indexOfScalar(input, needle);
/// Returns the indexOf function for the given ISA.
pub fn indexOfFunc(v: isa.ISA) *const IndexOf {
return isa.funcMap(IndexOf, v, .{
.{ .avx2, Scalar.indexOf }, // todo
.{ .neon, Neon.indexOf },
.{ .scalar, Scalar.indexOf },
});
}
/// indexOf implementation using ARM NEON instructions.
fn indexOfNeon(input: []const u8, needle: u8) ?usize {
// This function is going to be commented in a lot of detail. SIMD is
// complicated and nonintuitive, so I want to make sure I understand what's
// going on. More importantly, I want to make sure when I look back on this
// code in the future, I understand what's going on.
pub const Scalar = struct {
pub fn indexOf(input: []const u8, needle: u8) ?usize {
return std.mem.indexOfScalar(u8, input, needle);
}
};
// Load our needle into a vector register. This duplicates the needle 16
// times, once for each byte in the 128-bit vector register.
const needle_vec = aarch64.vdupq_n_u8(needle);
pub const Neon = struct {
/// indexOf implementation using ARM NEON instructions.
pub fn indexOf(input: []const u8, needle: u8) ?usize {
// This function is going to be commented in a lot of detail. SIMD is
// complicated and nonintuitive, so I want to make sure I understand what's
// going on. More importantly, I want to make sure when I look back on this
// code in the future, I understand what's going on.
// note(mitchellh): benchmark to see if we should align to 16 bytes here
// Load our needle into a vector register. This duplicates the needle 16
// times, once for each byte in the 128-bit vector register.
const needle_vec = aarch64.vdupq_n_u8(needle);
// Iterate 16 bytes at a time, which is the max size of a vector register.
var i: usize = 0;
while (i + 16 <= input.len) : (i += 16) {
// Load the next 16 bytes into a vector register.
const input_vec = aarch64.vld1q_u8(input[i..]);
// note(mitchellh): benchmark to see if we should align to 16 bytes here
// Iterate 16 bytes at a time, which is the max size of a vector register.
var i: usize = 0;
while (i + 16 <= input.len) : (i += 16) {
const input_vec = aarch64.vld1q_u8(input[i..]);
if (indexOfVec(input_vec, needle_vec)) |index| {
return i + index;
}
}
// Handle the remaining bytes
if (i < input.len) {
while (i < input.len) : (i += 1) {
if (input[i] == needle) return i;
}
}
return null;
}
pub fn indexOfVec(input_vec: @Vector(16, u8), needle_vec: @Vector(16, u8)) ?usize {
// Compare the input vector to the needle vector. This will set
// all bits to "1" in the output vector for each matching byte.
const match_vec = aarch64.vceqq_u8(input_vec, needle_vec);
@ -48,7 +71,7 @@ fn indexOfNeon(input: []const u8, needle: u8) ?usize {
const shift_u64 = aarch64.vget_lane_u64(@bitCast(shift_vec));
if (shift_u64 == 0) {
// This means no matches were found.
continue;
return null;
}
// A match was found! Reverse the bits and divide by 4 to get the
@ -57,24 +80,9 @@ fn indexOfNeon(input: []const u8, needle: u8) ?usize {
// is due to all data being repeated 4 times by vceqq.
const reversed = aarch64.rbit(u64, shift_u64);
const index = aarch64.clz(u64, reversed) >> 2;
return i + index;
return index;
}
// Handle the remaining bytes
if (i < input.len) {
while (i < input.len) : (i += 1) {
if (input[i] == needle) return i;
}
}
return null;
}
fn indexOfScalar(input: []const u8, needle: u8) ?usize {
// Note this actually uses vector operations if supported. See
// our comment at the top of the file.
return std.mem.indexOfScalar(u8, input, needle);
}
};
/// Generic test function so we can test against multiple implementations.
fn testIndexOf(func: *const IndexOf) !void {
@ -91,8 +99,8 @@ fn testIndexOf(func: *const IndexOf) !void {
, ' ').?);
}
test "indexOf neon" {
if (comptime !isa.possible(.neon)) return error.SkipZigTest;
const set = isa.detect();
if (set.contains(.neon)) try testIndexOf(&indexOfNeon);
test "indexOf" {
const v = isa.detect();
var it = v.iterator();
while (it.next()) |isa_v| try testIndexOf(indexOfFunc(isa_v));
}

View File

@ -1,6 +1,9 @@
const std = @import("std");
pub const isa = @import("isa.zig");
pub const aarch64 = @import("aarch64.zig");
pub const utf8_count = @import("utf8_count.zig");
pub const utf8_decode = @import("utf8_decode.zig");
pub const utf8_validate = @import("utf8_validate.zig");

View File

@ -43,6 +43,11 @@ pub const EraseLine = csi.EraseLine;
pub const TabClear = csi.TabClear;
pub const Attribute = sgr.Attribute;
// TODO: we only have a hardcoded Neon implementation for now
pub usingnamespace if (builtin.target.cpu.arch == .aarch64) struct {
pub const simdvt = @import("simdvt.zig");
} else struct {};
/// If we're targeting wasm then we export some wasm APIs.
pub usingnamespace if (builtin.target.isWasm()) struct {
pub usingnamespace @import("wasm.zig");

5
src/terminal/simdvt.zig Normal file
View File

@ -0,0 +1,5 @@
pub const Parser = @import("simdvt/Parser.zig");
test {
@import("std").testing.refAllDecls(@This());
}

View File

@ -0,0 +1,82 @@
const std = @import("std");
const Allocator = std.mem.Allocator;
const ArenaAllocator = std.heap.ArenaAllocator;
const terminal = @import("../main.zig");
const ScalarStream = terminal.Stream;
const simd = @import("../../simd/main.zig");
const aarch64 = simd.aarch64;
pub fn Stream(comptime Handler: type) type {
return struct {
const Self = @This();
handler: Handler,
pub fn init(h: Handler) Self {
return .{ .handler = h };
}
pub fn feed(self: *Self, input: []const u8) void {
// TODO: I want to do the UTF-8 decoding as we stream the input,
// but I don't want to deal with UTF-8 decode in SIMD right now.
// So for now we just go back over the input and decode using
// a scalar loop. Ugh.
// We search for ESC (0x1B) very frequently, since this is what triggers
// the start of a terminal escape sequence of any kind, so put this into
// a register immediately.
const esc_vec = aarch64.vdupq_n_u8(0x1B);
// Iterate 16 bytes at a time, which is the max size of a vector register.
var i: usize = 0;
while (i + 16 <= input.len) : (i += 16) {
// Load the next 16 bytes into a vector register.
const input_vec = aarch64.vld1q_u8(input[i..]);
// Check for ESC to determine if we should go to the next state.
if (simd.index_of.Neon.indexOfVec(input_vec, esc_vec)) |index| {
_ = index;
@panic("TODO");
}
// No ESC found, decode UTF-8.
// TODO(mitchellh): I don't have a UTF-8 decoder in SIMD yet, so
// for now we just use a scalar loop. This is slow.
const view = std.unicode.Utf8View.initUnchecked(input[i .. i + 16]);
var it = view.iterator();
while (it.nextCodepoint()) |cp| {
self.handler.print(cp);
}
}
// Handle the remaining bytes
if (i < input.len) {
@panic("input must be a multiple of 16 bytes for now");
}
}
};
}
test "ascii" {
const testing = std.testing;
var arena = ArenaAllocator.init(testing.allocator);
defer arena.deinit();
const alloc = arena.allocator();
const H = struct {
const Self = @This();
alloc: Allocator,
buf: std.ArrayListUnmanaged(u21) = .{},
pub fn print(self: *Self, c: u21) void {
self.buf.append(self.alloc, c) catch unreachable;
}
};
const str = "hello" ** 16;
var s = Stream(H).init(.{ .alloc = alloc });
s.feed(str);
try testing.expectEqual(str.len, s.handler.buf.items.len);
}