summaryrefslogtreecommitdiff
path: root/src/simd
diff options
context:
space:
mode:
authorMitchell Hashimoto <mitchell.hashimoto@gmail.com>2024-01-28 21:38:09 -0800
committerMitchell Hashimoto <mitchell.hashimoto@gmail.com>2024-02-05 21:22:03 -0800
commit7feba12eab2c4ac9ddf75c029f2a894ee14019ae (patch)
treef23822d24488f5032239a96922950e4bbfac5853 /src/simd
parent31d5785105c6095e23a1fd0c63bbb93e7c4431d4 (diff)
simd: indexOf implementation using NEON
Diffstat (limited to 'src/simd')
-rw-r--r--src/simd/aarch64.zig66
-rw-r--r--src/simd/index_of.zig97
-rw-r--r--src/simd/main.zig11
3 files changed, 173 insertions, 1 deletions
diff --git a/src/simd/aarch64.zig b/src/simd/aarch64.zig
new file mode 100644
index 000000000..1c02a68cc
--- /dev/null
+++ b/src/simd/aarch64.zig
@@ -0,0 +1,66 @@
+// https://developer.arm.com/architectures/instruction-sets/intrinsics
+// https://llvm.org/docs/LangRef.html#inline-assembler-expressions
+
+const std = @import("std");
+const assert = std.debug.assert;
+
+pub inline fn vdupq_n_u8(v: u8) @Vector(16, u8) {
+ return asm (
+ \\ dup %[ret].16b, %[value:w]
+ : [ret] "=w" (-> @Vector(16, u8)),
+ : [value] "r" (v),
+ );
+}
+
+pub inline fn vld1q_u8(v: []const u8) @Vector(16, u8) {
+ return asm (
+ \\ ld1 { %[ret].16b }, [%[value]]
+ : [ret] "=w" (-> @Vector(16, u8)),
+ : [value] "r" (v.ptr),
+ );
+}
+
+pub inline fn vceqq_u8(a: @Vector(16, u8), b: @Vector(16, u8)) @Vector(16, u8) {
+ return asm (
+ \\ cmeq %[ret].16b, %[a].16b, %[b].16b
+ : [ret] "=w" (-> @Vector(16, u8)),
+ : [a] "w" (a),
+ [b] "w" (b),
+ );
+}
+
+pub inline fn vshrn_n_u16(a: @Vector(8, u16), n: u4) @Vector(8, u8) {
+ assert(n <= 8);
+ return asm (
+ \\ shrn %[ret].8b, %[a].8h, %[n]
+ : [ret] "=w" (-> @Vector(8, u8)),
+ : [a] "w" (a),
+ [n] "I" (n),
+ );
+}
+
+pub inline fn vget_lane_u64(v: @Vector(1, u64)) u64 {
+ return asm (
+ \\ umov %[ret], %[v].d[0]
+ : [ret] "=r" (-> u64),
+ : [v] "w" (v),
+ );
+}
+
+pub inline fn rbit(comptime T: type, v: T) T {
+ assert(T == u32 or T == u64);
+ return asm (
+ \\ rbit %[ret], %[v]
+ : [ret] "=r" (-> T),
+ : [v] "r" (v),
+ );
+}
+
+pub inline fn clz(comptime T: type, v: T) T {
+ assert(T == u32 or T == u64);
+ return asm (
+ \\ clz %[ret], %[v]
+ : [ret] "=r" (-> T),
+ : [v] "r" (v),
+ );
+}
diff --git a/src/simd/index_of.zig b/src/simd/index_of.zig
new file mode 100644
index 000000000..6630615a4
--- /dev/null
+++ b/src/simd/index_of.zig
@@ -0,0 +1,97 @@
+const std = @import("std");
+const builtin = @import("builtin");
+const aarch64 = @import("aarch64.zig");
+
+// Note this is a reimplementation of std.mem.indexOfScalar. The Zig stdlib
+// version is already SIMD-optimized but not using runtime ISA detection. This
+// expands the stdlib version to use runtime ISA detection. This also, at the
+// time of writing this comment, reimplements it using manual assembly. This is
+// so I can compare to Zig's @Vector lowering.
+
+const IndexOf = @TypeOf(indexOf);
+
+/// Returns the first index of `needle` in `input` or `null` if `needle`
+/// is not found.
+pub fn indexOf(input: []const u8, needle: u8) ?usize {
+ return indexOfNeon(input, needle);
+ //return indexOfScalar(input, needle);
+}
+
+/// indexOf implementation using ARM NEON instructions.
+fn indexOfNeon(input: []const u8, needle: u8) ?usize {
+ // This function is going to be commented in a lot of detail. SIMD is
+ // complicated and nonintuitive, so I want to make sure I understand what's
+ // going on. More importantly, I want to make sure when I look back on this
+ // code in the future, I understand what's going on.
+
+ // Load our needle into a vector register. This duplicates the needle 16
+ // times, once for each byte in the 128-bit vector register.
+ const needle_vec = aarch64.vdupq_n_u8(needle);
+
+ // note(mitchellh): benchmark to see if we should align to 16 bytes here
+
+ // Iterate 16 bytes at a time, which is the max size of a vector register.
+ var i: usize = 0;
+ while (i + 16 <= input.len) : (i += 16) {
+ // Load the next 16 bytes into a vector register.
+ const input_vec = aarch64.vld1q_u8(input[i..]);
+
+ // Compare the input vector to the needle vector. This will set
+ // all bits to "1" in the output vector for each matching byte.
+ const match_vec = aarch64.vceqq_u8(input_vec, needle_vec);
+
+ // This is a neat trick in order to efficiently find the index of
+ // the first matching byte. Details for this can be found here:
+ // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
+ const shift_vec = aarch64.vshrn_n_u16(@bitCast(match_vec), 4);
+ const shift_u64 = aarch64.vget_lane_u64(@bitCast(shift_vec));
+ if (shift_u64 == 0) {
+ // This means no matches were found.
+ continue;
+ }
+
+ // A match was found! Reverse the bits and divide by 4 to get the
+ // index of the first matching byte. The reversal is due to the
+ // bits being reversed in the shift operation, the division by 4
+ // is due to all data being repeated 4 times by vceqq.
+ const reversed = aarch64.rbit(u64, shift_u64);
+ const index = aarch64.clz(u64, reversed) >> 2;
+ return i + index;
+ }
+
+ // Handle the remaining bytes
+ if (i < input.len) {
+ while (i < input.len) : (i += 1) {
+ if (input[i] == needle) return i;
+ }
+ }
+
+ return null;
+}
+
+fn indexOfScalar(input: []const u8, needle: u8) ?usize {
+ // Note this actually uses vector operations if supported. See
+ // our comment at the top of the file.
+ return std.mem.indexOfScalar(u8, input, needle);
+}
+
+/// Generic test function so we can test against multiple implementations.
+fn testIndexOf(func: *const IndexOf) !void {
+ const testing = std.testing;
+ try testing.expect(func("hello", ' ') == null);
+ try testing.expectEqual(@as(usize, 2), func("hi lo", ' ').?);
+ try testing.expectEqual(@as(usize, 5), func(
+ \\XXXXX XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
+ \\XXXXXXXXXXXX XXXXXXXXXXX XXXXXXXXXXXXXXX
+ , ' ').?);
+ try testing.expectEqual(@as(usize, 53), func(
+ \\XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
+ \\XXXXXXXXXXXX XXXXXXXXXXX XXXXXXXXXXXXXXX
+ , ' ').?);
+}
+
+test "indexOf neon" {
+ // TODO: use ISA detection here
+ if (comptime builtin.cpu.arch != .aarch64) return error.SkipZigTest;
+ try testIndexOf(&indexOfNeon);
+}
diff --git a/src/simd/main.zig b/src/simd/main.zig
index d87d306ed..dabae2950 100644
--- a/src/simd/main.zig
+++ b/src/simd/main.zig
@@ -1,8 +1,17 @@
const std = @import("std");
const isa = @import("isa.zig");
+const index_of = @import("index_of.zig");
pub usingnamespace isa;
+pub usingnamespace index_of;
pub fn main() !void {
- std.log.warn("ISA={}", .{isa.ISA.detect()});
+ //std.log.warn("ISA={}", .{isa.ISA.detect()});
+ const input = "1234567\x1b1234567\x1b";
+ //const input = "1234567812345678";
+ _ = index_of.indexOf(input, 0x1B);
+}
+
+test {
+ @import("std").testing.refAllDecls(@This());
}