From 4f6c67fe9d7c5922eecb5a9fbfd1968552fe1bc3 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Mon, 12 Sep 2022 10:21:18 -0700 Subject: [PATCH] add LRU --- src/lru.zig | 234 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/main.zig | 1 + 2 files changed, 235 insertions(+) create mode 100644 src/lru.zig diff --git a/src/lru.zig b/src/lru.zig new file mode 100644 index 000000000..39b97006c --- /dev/null +++ b/src/lru.zig @@ -0,0 +1,234 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; + +/// Create a HashMap for a key type that can be autoamtically hashed. +/// If you want finer-grained control, use HashMap directly. +pub fn AutoHashMap(comptime K: type, comptime V: type) type { + return HashMap( + K, + V, + std.hash_map.AutoContext(K), + std.hash_map.default_max_load_percentage, + ); +} + +/// HashMap implementation that supports least-recently-used eviction. +/// +/// Note: This is a really elementary CS101 version of an LRU right now. +/// This is done initially to get something working. Once we have it working, +/// we can benchmark and improve if this ends up being a source of slowness. +pub fn HashMap( + comptime K: type, + comptime V: type, + comptime Context: type, + comptime max_load_percentage: u64, +) type { + return struct { + const Self = @This(); + const Map = std.HashMapUnmanaged(K, *Queue.Node, Context, max_load_percentage); + const Queue = std.TailQueue(KV); + + /// Map to maintain our entries. + map: Map, + + /// Queue to maintain LRU order. + queue: Queue, + + /// The capacity of our map. If this capacity is reached, cache + /// misses will begin evicting entries. + capacity: Map.Size, + + pub const KV = struct { + key: K, + value: V, + }; + + /// The result of a getOrPut operation. + pub const GetOrPutResult = struct { + /// The entry that was retrieved. If found_existing is false, + /// then this is a pointer to allocated space to store a V. + /// If found_existing is true, the pointer value is valid, but + /// can be overwritten. + value_ptr: *V, + + /// Whether an existing value was found or not. + found_existing: bool, + + /// If another entry had to be evicted to make space for this + /// put operation, then this is the value that was evicted. + evicted: ?KV, + }; + + pub fn init(capacity: Map.Size) Self { + return .{ + .map = .{}, + .queue = .{}, + .capacity = capacity, + }; + } + + pub fn deinit(self: *Self, alloc: Allocator) void { + // Important: use our queue as a source of truth for dealloc + // because we might keep items in the queue around that aren't + // present in our LRU anymore to prevent future allocations. + var it = self.queue.first; + while (it) |node| { + it = node.next; + alloc.destroy(node); + } + + self.map.deinit(alloc); + self.* = undefined; + } + + /// Get or put a value for a key. See GetOrPutResult on how to check + /// if an existing value was found, if an existing value was evicted, + /// etc. + pub fn getOrPut(self: *Self, allocator: Allocator, key: K) Allocator.Error!GetOrPutResult { + if (@sizeOf(Context) != 0) + @compileError("Cannot infer context " ++ @typeName(Context) ++ ", call getOrPutContext instead."); + return self.getOrPutContext(allocator, key, undefined); + } + + /// See getOrPut + pub fn getOrPutContext( + self: *Self, + alloc: Allocator, + key: K, + ctx: Context, + ) Allocator.Error!GetOrPutResult { + const map_gop = try self.map.getOrPutContext(alloc, key, ctx); + if (map_gop.found_existing) { + // Move to end to mark as most recently used + self.queue.remove(map_gop.value_ptr.*); + self.queue.append(map_gop.value_ptr.*); + + return GetOrPutResult{ + .found_existing = true, + .value_ptr = &map_gop.value_ptr.*.data.value, + .evicted = null, + }; + } + errdefer _ = self.map.remove(key); + + // We're evicting if our map insertion increased our capacity. + const evict = self.map.count() > self.capacity; + + // Get our node. If we're not evicting then we allocate a new + // node. If we are evicting then we avoid allocation by just + // reusing the node we would've evicted. + var node = if (!evict) try alloc.create(Queue.Node) else node: { + // Our first node is the least recently used. + var least_used = self.queue.first.?; + + // Move our least recently used to the end to make + // it the most recently used. + self.queue.remove(least_used); + + // Remove the least used from the map + _ = self.map.remove(least_used.data.key); + + break :node least_used; + }; + errdefer if (!evict) alloc.destroy(node); + + // Store our node in the map. + map_gop.value_ptr.* = node; + + // Mark the node as most recently used + self.queue.append(node); + + // Set our key + node.data.key = key; + + return GetOrPutResult{ + .found_existing = map_gop.found_existing, + .value_ptr = &node.data.value, + .evicted = if (!evict) null else node.data, + }; + } + + /// Get a value for a key. + pub fn get(self: *Self, key: K) ?V { + if (@sizeOf(Context) != 0) { + @compileError("getContext must be used."); + } + return self.getContext(key, undefined); + } + + /// See get + pub fn getContext(self: *Self, key: K, ctx: Context) ?V { + const node = self.map.getContext(key, ctx) orelse return null; + return node.data.value; + } + }; +} + +test "getOrPut" { + const testing = std.testing; + const alloc = testing.allocator; + + const Map = AutoHashMap(u32, u8); + var m = Map.init(2); + defer m.deinit(alloc); + + // Insert cap values, should be hits + { + const gop = try m.getOrPut(alloc, 1); + try testing.expect(!gop.found_existing); + try testing.expect(gop.evicted == null); + gop.value_ptr.* = 1; + } + { + const gop = try m.getOrPut(alloc, 2); + try testing.expect(!gop.found_existing); + try testing.expect(gop.evicted == null); + gop.value_ptr.* = 2; + } + + // 1 is LRU + try testing.expect((try m.getOrPut(alloc, 1)).found_existing); + try testing.expect((try m.getOrPut(alloc, 2)).found_existing); + + // Next should evict + { + const gop = try m.getOrPut(alloc, 3); + try testing.expect(!gop.found_existing); + try testing.expect(gop.evicted != null); + try testing.expect(gop.evicted.?.value == 1); + gop.value_ptr.* = 3; + } + + // Currently: 2 is LRU, let's make 3 LRU + try testing.expect((try m.getOrPut(alloc, 2)).found_existing); + + // Next should evict + { + const gop = try m.getOrPut(alloc, 4); + try testing.expect(!gop.found_existing); + try testing.expect(gop.evicted != null); + try testing.expect(gop.evicted.?.value == 3); + gop.value_ptr.* = 4; + } +} + +test "get" { + const testing = std.testing; + const alloc = testing.allocator; + + const Map = AutoHashMap(u32, u8); + var m = Map.init(2); + defer m.deinit(alloc); + + // Insert cap values, should be hits + { + const gop = try m.getOrPut(alloc, 1); + try testing.expect(!gop.found_existing); + try testing.expect(gop.evicted == null); + gop.value_ptr.* = 1; + } + + try testing.expect(m.get(1) != null); + try testing.expect(m.get(1).? == 1); + try testing.expect(m.get(2) == null); +} diff --git a/src/main.zig b/src/main.zig index 188ff1fef..53c401d24 100644 --- a/src/main.zig +++ b/src/main.zig @@ -119,4 +119,5 @@ test { // TODO _ = @import("config.zig"); _ = @import("cli_args.zig"); + _ = @import("lru.zig"); }