From dba78b20ca3ccb11d793866d1359928aa1ae44be Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Thu, 16 Nov 2023 10:06:24 -0800 Subject: [PATCH] renderer: shadertoy convert to MSL --- pkg/glslang/program.zig | 4 +- src/renderer/shadertoy.zig | 79 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 5 deletions(-) diff --git a/pkg/glslang/program.zig b/pkg/glslang/program.zig index 28fc8ab9b..70d3c88cd 100644 --- a/pkg/glslang/program.zig +++ b/pkg/glslang/program.zig @@ -30,11 +30,11 @@ pub const Program = opaque { return @intCast(c.glslang_program_SPIRV_get_size(@ptrCast(self))); } - pub fn spirvGet(self: *Program, buf: []u8) void { + pub fn spirvGet(self: *Program, buf: []u32) void { c.glslang_program_SPIRV_get(@ptrCast(self), buf.ptr); } - pub fn spirvGetPtr(self: *Program) ![*]u8 { + pub fn spirvGetPtr(self: *Program) ![*]u32 { return @ptrCast(c.glslang_program_SPIRV_get_ptr(@ptrCast(self))); } diff --git a/src/renderer/shadertoy.zig b/src/renderer/shadertoy.zig index 00a2443e5..81430932e 100644 --- a/src/renderer/shadertoy.zig +++ b/src/renderer/shadertoy.zig @@ -1,7 +1,9 @@ const std = @import("std"); const builtin = @import("builtin"); +const assert = std.debug.assert; const Allocator = std.mem.Allocator; const glslang = @import("glslang"); +const spvcross = @import("spirv_cross"); /// Convert a ShaderToy shader into valid GLSL. /// @@ -67,7 +69,9 @@ pub fn spirvFromGlsl( program.spirvGenerate(c.GLSLANG_STAGE_FRAGMENT); const size = program.spirvGetSize(); const ptr = try program.spirvGetPtr(); - try writer.writeAll(ptr[0..size]); + const ptr_u8: [*]u8 = @ptrCast(ptr); + const slice_u8: []u8 = ptr_u8[0 .. size * 4]; + try writer.writeAll(slice_u8); } /// Retrieve errors from spirv compilation. @@ -100,6 +104,59 @@ pub const SpirvLog = struct { } }; +/// Convert SPIR-V binary to MSL. +pub fn mslFromSpv(alloc: Allocator, spv: []const u8) ![:0]const u8 { + // Spir-V is always a multiple of 4 because it is written as a series of words + if (@mod(spv.len, 4) != 0) return error.SpirvInvalid; + + // Compiler context + const c = spvcross.c; + var ctx: c.spvc_context = undefined; + if (c.spvc_context_create(&ctx) != c.SPVC_SUCCESS) return error.SpvcFailed; + defer c.spvc_context_destroy(ctx); + + // It would be better to get this out into an output parameter to + // show users but for now we can just log it. + c.spvc_context_set_error_callback(ctx, @ptrCast(&(struct { + fn callback(_: ?*anyopaque, msg_ptr: [*c]const u8) callconv(.C) void { + const msg = std.mem.sliceTo(msg_ptr, 0); + std.log.warn("spirv-cross error message={s}", .{msg}); + } + }).callback), null); + + // Parse the Spir-V binary to an IR + var ir: c.spvc_parsed_ir = undefined; + if (c.spvc_context_parse_spirv( + ctx, + @ptrCast(@alignCast(spv.ptr)), + spv.len / 4, + &ir, + ) != c.SPVC_SUCCESS) { + return error.SpvcFailed; + } + + // Build our compiler to MSL + var compiler: c.spvc_compiler = undefined; + if (c.spvc_context_create_compiler( + ctx, + c.SPVC_BACKEND_MSL, + ir, + c.SPVC_CAPTURE_MODE_TAKE_OWNERSHIP, + &compiler, + ) != c.SPVC_SUCCESS) { + return error.SpvcFailed; + } + + // Compile the resulting string. This string pointer is owned by the + // context so we don't need to free it. + var result: [*:0]const u8 = undefined; + if (c.spvc_compiler_compile(compiler, @ptrCast(&result)) != c.SPVC_SUCCESS) { + return error.SpvcFailed; + } + + return try alloc.dupeZ(u8, std.mem.sliceTo(result, 0)); +} + /// Convert ShaderToy shader to null-terminated glsl for testing. fn testGlslZ(alloc: Allocator, src: []const u8) ![:0]const u8 { var list = std.ArrayList(u8).init(alloc); @@ -115,7 +172,7 @@ test "spirv" { const src = try testGlslZ(alloc, test_crt); defer alloc.free(src); - var buf: [4096]u8 = undefined; + var buf: [4096 * 4]u8 = undefined; var buf_stream = std.io.fixedBufferStream(&buf); const writer = buf_stream.writer(); try spirvFromGlsl(writer, null, src); @@ -128,7 +185,7 @@ test "spirv invalid" { const src = try testGlslZ(alloc, test_invalid); defer alloc.free(src); - var buf: [4096]u8 = undefined; + var buf: [4096 * 4]u8 = undefined; var buf_stream = std.io.fixedBufferStream(&buf); const writer = buf_stream.writer(); @@ -138,5 +195,21 @@ test "spirv invalid" { try testing.expect(errlog.info.len > 0); } +test "shadertoy to msl" { + const testing = std.testing; + const alloc = testing.allocator; + + const src = try testGlslZ(alloc, test_crt); + defer alloc.free(src); + + var spvlist = std.ArrayList(u8).init(alloc); + defer spvlist.deinit(); + try spirvFromGlsl(spvlist.writer(), null, src); + while (@mod(spvlist.items.len, 4) != 0) try spvlist.append(0); + + const msl = try mslFromSpv(alloc, spvlist.items); + defer alloc.free(msl); +} + const test_crt = @embedFile("shaders/test_shadertoy_crt.glsl"); const test_invalid = @embedFile("shaders/test_shadertoy_invalid.glsl");