diff --git a/src/config.zig b/src/config.zig index ddd1e0690..859e21d8b 100644 --- a/src/config.zig +++ b/src/config.zig @@ -3,16 +3,10 @@ const builtin = @import("builtin"); const Allocator = std.mem.Allocator; const ArenaAllocator = std.heap.ArenaAllocator; const inputpkg = @import("input.zig"); +const passwd = @import("passwd.zig"); const log = std.log.scoped(.config); -/// Used to determine the default shell and directory on Unixes. -const c = @cImport({ - @cInclude("sys/types.h"); - @cInclude("unistd.h"); - @cInclude("pwd.h"); -}); - /// Config is the main config struct. These fields map directly to the /// CLI flag names hence we use a lot of `@""` syntax to support hyphens. pub const Config = struct { @@ -172,27 +166,9 @@ pub const Config = struct { break :command; } else |_| {} - var buf: [1024]u8 = undefined; - var pw: c.struct_passwd = undefined; - var pw_ptr: ?*c.struct_passwd = null; - const res = c.getpwuid_r(c.getuid(), &pw, &buf, buf.len, &pw_ptr); - if (res != 0) { - log.warn("error retrieving pw entry code={d}", .{res}); - break :command; - } - - if (pw_ptr == null) { - // Future: let's check if a better shell is available like zsh - log.warn("no pw entry to detect default shell, will default to 'sh'", .{}); - self.command = "sh"; - break :command; - } - - if (pw.pw_shell) |ptr| { - const source = std.mem.sliceTo(ptr, 0); - const sh = try alloc.alloc(u8, source.len); - std.mem.copy(u8, sh, source); - + // Get the shell from the passwd entry + const pw = try passwd.get(alloc); + if (pw.shell) |sh| { log.debug("default shell src=passwd value={s}", .{sh}); self.command = sh; } diff --git a/src/main.zig b/src/main.zig index ab85343f6..9dd388440 100644 --- a/src/main.zig +++ b/src/main.zig @@ -177,6 +177,7 @@ test { // TODO _ = @import("config.zig"); + _ = @import("passwd.zig"); _ = @import("cli_args.zig"); _ = @import("lru.zig"); } diff --git a/src/passwd.zig b/src/passwd.zig new file mode 100644 index 000000000..6d647a86a --- /dev/null +++ b/src/passwd.zig @@ -0,0 +1,68 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const Allocator = std.mem.Allocator; +const ArenaAllocator = std.heap.ArenaAllocator; + +const log = std.log.scoped(.passwd); + +/// Used to determine the default shell and directory on Unixes. +const c = @cImport({ + @cInclude("sys/types.h"); + @cInclude("unistd.h"); + @cInclude("pwd.h"); +}); + +// Entry that is retrieved from the passwd API. This only contains the fields +// we care about. +pub const Entry = struct { + shell: ?[]const u8 = null, + home: ?[]const u8 = null, +}; + +/// Get the passwd entry for the currently executing user. +pub fn get(alloc: Allocator) !Entry { + var buf: [1024]u8 = undefined; + var pw: c.struct_passwd = undefined; + var pw_ptr: ?*c.struct_passwd = null; + const res = c.getpwuid_r(c.getuid(), &pw, &buf, buf.len, &pw_ptr); + if (res != 0) { + log.warn("error retrieving pw entry code={d}", .{res}); + return Entry{}; + } + + if (pw_ptr == null) { + // Future: let's check if a better shell is available like zsh + log.warn("no pw entry to detect default shell, will default to 'sh'", .{}); + return Entry{}; + } + + var result: Entry = .{}; + + if (pw.pw_shell) |ptr| { + const source = std.mem.sliceTo(ptr, 0); + const sh = try alloc.alloc(u8, source.len); + std.mem.copy(u8, sh, source); + result.shell = sh; + } + + if (pw.pw_dir) |ptr| { + const source = std.mem.sliceTo(ptr, 0); + const dir = try alloc.alloc(u8, source.len); + std.mem.copy(u8, dir, source); + result.home = dir; + } + + return result; +} + +test { + const testing = std.testing; + var arena = ArenaAllocator.init(testing.allocator); + defer arena.deinit(); + const alloc = arena.allocator(); + + // We should be able to get an entry + const entry = try get(alloc); + try testing.expect(entry.shell != null); + try testing.expect(entry.home != null); +}