Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add thread deinit API #22

Merged
merged 6 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ fn addLibzimalloc(b: *std.Build, options: LibzimallocOptions) *std.Build.Step.Co

const standalone_tests = [_][]const u8{
"create-destroy-loop.zig",
"multi-threaded-loop.zig",
};

const TestBuildConfig = struct {
Expand Down
121 changes: 121 additions & 0 deletions src/ThreadHeapMap.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
list: List = .{},
lock: std.Thread.RwLock = .{},
pool: Pool = Pool.init(std.heap.page_allocator),

const ThreadHeapMap = @This();

const List = std.DoublyLinkedList(Entry);
const Pool = std.heap.MemoryPool(List.Node);

pub const Entry = struct {
heap: Heap,
thread_id: std.Thread.Id,
};

pub fn deinit(self: *ThreadHeapMap) void {
self.lock.lock();

self.pool.deinit();
self.* = undefined;
}

pub fn initThreadHeap(self: *ThreadHeapMap, thread_id: std.Thread.Id) ?*Entry {
log.debugVerbose("obtaining heap lock", .{});
self.lock.lock();
defer self.lock.unlock();

const node = self.pool.create() catch return null;
node.* = .{
.data = .{ .heap = Heap.init(), .thread_id = thread_id },
};

self.list.prepend(node);

return &node.data;
}

/// behaviour is undefined if `thread_id` is not present in the map
pub fn deinitThread(self: *ThreadHeapMap, thread_id: std.Thread.Id) void {
var iter = self.iterator(.exclusive);
defer iter.unlock();
while (iter.next()) |entry| {
if (entry.thread_id == thread_id) {
entry.heap.deinit();
const node = @fieldParentPtr(List.Node, "data", entry);
self.list.remove(node);
return;
}
}
}

pub fn ownsHeap(self: *ThreadHeapMap, heap: *const Heap) bool {
var iter = self.constIterator(.shared);
defer iter.unlock();
while (iter.next()) |entry| {
if (&entry.heap == heap) return true;
}
return false;
}

pub const LockType = enum {
shared,
exclusive,
};

pub fn constIterator(self: *ThreadHeapMap, comptime kind: LockType) ConstIterator(kind) {
switch (kind) {
.shared => self.lock.lockShared(),
.exclusive => self.lock.lock(),
}
return .{
.current = self.list.first,
.lock = &self.lock,
};
}

pub fn iterator(self: *ThreadHeapMap, comptime kind: LockType) Iterator(kind) {
switch (kind) {
.shared => self.lock.lockShared(),
.exclusive => self.lock.lock(),
}
return .{
.current = self.list.first,
.lock = &self.lock,
};
}

pub fn ConstIterator(comptime kind: LockType) type {
return BaseIterator(*const List.Node, *const Entry, kind);
}

pub fn Iterator(comptime kind: LockType) type {
return BaseIterator(*List.Node, *Entry, kind);
}

fn BaseIterator(comptime NodeType: type, comptime EntryType: type, comptime kind: LockType) type {
return struct {
current: ?NodeType,
lock: *std.Thread.RwLock,

pub fn next(self: *@This()) ?EntryType {
const node = self.current orelse return null;
const result: EntryType = &node.data;
self.current = node.next;
return result;
}

pub fn unlock(self: @This()) void {
switch (kind) {
.shared => self.lock.unlockShared(),
.exclusive => self.lock.unlock(),
}
}
};
}

const std = @import("std");
const Allocator = std.mem.Allocator;

const Heap = @import("Heap.zig");
const log = @import("log.zig");
const list = @import("list.zig");
68 changes: 26 additions & 42 deletions src/allocator.zig
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,12 @@ pub const Config = struct {
pub fn Allocator(comptime config: Config) type {
return struct {
backing_allocator: std.mem.Allocator = std.heap.page_allocator,
thread_heaps: std.SegmentedList(HeapData, config.thread_data_prealloc) = .{},
thread_heaps_lock: std.Thread.RwLock = .{},
thread_heaps: ThreadHeapMap = .{},
huge_allocations: HugeAllocTable(config.store_huge_alloc_size) = .{},
// TODO: atomic access

const Self = @This();

const HeapData = struct {
heap: Heap,
thread_id: std.Thread.Id,
};

pub fn init(backing_allocator: std.mem.Allocator) error{OutOfMemory}!Self {
return .{
.backing_allocator = backing_allocator,
Expand All @@ -28,12 +22,7 @@ pub fn Allocator(comptime config: Config) type {
}

pub fn deinit(self: *Self) void {
self.thread_heaps_lock.lock();
var heap_iter = self.thread_heaps.iterator(0);
while (heap_iter.next()) |data| {
data.heap.deinit();
}
self.thread_heaps.deinit(self.backing_allocator);
self.thread_heaps.deinit();
self.huge_allocations.deinit(std.heap.page_allocator);
self.* = undefined;
}
Expand All @@ -44,26 +33,12 @@ pub fn Allocator(comptime config: Config) type {
const thread_id = std.Thread.getCurrentId();
log.debug("initialising heap for thread {d}", .{thread_id});

log.debugVerbose("obtaining heap lock", .{});
self.thread_heaps_lock.lock();
defer self.thread_heaps_lock.unlock();

const new_ptr = self.thread_heaps.addOne(self.backing_allocator) catch return null;

new_ptr.* = .{ .heap = Heap.init(), .thread_id = thread_id };

log.debug("heap initialised: {*}", .{&new_ptr.heap});
return &new_ptr.heap;
}

fn ownsHeap(self: *Self, heap: *const Heap) bool {
self.thread_heaps_lock.lockShared();
defer self.thread_heaps_lock.unlockShared();
var iter = self.thread_heaps.constIterator(0);
while (iter.next()) |child_data| {
if (&child_data.heap == heap) return true;
if (self.thread_heaps.initThreadHeap(thread_id)) |entry| {
log.debug("heap added to thread map: {*}", .{&entry.heap});
return &entry.heap;
}
return false;

return null;
}

pub fn getThreadHeap(
Expand All @@ -74,12 +49,21 @@ pub fn Allocator(comptime config: Config) type {
const heap = segment.heap;

if (config.safety_checks) {
if (!self.ownsHeap(heap)) return null;
if (!self.thread_heaps.ownsHeap(heap)) return null;
}

return heap;
}

/// behaviour is undefined if `thread_id` is not used by the allocator
pub fn deinitThreadHeap(self: *Self, thread_id: std.Thread.Id) void {
self.thread_heaps.deinitThread(thread_id);
}

pub fn deinitCurrentThreadHeap(self: *Self) void {
self.deinitThreadHeap(std.Thread.getCurrentId());
}

pub fn allocator(self: *Self) std.mem.Allocator {
return .{
.ptr = self,
Expand All @@ -106,16 +90,15 @@ pub fn Allocator(comptime config: Config) type {
const thread_id = std.Thread.getCurrentId();

log.debugVerbose("obtaining shared thread heaps lock", .{});
self.thread_heaps_lock.lockShared();

var iter = self.thread_heaps.iterator(0);
var iter = self.thread_heaps.iterator(.shared);
while (iter.next()) |data| {
if (data.thread_id == thread_id) {
self.thread_heaps_lock.unlockShared();
iter.unlock();
return self.allocInHeap(&data.heap, len, log2_align, ret_addr);
}
} else {
self.thread_heaps_lock.unlockShared();
iter.unlock();
const heap = self.initHeapForThread() orelse return null;
return self.allocInHeap(heap, len, log2_align, ret_addr);
}
Expand Down Expand Up @@ -151,7 +134,7 @@ pub fn Allocator(comptime config: Config) type {
_ = self;
assert.withMessage(
@src(),
@fieldParentPtr(HeapData, "heap", heap).thread_id == std.Thread.getCurrentId(),
@fieldParentPtr(ThreadHeapMap.Entry, "heap", heap).thread_id == std.Thread.getCurrentId(),
"tried to allocated from wrong thread",
);

Expand Down Expand Up @@ -185,7 +168,7 @@ pub fn Allocator(comptime config: Config) type {

pub fn freeNonHugeFromHeap(self: *Self, heap: *Heap, ptr: [*]u8, log2_align: u8, ret_addr: usize) void {
log.debug("freeing non-huge allocation", .{});
if (config.safety_checks) if (!self.ownsHeap(heap)) {
if (config.safety_checks) if (!self.thread_heaps.ownsHeap(heap)) {
log.err("invalid free: {*} is not part of an owned heap", .{ptr});
return;
};
Expand All @@ -200,7 +183,7 @@ pub fn Allocator(comptime config: Config) type {
const page = &page_node.data;
const slot = page.containingSlotSegment(segment, ptr);

const thread_id = @fieldParentPtr(HeapData, "heap", heap).thread_id;
const thread_id = @fieldParentPtr(ThreadHeapMap.Entry, "heap", heap).thread_id;

if (std.Thread.getCurrentId() == thread_id) {
log.debugVerbose("moving slot {*} to local freelist", .{slot.ptr});
Expand Down Expand Up @@ -230,14 +213,14 @@ pub fn Allocator(comptime config: Config) type {

assert.withMessage(@src(), self.huge_allocations.removeRaw(buf.ptr), "huge allocation table corrupt with deallocating");
} else {
log.err("invalid free: {*} is not part of an owned heap", .{buf.ptr});
log.err("invalid huge free: {*} is not part of an owned heap", .{buf.ptr});
}
}

pub fn usableSizeInSegment(self: *Self, ptr: *const anyopaque) usize {
const segment = Segment.ofPtr(ptr);

if (config.safety_checks) if (!self.ownsHeap(segment.heap)) {
if (config.safety_checks) if (!self.thread_heaps.ownsHeap(segment.heap)) {
log.err("invalid pointer: {*} is not part of an owned heap", .{ptr});
return 0;
};
Expand Down Expand Up @@ -397,3 +380,4 @@ const huge_alignment = @import("huge_alignment.zig");
const Heap = @import("Heap.zig");
const Segment = @import("Segment.zig");
const HugeAllocTable = @import("HugeAllocTable.zig").HugeAllocTable;
const ThreadHeapMap = @import("ThreadHeapMap.zig");
12 changes: 5 additions & 7 deletions test/create-destroy-loop.zig
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ pub fn main() !void {
const allocator = zigpa.allocator();

if (comptime build_options.pauses) {
std.debug.print("hit [enter] to enter loop\n", .{});
waitForInput();
waitForInput("enter loop");
}

inline for (.{ 1, 2, 3, 4 }) |_| {
Expand All @@ -24,8 +23,7 @@ pub fn main() !void {

if (comptime build_options.pauses) {
std.debug.print("memory allocated\n", .{});
std.debug.print("hit [enter] to free memory\n", .{});
waitForInput();
waitForInput("free memory");
std.debug.print("freeing memory\n", .{});
}

Expand All @@ -34,13 +32,13 @@ pub fn main() !void {
}
if (comptime build_options.pauses) {
std.debug.print("memory freed\n", .{});
std.debug.print("hit [enter] to continue\n", .{});
waitForInput();
waitForInput("continue");
}
}
}

fn waitForInput() void {
fn waitForInput(action: []const u8) void {
std.debug.print("hit [enter] to {s}\n", .{action});
const stdin = std.io.getStdIn().reader();
var buf: [64]u8 = undefined;
_ = stdin.readUntilDelimiter(&buf, '\n') catch return;
Expand Down
Loading
Loading