simd: utf8 count

This commit is contained in:
Mitchell Hashimoto
2024-01-30 12:38:14 -08:00
parent 5b295cf6e2
commit 6523721846
4 changed files with 169 additions and 18 deletions

View File

@ -4,6 +4,36 @@
const std = @import("std");
const assert = std.debug.assert;
pub inline fn vaddlvq_u8(v: @Vector(16, u8)) u16 {
const result = asm (
\\ uaddlv %[ret:h], %[v].16b
: [ret] "=w" (-> @Vector(8, u16)),
: [v] "w" (v),
);
return result[0];
}
pub inline fn vaddvq_u8(v: @Vector(16, u8)) u8 {
const result = asm (
\\ addv %[ret:b], %[v].16b
: [ret] "=w" (-> @Vector(16, u8)),
: [v] "w" (v),
);
return result[0];
}
pub inline fn vaddv_u8(v: @Vector(8, u8)) u8 {
const result = asm (
\\ addv %[ret:b], %[v].8b
: [ret] "=w" (-> @Vector(8, u8)),
: [v] "w" (v),
);
return result[0];
}
pub inline fn vandq_u8(a: @Vector(16, u8), b: @Vector(16, u8)) @Vector(16, u8) {
return asm (
\\ and %[ret].16b, %[a].16b, %[b].16b
@ -31,6 +61,15 @@ pub inline fn vcgeq_u8(a: @Vector(16, u8), b: @Vector(16, u8)) @Vector(16, u8) {
);
}
pub inline fn vcgtq_s8(a: @Vector(16, i8), b: @Vector(16, i8)) @Vector(16, u8) {
return asm (
\\ cmgt %[ret].16b, %[a].16b, %[b].16b
: [ret] "=w" (-> @Vector(16, u8)),
: [a] "w" (a),
[b] "w" (b),
);
}
pub inline fn vcgtq_u8(a: @Vector(16, u8), b: @Vector(16, u8)) @Vector(16, u8) {
return asm (
\\ cmhi %[ret].16b, %[a].16b, %[b].16b
@ -40,6 +79,30 @@ pub inline fn vcgtq_u8(a: @Vector(16, u8), b: @Vector(16, u8)) @Vector(16, u8) {
);
}
pub inline fn vcnt_u8(v: @Vector(8, u8)) @Vector(8, u8) {
return asm (
\\ cnt %[ret].8b, %[v].8b
: [ret] "=w" (-> @Vector(8, u8)),
: [v] "w" (v),
);
}
pub inline fn vcreate_u8(v: u64) @Vector(8, u8) {
return asm (
\\ ins %[ret].D[0], %[value]
: [ret] "=w" (-> @Vector(8, u8)),
: [value] "r" (v),
);
}
pub inline fn vdupq_n_s8(v: i8) @Vector(16, i8) {
return asm (
\\ dup %[ret].16b, %[value:w]
: [ret] "=w" (-> @Vector(16, i8)),
: [value] "r" (v),
);
}
pub inline fn vdupq_n_u8(v: u8) @Vector(16, u8) {
return asm (
\\ dup %[ret].16b, %[value:w]
@ -76,6 +139,15 @@ pub inline fn vget_lane_u64(v: @Vector(1, u64)) u64 {
);
}
pub inline fn vgetq_lane_u64(v: @Vector(2, u64), n: u1) u64 {
return asm (
\\ umov %[ret], %[v].d[%[n]]
: [ret] "=r" (-> u64),
: [v] "w" (v),
[n] "I" (n),
);
}
pub inline fn vld1q_u8(v: []const u8) @Vector(16, u8) {
return asm (
\\ ld1 { %[ret].16b }, [%[value]]

View File

@ -1,7 +1,8 @@
const std = @import("std");
pub const isa = @import("isa.zig");
pub const utf8 = @import("utf8.zig");
pub const utf8_count = @import("utf8_count.zig");
pub const utf8_validate = @import("utf8_validate.zig");
pub const index_of = @import("index_of.zig");
// TODO: temporary, only for zig build simd to inspect disasm easily

72
src/simd/utf8_count.zig Normal file
View File

@ -0,0 +1,72 @@
const std = @import("std");
const isa = @import("isa.zig");
const aarch64 = @import("aarch64.zig");
/// Count the number of UTF-8 codepoints in the given string. The string
/// is assumed to be valid UTF-8. Invalid UTF-8 will result in undefined
/// (and probably incorrect) behaviour.
pub const Count = fn ([]const u8) usize;
/// Returns the count function for the given ISA.
pub fn countFunc(v: isa.ISA) *const Count {
return switch (v) {
.avx2 => &Scalar.count, // todo
.neon => &Neon.count,
.scalar => &Scalar.count,
};
}
pub const Scalar = struct {
pub fn count(input: []const u8) usize {
return std.unicode.utf8CountCodepoints(input) catch unreachable;
}
};
/// Arm NEON implementation of the count function.
pub const Neon = struct {
pub fn count(input: []const u8) usize {
var result: usize = 0;
var i: usize = 0;
while (i + 16 <= input.len) : (i += 16) {
const input_vec = aarch64.vld1q_u8(input[i..]);
result += @intCast(process(input_vec));
}
if (i < input.len) result += Scalar.count(input[i..]);
return result;
}
pub fn process(v: @Vector(16, u8)) u8 {
// Find all the bits greater than -65 in binary (0b10000001) which
// are a leading byte of a UTF-8 codepoint. This will set the resulting
// vector to 0xFF for all leading bytes and 0x00 for all non-leading.
const mask = aarch64.vcgtq_s8(@bitCast(v), aarch64.vdupq_n_s8(-65));
// Shift to turn 0xFF to 0x01.
const mask_shift = aarch64.vshrq_n_u8(mask, 7);
// Sum across the vector
const sum = aarch64.vaddvq_u8(mask_shift);
// std.log.warn("mask={}", .{mask});
// std.log.warn("mask_shift={}", .{mask_shift});
// std.log.warn("sum={}", .{sum});
return sum;
}
};
/// Generic test function so we can test against multiple implementations.
/// This is initially copied from the Zig stdlib but may be expanded.
fn testCount(func: *const Count) !void {
const testing = std.testing;
try testing.expectEqual(@as(usize, 16), func("hello friends!!!"));
try testing.expectEqual(@as(usize, 10), func("abcdefghij"));
try testing.expectEqual(@as(usize, 10), func("äåéëþüúíóö"));
try testing.expectEqual(@as(usize, 5), func("こんにちは"));
}
test "count" {
const v = isa.detect();
var it = v.iterator();
while (it.next()) |isa_v| try testCount(countFunc(isa_v));
}

View File

@ -4,26 +4,26 @@ const assert = std.debug.assert;
const isa = @import("isa.zig");
const aarch64 = @import("aarch64.zig");
const Validate = @TypeOf(utf8Validate);
const Validate = fn ([]const u8) bool;
// All of the work in this file is based heavily on the work of
// Daniel Lemire and John Keiser. Their original work can be found here:
// - https://arxiv.org/pdf/2010.03090.pdf
// - https://simdutf.github.io/simdutf/ (MIT License)
pub fn utf8Validate(input: []const u8) bool {
return utf8ValidateNeon(input);
pub fn validateFunc(v: isa.ISA) *const Validate {
return switch (v) {
.avx2 => &Scalar.validate, // todo
.neon => &Neon.validate,
.scalar => &Scalar.validate,
};
}
pub fn utf8ValidateNeon(input: []const u8) bool {
var neon = Neon.init();
neon.validate(input);
return !neon.hasErrors();
}
pub fn utf8ValidateScalar(input: []const u8) bool {
return std.unicode.utf8ValidateSlice(input);
}
pub const Scalar = struct {
pub fn validate(input: []const u8) bool {
return std.unicode.utf8ValidateSlice(input);
}
};
pub const Neon = struct {
/// The previous input in a vector. This is required because to check
@ -46,11 +46,17 @@ pub const Neon = struct {
};
}
pub fn validate(input: []const u8) bool {
var neon = Neon.init();
neon.feed(input);
return !neon.hasErrors();
}
/// Validate a chunk of UTF-8 data. This function is designed to be
/// called multiple times with successive chunks of data. When the
/// data is complete, you must call `finalize` to check for any
/// remaining errors.
pub fn validate(self: *Neon, input: []const u8) void {
pub fn feed(self: *Neon, input: []const u8) void {
// Break up our input into 16 byte chunks, and process each chunk
// separately. The size of a Neon register is 16 bytes.
var i: usize = 0;
@ -288,8 +294,8 @@ fn testValidate(func: *const Validate) !void {
try testing.expect(!func("\xed\xbf\xbf"));
}
test "utf8Validate neon" {
if (comptime !isa.possible(.neon)) return error.SkipZigTest;
const set = isa.detect();
if (set.contains(.neon)) try testValidate(&utf8ValidateNeon);
test "validate" {
const v = isa.detect();
var it = v.iterator();
while (it.next()) |isa_v| try testValidate(validateFunc(isa_v));
}