diff --git a/src/tcp/Server.zig b/src/tcp/Server.zig index 7ad147c9a..7b48bfb7f 100644 --- a/src/tcp/Server.zig +++ b/src/tcp/Server.zig @@ -1,9 +1,7 @@ const std = @import("std"); const xev = @import("xev"); const Config = @import("../config/Config.zig"); - -const reject_client = @import("./handlers/reject.zig").reject_client; -const read_client = @import("./handlers/reader.zig").read_client; +const connections = @import("./handlers/connections.zig"); const Allocator = std.mem.Allocator; const CompletionPool = std.heap.MemoryPool(xev.Completion); @@ -103,16 +101,11 @@ pub fn deinit(self: *Server) void { pub fn start(self: *Server) !void { try self.socket.bind(self.addr); try self.socket.listen(self.max_clients); + try connections.startAccepting(self); log.info("bound server to socket={any}", .{self.socket}); - // Each acceptor borrows a completion from the pool - // We do this because the completion is passed to the client TCP handlers - const c = self.comp_pool.create() catch { - log.err("couldn't allocate completion in pool", .{}); - return error.OutOfMemory; - }; - - self.socket.accept(&self.loop, c, Server, self, acceptHandler); + // TODO: Stop flag? Only necessary if we support signaling the server + // from the main thread on an event, ie. configuration reloading. while (true) { try self.loop.run(.until_done); } @@ -125,81 +118,40 @@ pub fn destroyBuffer(self: *Server, buf: []const u8) void { )); } -/// Accepts a new client connection and starts reading from it until EOF. -/// Once an accept handler enters, it queues for a new client connection. -/// It essentially recursively calls itself until shutdown. -fn acceptHandler( - self_: ?*Server, - _: *xev.Loop, - c: *xev.Completion, - e: xev.TCP.AcceptError!xev.TCP, -) xev.CallbackAction { - const self = self_.?; - const new_c = self.comp_pool.create() catch { - log.err("couldn't allocate completion in pool", .{}); - return .disarm; +const BindError = error{ + NoAddress, + InvalidAddress, +}; + +/// Tries to generate a valid address to bind to, expects that the address will +/// start with tcp:// when binding to an IP and unix:// when binding to a file +/// based socket. +pub fn parseAddress(raw_addr: ?[:0]const u8) BindError!std.net.Address { + const addr = raw_addr orelse { + return BindError.NoAddress; }; - // Accept a new client connection now that we have a new completion - self.socket.accept(&self.loop, new_c, Server, self, acceptHandler); - - const sock = self.sock_pool.create() catch { - log.err("couldn't allocate socket in pool", .{}); - return .disarm; - }; - - sock.* = e catch { - log.err("accept error", .{}); - self.sock_pool.destroy(sock); - return .disarm; - }; - - if (self.clients_count == self.max_clients) { - log.warn("max clients reached, rejecting fd={d}", .{sock.fd}); - reject_client(self, sock) catch return .rearm; - return .disarm; + if (addr.len == 0) { + return BindError.NoAddress; } - log.info("accepted connection fd={d}", .{sock.fd}); - self.clients_count += 1; + const uri = std.Uri.parse(addr) catch return BindError.InvalidAddress; + if (std.mem.eql(u8, uri.scheme, "tcp")) { + const host = uri.host orelse return BindError.InvalidAddress; + const port = uri.port orelse return BindError.InvalidAddress; - read_client(self, sock, c) catch { - log.err("couldn't read from client", .{}); - }; + return std.net.Address.parseIp4(host.percent_encoded, port) catch { + return BindError.InvalidAddress; + }; + } - return .disarm; -} - -fn shutdownHandler( - self_: ?*Server, - loop: *xev.Loop, - comp: *xev.Completion, - sock: xev.TCP, - e: xev.TCP.ShutdownError!void, -) xev.CallbackAction { - e catch { - // Is this even possible? - log.err("couldn't shutdown socket", .{}); - }; - - const self = self_.?; - - sock.close(loop, comp, Server, self, closeHandler); - return .disarm; -} - -pub fn closeHandler( - self_: ?*Server, - _: *xev.Loop, - comp: *xev.Completion, - _: xev.TCP, - e: xev.TCP.CloseError!void, -) xev.CallbackAction { - e catch { - log.err("couldn't close socket", .{}); - }; - - const self = self_.?; - self.comp_pool.destroy(comp); - return .disarm; + // TODO: Should we check for valid file paths or just rely on the initUnix + // function to return an error? + if (std.mem.eql(u8, uri.scheme, "unix")) { + return std.net.Address.initUnix(uri.path.percent_encoded) catch { + return BindError.InvalidAddress; + }; + } + + return BindError.InvalidAddress; } diff --git a/src/tcp/handlers/connections.zig b/src/tcp/handlers/connections.zig new file mode 100644 index 000000000..93c728e12 --- /dev/null +++ b/src/tcp/handlers/connections.zig @@ -0,0 +1,96 @@ +const xev = @import("xev"); +const std = @import("std"); +const Server = @import("../Server.zig").Server; +const Command = @import("../Command.zig").Command; +const reject_client = @import("./reject.zig").reject_client; +const read_client = @import("./reader.zig").read_client; + +const log = std.log.scoped(.tcp_thread); + +/// Starts accepting client connections on the server's socket. +/// Note: This first xev.Completion is not destroyed here because it gets used +/// for an entire client connection lifecycle. +pub fn startAccepting(self: *Server) !void { + const c = try self.comp_pool.create(); + self.socket.accept(&self.loop, c, Server, self, aHandler); +} + +/// Accepts a new client connection and starts reading from it until EOF. +/// Once an accept handler enters, it queues for a new client connection. +/// It essentially recursively calls itself until shutdown. +fn aHandler( + self_: ?*Server, + _: *xev.Loop, + c: *xev.Completion, + e: xev.TCP.AcceptError!xev.TCP, +) xev.CallbackAction { + const self = self_.?; + const new_c = self.comp_pool.create() catch { + log.err("couldn't allocate completion in pool", .{}); + return .disarm; + }; + + // Accept a new client connection now that we have a new completion + self.socket.accept(&self.loop, new_c, Server, self, aHandler); + + const sock = self.sock_pool.create() catch { + log.err("couldn't allocate socket in pool", .{}); + return .disarm; + }; + + sock.* = e catch { + log.err("accept error", .{}); + self.sock_pool.destroy(sock); + return .disarm; + }; + + if (self.clients_count == self.max_clients) { + log.warn("max clients reached, rejecting fd={d}", .{sock.fd}); + reject_client(self, sock) catch return .rearm; + return .disarm; + } + + log.info("accepted connection fd={d}", .{sock.fd}); + self.clients_count += 1; + + read_client(self, sock, c) catch { + log.err("couldn't read from client", .{}); + }; + + return .disarm; +} + +fn sHandler( + self_: ?*Server, + loop: *xev.Loop, + comp: *xev.Completion, + sock: xev.TCP, + e: xev.TCP.ShutdownError!void, +) xev.CallbackAction { + e catch { + // Is this even possible? + log.err("couldn't shutdown socket", .{}); + }; + + const self = self_.?; + + sock.close(loop, comp, Server, self, cHandler); + return .disarm; +} + +/// Closes the client connection and cleans up the completion. +pub fn cHandler( + self_: ?*Server, + _: *xev.Loop, + comp: *xev.Completion, + _: xev.TCP, + e: xev.TCP.CloseError!void, +) xev.CallbackAction { + e catch { + log.err("couldn't close socket", .{}); + }; + + const self = self_.?; + self.comp_pool.destroy(comp); + return .disarm; +} diff --git a/src/tcp/handlers/reader.zig b/src/tcp/handlers/reader.zig index 8d79f33a1..8596c2d92 100644 --- a/src/tcp/handlers/reader.zig +++ b/src/tcp/handlers/reader.zig @@ -2,6 +2,7 @@ const xev = @import("xev"); const std = @import("std"); const Server = @import("../Server.zig").Server; const Command = @import("../Command.zig").Command; +const connections = @import("./connections.zig"); const log = std.log.scoped(.tcp_thread); @@ -29,7 +30,7 @@ fn rHandler( error.EOF => { log.info("client disconnected fd={d}", .{s.fd}); self.clients_count -= 1; - s.close(l, c, Server, self, Server.closeHandler); + s.close(l, c, Server, self, connections.cHandler); return .disarm; }, diff --git a/src/tcp/handlers/reject.zig b/src/tcp/handlers/reject.zig index 07c698f86..c1952d354 100644 --- a/src/tcp/handlers/reject.zig +++ b/src/tcp/handlers/reject.zig @@ -1,6 +1,7 @@ const xev = @import("xev"); const std = @import("std"); const Server = @import("../Server.zig").Server; +const connections = @import("./connections.zig"); const log = std.log.scoped(.tcp_thread); @@ -36,6 +37,6 @@ fn wHandler( return .disarm; }; - client.close(l, c, Server, self, Server.closeHandler); + client.close(l, c, Server, self, connections.cHandler); return .disarm; }