diff --git a/build.zig b/build.zig index 371d423..e8585a1 100644 --- a/build.zig +++ b/build.zig @@ -163,6 +163,7 @@ fn addLibzimalloc(b: *std.Build, options: LibzimallocOptions) *std.Build.Step.Co const standalone_tests = [_][]const u8{ "create-destroy-loop.zig", + "multi-threaded-loop.zig", }; const TestBuildConfig = struct { diff --git a/src/ThreadHeapMap.zig b/src/ThreadHeapMap.zig new file mode 100644 index 0000000..d4c0d0d --- /dev/null +++ b/src/ThreadHeapMap.zig @@ -0,0 +1,121 @@ +list: List = .{}, +lock: std.Thread.RwLock = .{}, +pool: Pool = Pool.init(std.heap.page_allocator), + +const ThreadHeapMap = @This(); + +const List = std.DoublyLinkedList(Entry); +const Pool = std.heap.MemoryPool(List.Node); + +pub const Entry = struct { + heap: Heap, + thread_id: std.Thread.Id, +}; + +pub fn deinit(self: *ThreadHeapMap) void { + self.lock.lock(); + + self.pool.deinit(); + self.* = undefined; +} + +pub fn initThreadHeap(self: *ThreadHeapMap, thread_id: std.Thread.Id) ?*Entry { + log.debugVerbose("obtaining heap lock", .{}); + self.lock.lock(); + defer self.lock.unlock(); + + const node = self.pool.create() catch return null; + node.* = .{ + .data = .{ .heap = Heap.init(), .thread_id = thread_id }, + }; + + self.list.prepend(node); + + return &node.data; +} + +/// behaviour is undefined if `thread_id` is not present in the map +pub fn deinitThread(self: *ThreadHeapMap, thread_id: std.Thread.Id) void { + var iter = self.iterator(.exclusive); + defer iter.unlock(); + while (iter.next()) |entry| { + if (entry.thread_id == thread_id) { + entry.heap.deinit(); + const node = @fieldParentPtr(List.Node, "data", entry); + self.list.remove(node); + return; + } + } +} + +pub fn ownsHeap(self: *ThreadHeapMap, heap: *const Heap) bool { + var iter = self.constIterator(.shared); + defer iter.unlock(); + while (iter.next()) |entry| { + if (&entry.heap == heap) return true; + } + return false; +} + +pub const LockType = enum { + shared, + exclusive, +}; + +pub fn constIterator(self: *ThreadHeapMap, comptime kind: LockType) ConstIterator(kind) { + switch (kind) { + .shared => self.lock.lockShared(), + .exclusive => self.lock.lock(), + } + return .{ + .current = self.list.first, + .lock = &self.lock, + }; +} + +pub fn iterator(self: *ThreadHeapMap, comptime kind: LockType) Iterator(kind) { + switch (kind) { + .shared => self.lock.lockShared(), + .exclusive => self.lock.lock(), + } + return .{ + .current = self.list.first, + .lock = &self.lock, + }; +} + +pub fn ConstIterator(comptime kind: LockType) type { + return BaseIterator(*const List.Node, *const Entry, kind); +} + +pub fn Iterator(comptime kind: LockType) type { + return BaseIterator(*List.Node, *Entry, kind); +} + +fn BaseIterator(comptime NodeType: type, comptime EntryType: type, comptime kind: LockType) type { + return struct { + current: ?NodeType, + lock: *std.Thread.RwLock, + + pub fn next(self: *@This()) ?EntryType { + const node = self.current orelse return null; + const result: EntryType = &node.data; + self.current = node.next; + return result; + } + + pub fn unlock(self: @This()) void { + switch (kind) { + .shared => self.lock.unlockShared(), + .exclusive => self.lock.unlock(), + } + } + }; +} + +const std = @import("std"); +const Allocator = std.mem.Allocator; + +const Heap = @import("Heap.zig"); +const log = @import("log.zig"); +const list = @import("list.zig"); diff --git a/src/allocator.zig b/src/allocator.zig index 983e062..2187991 100644 --- a/src/allocator.zig +++ b/src/allocator.zig @@ -8,18 +8,12 @@ pub const Config = struct { pub fn Allocator(comptime config: Config) type { return struct { backing_allocator: std.mem.Allocator = std.heap.page_allocator, - thread_heaps: std.SegmentedList(HeapData, config.thread_data_prealloc) = .{}, - thread_heaps_lock: std.Thread.RwLock = .{}, + thread_heaps: ThreadHeapMap = .{}, huge_allocations: HugeAllocTable(config.store_huge_alloc_size) = .{}, // TODO: atomic access const Self = @This(); - const HeapData = struct { - heap: Heap, - thread_id: std.Thread.Id, - }; - pub fn init(backing_allocator: std.mem.Allocator) error{OutOfMemory}!Self { return .{ .backing_allocator = backing_allocator, @@ -28,12 +22,7 @@ pub fn Allocator(comptime config: Config) type { } pub fn deinit(self: *Self) void { - self.thread_heaps_lock.lock(); - var heap_iter = self.thread_heaps.iterator(0); - while (heap_iter.next()) |data| { - data.heap.deinit(); - } - self.thread_heaps.deinit(self.backing_allocator); + self.thread_heaps.deinit(); self.huge_allocations.deinit(std.heap.page_allocator); self.* = undefined; } @@ -44,26 +33,12 @@ pub fn Allocator(comptime config: Config) type { const thread_id = std.Thread.getCurrentId(); log.debug("initialising heap for thread {d}", .{thread_id}); - log.debugVerbose("obtaining heap lock", .{}); - self.thread_heaps_lock.lock(); - defer self.thread_heaps_lock.unlock(); - - const new_ptr = self.thread_heaps.addOne(self.backing_allocator) catch return null; - - new_ptr.* = .{ .heap = Heap.init(), .thread_id = thread_id }; - - log.debug("heap initialised: {*}", .{&new_ptr.heap}); - return &new_ptr.heap; - } - - fn ownsHeap(self: *Self, heap: *const Heap) bool { - self.thread_heaps_lock.lockShared(); - defer self.thread_heaps_lock.unlockShared(); - var iter = self.thread_heaps.constIterator(0); - while (iter.next()) |child_data| { - if (&child_data.heap == heap) return true; + if (self.thread_heaps.initThreadHeap(thread_id)) |entry| { + log.debug("heap added to thread map: {*}", .{&entry.heap}); + return &entry.heap; } - return false; + + return null; } pub fn getThreadHeap( @@ -74,12 +49,21 @@ pub fn Allocator(comptime config: Config) type { const heap = segment.heap; if (config.safety_checks) { - if (!self.ownsHeap(heap)) return null; + if (!self.thread_heaps.ownsHeap(heap)) return null; } return heap; } + /// behaviour is undefined if `thread_id` is not used by the allocator + pub fn deinitThreadHeap(self: *Self, thread_id: std.Thread.Id) void { + self.thread_heaps.deinitThread(thread_id); + } + + pub fn deinitCurrentThreadHeap(self: *Self) void { + self.deinitThreadHeap(std.Thread.getCurrentId()); + } + pub fn allocator(self: *Self) std.mem.Allocator { return .{ .ptr = self, @@ -106,16 +90,15 @@ pub fn Allocator(comptime config: Config) type { const thread_id = std.Thread.getCurrentId(); log.debugVerbose("obtaining shared thread heaps lock", .{}); - self.thread_heaps_lock.lockShared(); - var iter = self.thread_heaps.iterator(0); + var iter = self.thread_heaps.iterator(.shared); while (iter.next()) |data| { if (data.thread_id == thread_id) { - self.thread_heaps_lock.unlockShared(); + iter.unlock(); return self.allocInHeap(&data.heap, len, log2_align, ret_addr); } } else { - self.thread_heaps_lock.unlockShared(); + iter.unlock(); const heap = self.initHeapForThread() orelse return null; return self.allocInHeap(heap, len, log2_align, ret_addr); } @@ -151,7 +134,7 @@ pub fn Allocator(comptime config: Config) type { _ = self; assert.withMessage( @src(), - @fieldParentPtr(HeapData, "heap", heap).thread_id == std.Thread.getCurrentId(), + @fieldParentPtr(ThreadHeapMap.Entry, "heap", heap).thread_id == std.Thread.getCurrentId(), "tried to allocated from wrong thread", ); @@ -185,7 +168,7 @@ pub fn Allocator(comptime config: Config) type { pub fn freeNonHugeFromHeap(self: *Self, heap: *Heap, ptr: [*]u8, log2_align: u8, ret_addr: usize) void { log.debug("freeing non-huge allocation", .{}); - if (config.safety_checks) if (!self.ownsHeap(heap)) { + if (config.safety_checks) if (!self.thread_heaps.ownsHeap(heap)) { log.err("invalid free: {*} is not part of an owned heap", .{ptr}); return; }; @@ -200,7 +183,7 @@ pub fn Allocator(comptime config: Config) type { const page = &page_node.data; const slot = page.containingSlotSegment(segment, ptr); - const thread_id = @fieldParentPtr(HeapData, "heap", heap).thread_id; + const thread_id = @fieldParentPtr(ThreadHeapMap.Entry, "heap", heap).thread_id; if (std.Thread.getCurrentId() == thread_id) { log.debugVerbose("moving slot {*} to local freelist", .{slot.ptr}); @@ -230,14 +213,14 @@ pub fn Allocator(comptime config: Config) type { assert.withMessage(@src(), self.huge_allocations.removeRaw(buf.ptr), "huge allocation table corrupt with deallocating"); } else { - log.err("invalid free: {*} is not part of an owned heap", .{buf.ptr}); + log.err("invalid huge free: {*} is not part of an owned heap", .{buf.ptr}); } } pub fn usableSizeInSegment(self: *Self, ptr: *const anyopaque) usize { const segment = Segment.ofPtr(ptr); - if (config.safety_checks) if (!self.ownsHeap(segment.heap)) { + if (config.safety_checks) if (!self.thread_heaps.ownsHeap(segment.heap)) { log.err("invalid pointer: {*} is not part of an owned heap", .{ptr}); return 0; }; @@ -397,3 +380,4 @@ const huge_alignment = @import("huge_alignment.zig"); const Heap = @import("Heap.zig"); const Segment = @import("Segment.zig"); const HugeAllocTable = @import("HugeAllocTable.zig").HugeAllocTable; +const ThreadHeapMap = @import("ThreadHeapMap.zig"); diff --git a/test/create-destroy-loop.zig b/test/create-destroy-loop.zig index a0231b1..a4fa94d 100644 --- a/test/create-destroy-loop.zig +++ b/test/create-destroy-loop.zig @@ -9,8 +9,7 @@ pub fn main() !void { const allocator = zigpa.allocator(); if (comptime build_options.pauses) { - std.debug.print("hit [enter] to enter loop\n", .{}); - waitForInput(); + waitForInput("enter loop"); } inline for (.{ 1, 2, 3, 4 }) |_| { @@ -24,8 +23,7 @@ pub fn main() !void { if (comptime build_options.pauses) { std.debug.print("memory allocated\n", .{}); - std.debug.print("hit [enter] to free memory\n", .{}); - waitForInput(); + waitForInput("free memory"); std.debug.print("freeing memory\n", .{}); } @@ -34,13 +32,13 @@ pub fn main() !void { } if (comptime build_options.pauses) { std.debug.print("memory freed\n", .{}); - std.debug.print("hit [enter] to continue\n", .{}); - waitForInput(); + waitForInput("continue"); } } } -fn waitForInput() void { +fn waitForInput(action: []const u8) void { + std.debug.print("hit [enter] to {s}\n", .{action}); const stdin = std.io.getStdIn().reader(); var buf: [64]u8 = undefined; _ = stdin.readUntilDelimiter(&buf, '\n') catch return; diff --git a/test/multi-threaded-loop.zig b/test/multi-threaded-loop.zig new file mode 100644 index 0000000..b05043d --- /dev/null +++ b/test/multi-threaded-loop.zig @@ -0,0 +1,106 @@ +var zigpa = ZiAllocator(.{}){}; + +pub fn main() !void { + defer zigpa.deinit(); + + const max_spawn_count = 128 * 5; + var threads: [max_spawn_count]std.Thread = undefined; + + const concurrency_limit = try std.Thread.getCpuCount(); + const spawn_count = 5 * concurrency_limit; + + var semaphore = std.Thread.Semaphore{}; + + var wg = std.Thread.WaitGroup{}; + + var init_count = std.atomic.Value(usize).init(spawn_count); + + for (threads[0..spawn_count], 0..) |*thread, i| { + thread.* = try std.Thread.spawn(.{}, run, .{ i, &wg, &semaphore, &init_count }); + } + + while (init_count.load(.Acquire) != 0) { + std.atomic.spinLoopHint(); + } + + std.log.debug("starting loops", .{}); + { + semaphore.mutex.lock(); + defer semaphore.mutex.unlock(); + semaphore.permits = concurrency_limit; + semaphore.cond.broadcast(); + } + + wg.wait(); + + std.log.debug("joining threads", .{}); + for (threads[0..spawn_count]) |thread| { + thread.join(); + } +} + +threadlocal var thread_index: ?usize = null; + +fn run(index: usize, wg: *std.Thread.WaitGroup, semaphore: *std.Thread.Semaphore, init: *std.atomic.Value(usize)) !void { + wg.start(); + defer wg.finish(); + + defer zigpa.deinitCurrentThreadHeap(); + + thread_index = index; + std.log.debug("starting thread", .{}); + + const allocator = zigpa.allocator(); + + _ = init.fetchSub(1, .Release); + + for (1..5) |i| { + semaphore.wait(); + defer semaphore.post(); + + std.log.debug("running iteration {d}", .{i}); + + var buf: [50000]*[256]u8 = undefined; // pointers to 12 MiB of data + + for (&buf) |*ptr| { + const b = try allocator.create([256]u8); + b.* = [1]u8{1} ** 256; + ptr.* = b; + } + + for (buf) |ptr| { + allocator.destroy(ptr); + } + + std.log.debug("finished iteration {d}", .{i}); + } +} + +pub const std_options: std.Options = .{ + .logFn = logFn, +}; + +fn logFn( + comptime message_level: std.log.Level, + comptime scope: @TypeOf(.enum_literal), + comptime format: []const u8, + args: anytype, +) void { + if (comptime !std.log.logEnabled(message_level, scope)) return; + + const level_txt = comptime message_level.asText(); + const prefix1 = "[Thread {?d}-{d}] "; + const prefix2 = if (scope == .default) ": " else "(" ++ @tagName(scope) ++ "): "; + const stderr = std.io.getStdErr().writer(); + std.debug.getStderrMutex().lock(); + defer std.debug.getStderrMutex().unlock(); + nosuspend stderr.print( + prefix1 ++ level_txt ++ prefix2 ++ format ++ "\n", + .{ thread_index, std.Thread.getCurrentId() } ++ args, + ) catch return; +} + +const std = @import("std"); + +const build_options = @import("build_options"); +const ZiAllocator = @import("zimalloc").Allocator;