diff --git a/build.zig b/build.zig index 976ace79d..25b903977 100644 --- a/build.zig +++ b/build.zig @@ -1054,6 +1054,7 @@ fn addDeps( step.addCSourceFiles(.{ .files = &.{ + "src/simd/base64.cpp", "src/simd/codepoint_width.cpp", "src/simd/index_of.cpp", "src/simd/vt.cpp", diff --git a/src/simd/base64.cpp b/src/simd/base64.cpp new file mode 100644 index 000000000..63f26b505 --- /dev/null +++ b/src/simd/base64.cpp @@ -0,0 +1,20 @@ +#include + +extern "C" { + +size_t ghostty_simd_base64_max_length(const char* input, size_t length) { + return simdutf::maximal_binary_length_from_base64(input, length); +} + +size_t ghostty_simd_base64_decode(const char* input, + size_t length, + char* output) { + simdutf::result r = simdutf::base64_to_binary(input, length, output); + if (r.error) { + return -1; + } + + return r.count; +} + +} // extern "C" diff --git a/src/simd/base64.zig b/src/simd/base64.zig new file mode 100644 index 000000000..778fbfe3e --- /dev/null +++ b/src/simd/base64.zig @@ -0,0 +1,39 @@ +const std = @import("std"); + +// base64.cpp +extern "c" fn ghostty_simd_base64_max_length( + input: [*]const u8, + len: usize, +) usize; +extern "c" fn ghostty_simd_base64_decode( + input: [*]const u8, + len: usize, + output: [*]u8, +) isize; + +pub fn maxLen(input: []const u8) usize { + return ghostty_simd_base64_max_length(input.ptr, input.len); +} + +pub fn decode(input: []const u8, output: []u8) error{Base64Invalid}![]const u8 { + const res = ghostty_simd_base64_decode(input.ptr, input.len, output.ptr); + if (res < 0) return error.Base64Invalid; + return output[0..@intCast(res)]; +} + +test "base64 maxLen" { + const testing = std.testing; + const len = maxLen("aGVsbG8gd29ybGQ="); + try testing.expectEqual(11, len); +} + +test "base64 decode" { + const testing = std.testing; + const alloc = testing.allocator; + const input = "aGVsbG8gd29ybGQ="; + const len = maxLen(input); + const output = try alloc.alloc(u8, len); + defer alloc.free(output); + const str = try decode(input, output); + try testing.expectEqualStrings("hello world", str); +} diff --git a/src/simd/main.zig b/src/simd/main.zig index c7ced250d..d3ada5708 100644 --- a/src/simd/main.zig +++ b/src/simd/main.zig @@ -1,6 +1,7 @@ const std = @import("std"); pub usingnamespace @import("codepoint_width.zig"); +pub const base64 = @import("base64.zig"); pub const index_of = @import("index_of.zig"); pub const vt = @import("vt.zig"); diff --git a/src/terminal/kitty/graphics_command.zig b/src/terminal/kitty/graphics_command.zig index b8d0b0b7d..fe9a4520f 100644 --- a/src/terminal/kitty/graphics_command.zig +++ b/src/terminal/kitty/graphics_command.zig @@ -2,6 +2,7 @@ const std = @import("std"); const assert = std.debug.assert; const Allocator = std.mem.Allocator; const ArenaAllocator = std.heap.ArenaAllocator; +const simd = @import("../../simd/main.zig"); const log = std.log.scoped(.kitty_gfx); @@ -179,29 +180,24 @@ pub const CommandParser = struct { return ""; } - const Base64Decoder = std.base64.standard_no_pad.Decoder; - - // We remove any padding, since it's optional, and decode without it. - while (self.data.items[self.data.items.len - 1] == '=') { - self.data.items.len -= 1; - } - - const size = Base64Decoder.calcSizeForSlice(self.data.items) catch |err| { - log.warn("failed to calculate base64 size for payload: {}", .{err}); - return error.InvalidData; - }; + const max_len = simd.base64.maxLen(self.data.items); + assert(max_len <= self.data.items.len); // This is kinda cursed, but we can decode the base64 on top of // itself, since it's guaranteed that the encoded size is larger, // and any bytes in areas that are written to will have already // been used (assuming scalar decoding). - Base64Decoder.decode(self.data.items[0..size], self.data.items) catch |err| { + const decoded = simd.base64.decode( + self.data.items, + self.data.items[0..max_len], + ) catch |err| { log.warn("failed to decode base64 payload data: {}", .{err}); return error.InvalidData; }; + assert(decoded.len <= max_len); // Remove the extra bytes. - self.data.items.len = size; + self.data.items.len = decoded.len; return try self.data.toOwnedSlice(); }