simd: indexOf implementation using NEON

This commit is contained in:
Mitchell Hashimoto
2024-01-28 21:38:09 -08:00
parent 31d5785105
commit 7feba12eab
4 changed files with 174 additions and 1 deletions

View File

@ -307,6 +307,7 @@ test {
_ = @import("inspector/main.zig"); _ = @import("inspector/main.zig");
_ = @import("terminal/main.zig"); _ = @import("terminal/main.zig");
_ = @import("terminfo/main.zig"); _ = @import("terminfo/main.zig");
_ = @import("simd/main.zig");
// TODO // TODO
_ = @import("blocking_queue.zig"); _ = @import("blocking_queue.zig");

66
src/simd/aarch64.zig Normal file
View File

@ -0,0 +1,66 @@
// https://developer.arm.com/architectures/instruction-sets/intrinsics
// https://llvm.org/docs/LangRef.html#inline-assembler-expressions
const std = @import("std");
const assert = std.debug.assert;
pub inline fn vdupq_n_u8(v: u8) @Vector(16, u8) {
return asm (
\\ dup %[ret].16b, %[value:w]
: [ret] "=w" (-> @Vector(16, u8)),
: [value] "r" (v),
);
}
pub inline fn vld1q_u8(v: []const u8) @Vector(16, u8) {
return asm (
\\ ld1 { %[ret].16b }, [%[value]]
: [ret] "=w" (-> @Vector(16, u8)),
: [value] "r" (v.ptr),
);
}
pub inline fn vceqq_u8(a: @Vector(16, u8), b: @Vector(16, u8)) @Vector(16, u8) {
return asm (
\\ cmeq %[ret].16b, %[a].16b, %[b].16b
: [ret] "=w" (-> @Vector(16, u8)),
: [a] "w" (a),
[b] "w" (b),
);
}
pub inline fn vshrn_n_u16(a: @Vector(8, u16), n: u4) @Vector(8, u8) {
assert(n <= 8);
return asm (
\\ shrn %[ret].8b, %[a].8h, %[n]
: [ret] "=w" (-> @Vector(8, u8)),
: [a] "w" (a),
[n] "I" (n),
);
}
pub inline fn vget_lane_u64(v: @Vector(1, u64)) u64 {
return asm (
\\ umov %[ret], %[v].d[0]
: [ret] "=r" (-> u64),
: [v] "w" (v),
);
}
pub inline fn rbit(comptime T: type, v: T) T {
assert(T == u32 or T == u64);
return asm (
\\ rbit %[ret], %[v]
: [ret] "=r" (-> T),
: [v] "r" (v),
);
}
pub inline fn clz(comptime T: type, v: T) T {
assert(T == u32 or T == u64);
return asm (
\\ clz %[ret], %[v]
: [ret] "=r" (-> T),
: [v] "r" (v),
);
}

97
src/simd/index_of.zig Normal file
View File

@ -0,0 +1,97 @@
const std = @import("std");
const builtin = @import("builtin");
const aarch64 = @import("aarch64.zig");
// Note this is a reimplementation of std.mem.indexOfScalar. The Zig stdlib
// version is already SIMD-optimized but not using runtime ISA detection. This
// expands the stdlib version to use runtime ISA detection. This also, at the
// time of writing this comment, reimplements it using manual assembly. This is
// so I can compare to Zig's @Vector lowering.
const IndexOf = @TypeOf(indexOf);
/// Returns the first index of `needle` in `input` or `null` if `needle`
/// is not found.
pub fn indexOf(input: []const u8, needle: u8) ?usize {
return indexOfNeon(input, needle);
//return indexOfScalar(input, needle);
}
/// indexOf implementation using ARM NEON instructions.
fn indexOfNeon(input: []const u8, needle: u8) ?usize {
// This function is going to be commented in a lot of detail. SIMD is
// complicated and nonintuitive, so I want to make sure I understand what's
// going on. More importantly, I want to make sure when I look back on this
// code in the future, I understand what's going on.
// Load our needle into a vector register. This duplicates the needle 16
// times, once for each byte in the 128-bit vector register.
const needle_vec = aarch64.vdupq_n_u8(needle);
// note(mitchellh): benchmark to see if we should align to 16 bytes here
// Iterate 16 bytes at a time, which is the max size of a vector register.
var i: usize = 0;
while (i + 16 <= input.len) : (i += 16) {
// Load the next 16 bytes into a vector register.
const input_vec = aarch64.vld1q_u8(input[i..]);
// Compare the input vector to the needle vector. This will set
// all bits to "1" in the output vector for each matching byte.
const match_vec = aarch64.vceqq_u8(input_vec, needle_vec);
// This is a neat trick in order to efficiently find the index of
// the first matching byte. Details for this can be found here:
// https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
const shift_vec = aarch64.vshrn_n_u16(@bitCast(match_vec), 4);
const shift_u64 = aarch64.vget_lane_u64(@bitCast(shift_vec));
if (shift_u64 == 0) {
// This means no matches were found.
continue;
}
// A match was found! Reverse the bits and divide by 4 to get the
// index of the first matching byte. The reversal is due to the
// bits being reversed in the shift operation, the division by 4
// is due to all data being repeated 4 times by vceqq.
const reversed = aarch64.rbit(u64, shift_u64);
const index = aarch64.clz(u64, reversed) >> 2;
return i + index;
}
// Handle the remaining bytes
if (i < input.len) {
while (i < input.len) : (i += 1) {
if (input[i] == needle) return i;
}
}
return null;
}
fn indexOfScalar(input: []const u8, needle: u8) ?usize {
// Note this actually uses vector operations if supported. See
// our comment at the top of the file.
return std.mem.indexOfScalar(u8, input, needle);
}
/// Generic test function so we can test against multiple implementations.
fn testIndexOf(func: *const IndexOf) !void {
const testing = std.testing;
try testing.expect(func("hello", ' ') == null);
try testing.expectEqual(@as(usize, 2), func("hi lo", ' ').?);
try testing.expectEqual(@as(usize, 5), func(
\\XXXXX XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
\\XXXXXXXXXXXX XXXXXXXXXXX XXXXXXXXXXXXXXX
, ' ').?);
try testing.expectEqual(@as(usize, 53), func(
\\XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
\\XXXXXXXXXXXX XXXXXXXXXXX XXXXXXXXXXXXXXX
, ' ').?);
}
test "indexOf neon" {
// TODO: use ISA detection here
if (comptime builtin.cpu.arch != .aarch64) return error.SkipZigTest;
try testIndexOf(&indexOfNeon);
}

View File

@ -1,8 +1,17 @@
const std = @import("std"); const std = @import("std");
const isa = @import("isa.zig"); const isa = @import("isa.zig");
const index_of = @import("index_of.zig");
pub usingnamespace isa; pub usingnamespace isa;
pub usingnamespace index_of;
pub fn main() !void { pub fn main() !void {
std.log.warn("ISA={}", .{isa.ISA.detect()}); //std.log.warn("ISA={}", .{isa.ISA.detect()});
const input = "1234567\x1b1234567\x1b";
//const input = "1234567812345678";
_ = index_of.indexOf(input, 0x1B);
}
test {
@import("std").testing.refAllDecls(@This());
} }