Skip to content

Commit

Permalink
feat: add merkleize and hashTreeRoot method
Browse files Browse the repository at this point in the history
  • Loading branch information
PengZhen committed Oct 11, 2024
1 parent cbe1371 commit c39a943
Show file tree
Hide file tree
Showing 3 changed files with 631 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/root.zig
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub const signing_root_helper = @import("consensus/helpers/signing_root.zig");
pub const block_root_helper = @import("consensus/helpers/block_root.zig");
pub const seed_helper = @import("consensus/helpers/seed.zig");
pub const committee_helper = @import("consensus/helpers/committee.zig");
pub const ssz = @import("./common/ssz.zig");
pub const ssz = @import("./ssz/ssz.zig");

test {
@import("std").testing.refAllDeclsRecursive(@This());
Expand Down
360 changes: 359 additions & 1 deletion src/common/ssz.zig → src/ssz/ssz.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
const std = @import("std");
const testing = std.testing;
const Allocator = std.mem.Allocator;
const ArrayList = std.ArrayList;
const sha256 = std.crypto.hash.sha2.Sha256;
const types = @import("../consensus/types.zig");

/// Set of possible errors while performing ssz decoding.
pub const SSZDecodeErrors = Allocator.Error || error{ InvalidEnumType, IndexOutOfBounds };
Expand All @@ -31,7 +34,11 @@ pub fn decodeSSZ(comptime T: type, serialized: []const u8) SSZDecodeErrors!T {
},
.array => |arr_info| {
if (arr_info.child == u8) {
return serialized[0..];
if (serialized.len >= arr_info.len) {
return serialized[0..arr_info.len].*;
} else {
return error.IndexOutOfBounds;
}
}

var result: T = undefined;
Expand Down Expand Up @@ -616,6 +623,18 @@ test "Decode Struct" {
try testing.expectEqualDeep(pastry, decoded);
}

test "Decode Fork" {
const fork = types.Fork{
.current_version = [4]u8{ 1, 2, 3, 4 },
.previous_version = [4]u8{ 1, 2, 3, 4 },
.epoch = 10,
};
const encoded = try encodeSSZ(testing.allocator, fork);
defer testing.allocator.free(encoded);
const decoded = try decodeSSZ(types.Fork, encoded);
try testing.expectEqualDeep(fork, decoded);
}

test "Decode Union" {
const Union = union(enum) {
foo: u32,
Expand Down Expand Up @@ -671,3 +690,342 @@ test "Decode Vector" {
try testing.expectEqualDeep(encoded, decoded);
}
}

/// Number of bytes per chunk.
const BYTES_PER_CHUNK = 32;

/// Number of bytes per serialized length offset.
const BYTES_PER_LENGTH_OFFSET = 4;

fn mixInLength(root: [32]u8, length: [32]u8, out: *[32]u8) void {
var hasher = sha256.init(sha256.Options{});
hasher.update(root[0..]);
hasher.update(length[0..]);
hasher.final(out);
}

test "mixInLength" {
var root: [32]u8 = undefined;
var length: [32]u8 = undefined;
var expected: [32]u8 = undefined;
var mixin: [32]u8 = undefined;
_ = try std.fmt.hexToBytes(root[0..], "2279cf111c15f2d594e7a0055e8735e7409e56ed4250735d6d2f2b0d1bcf8297");
_ = try std.fmt.hexToBytes(length[0..], "deadbeef00000000000000000000000000000000000000000000000000000000");
_ = try std.fmt.hexToBytes(expected[0..], "0b665dda6e4c269730bc4bbe3e990a69d37fa82892bac5fe055ca4f02a98c900");
mixInLength(root, length, &mixin);

try std.testing.expect(std.mem.eql(u8, mixin[0..], expected[0..]));
}

fn mixInSelector(root: [32]u8, comptime selector: usize, out: *[32]u8) void {
var hasher = sha256.init(sha256.Options{});
hasher.update(root[0..]);
var tmp = [_]u8{0} ** 32;
std.mem.writeInt(@TypeOf(selector), tmp[0..@sizeOf(@TypeOf(selector))], selector, .little);
hasher.update(tmp[0..]);
hasher.final(out);
}

test "mixInSelector" {
var root: [32]u8 = undefined;
var expected: [32]u8 = undefined;
var mixin: [32]u8 = undefined;
_ = try std.fmt.hexToBytes(root[0..], "2279cf111c15f2d594e7a0055e8735e7409e56ed4250735d6d2f2b0d1bcf8297");
_ = try std.fmt.hexToBytes(expected[0..], "c483cb731afcfe9f2c596698eaca1c4e0dcb4a1136297adef74c31c268966eb5");
mixInSelector(root, 25, &mixin);

try std.testing.expect(std.mem.eql(u8, mixin[0..], expected[0..]));
}

pub fn chunkCount(comptime T: type) usize {
const info = @typeInfo(T);
switch (info) {
.int, .bool => return 1,
.pointer => return chunkCount(info.pointer.child),
.array => switch (@typeInfo(info.array.child)) {
// Bitvector[N]
.bool => return (info.array.len + 255) / 256,
// Vector[B,N]
.int => return (info.Array.len * @sizeOf(info.Array.child) + 31) / 32,
// Vector[C,N]
else => return info.array.len,
},
.@"struct" => return info.@"struct".fields.len,
else => return error.NotSupported,
}
}

const chunk = [BYTES_PER_CHUNK]u8;
const zero_chunk: chunk = [_]u8{0} ** BYTES_PER_CHUNK;

fn pack(value: anytype, l: *ArrayList(u8)) ![]chunk {
const encoded = try encodeSSZ(l.allocator, value);
try l.appendSlice(encoded);
l.allocator.free(encoded);

const padding_size = (BYTES_PER_CHUNK - l.items.len % BYTES_PER_CHUNK) % BYTES_PER_CHUNK;
try l.appendSlice(zero_chunk[0..padding_size]);

return std.mem.bytesAsSlice(chunk, l.items);
}

test "pack u32" {
var expected: [32]u8 = undefined;
var list = ArrayList(u8).init(std.testing.allocator);
defer list.deinit();
const data: u32 = 0xdeadbeef;
const out = try pack(data, &list);

_ = try std.fmt.hexToBytes(expected[0..], "efbeadde00000000000000000000000000000000000000000000000000000000");

try std.testing.expect(std.mem.eql(u8, out[0][0..], expected[0..]));
}

test "pack bool" {
var expected: [32]u8 = undefined;
var list = ArrayList(u8).init(std.testing.allocator);
defer list.deinit();
const out = try pack(true, &list);

_ = try std.fmt.hexToBytes(expected[0..], "0100000000000000000000000000000000000000000000000000000000000000");

try std.testing.expect(std.mem.eql(u8, out[0][0..], expected[0..]));
}

test "pack string" {
var expected: [128]u8 = undefined;
var list = ArrayList(u8).init(std.testing.allocator);
defer list.deinit();
const data: []const u8 = "a" ** 100;
const out = try pack(data, &list);

_ = try std.fmt.hexToBytes(expected[0..], "6161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616161616100000000000000000000000000000000000000000000000000000000");

try std.testing.expect(expected.len == out.len * out[0].len);
try std.testing.expect(std.mem.eql(u8, out[0][0..], expected[0..32]));
try std.testing.expect(std.mem.eql(u8, out[1][0..], expected[32..64]));
try std.testing.expect(std.mem.eql(u8, out[2][0..], expected[64..96]));
try std.testing.expect(std.mem.eql(u8, out[3][0..], expected[96..]));
}

fn nextPowOfTwo(len: usize) !usize {
if (len == 0) {
return @as(usize, 0);
}

// check that the msb isn't set and
// return an error if it is, as it
// would overflow.
if (@clz(len) == 0) {
return error.OverflowsUSize;
}

const n = std.math.log2(std.math.shl(usize, len, 1) - 1);
return std.math.powi(usize, 2, n);
}

test "next power of 2" {
var out = try nextPowOfTwo(0b1);
try std.testing.expect(out == 1);
out = try nextPowOfTwo(0b10);
try std.testing.expect(out == 2);
out = try nextPowOfTwo(0b11);
try std.testing.expect(out == 4);

// special cases
out = try nextPowOfTwo(0);
try std.testing.expect(out == 0);
try std.testing.expectError(error.OverflowsUSize, nextPowOfTwo(std.math.maxInt(usize)));
}

const hashes_of_zero = @import("./zeros.zig").hashes_of_zero;

pub fn merkleize(chunks: []chunk, limit: ?usize, out: *[32]u8) anyerror!void {
if (limit != null and chunks.len > limit.?) {
return error.ChunkSizeExceedsLimit;
}

const size = try nextPowOfTwo(limit orelse chunks.len);

// Perform the merkleization.
switch (size) {
0 => std.mem.copyForwards(u8, out, &zero_chunk),
1 => std.mem.copyForwards(u8, out, chunks[0][0..]),
else => {
var hasher = sha256.init(sha256.Options{});
var buf: [32]u8 = undefined;
const split = if (size / 2 < chunks.len) size / 2 else chunks.len;
try merkleize(chunks[0..split], size / 2, &buf);
hasher.update(buf[0..]);

if (size / 2 < chunks.len) {
try merkleize(chunks[size / 2 ..], size / 2, &buf);
hasher.update(buf[0..]);
} else hasher.update(hashes_of_zero[size / 2 - 1][0..]);
hasher.final(out);
},
}
}

test "merkleize a string" {
var list = ArrayList(u8).init(std.testing.allocator);
defer list.deinit();
const data: []const u8 = "a" ** 100;
const chunks = try pack(data, &list);
var out: [32]u8 = undefined;
try merkleize(chunks, null, &out);
// Build the expected tree
const leaf1 = [_]u8{0x61} ** 32; // "0xaaaaa....aa" 32 times
var leaf2: [32]u8 = [_]u8{0x61} ** 4 ++ [_]u8{0} ** 28;
var root: [32]u8 = undefined;
var internal_left: [32]u8 = undefined;
var internal_right: [32]u8 = undefined;
var hasher = sha256.init(sha256.Options{});
hasher.update(leaf1[0..]);
hasher.update(leaf1[0..]);
hasher.final(&internal_left);
hasher = sha256.init(sha256.Options{});
hasher.update(leaf1[0..]);
hasher.update(leaf2[0..]);
hasher.final(&internal_right);
hasher = sha256.init(sha256.Options{});
hasher.update(internal_left[0..]);
hasher.update(internal_right[0..]);
hasher.final(&root);

try std.testing.expect(std.mem.eql(u8, out[0..], root[0..]));
}

test "merkleize a boolean" {
var list = ArrayList(u8).init(std.testing.allocator);
defer list.deinit();

var chunks = try pack(false, &list);
var expected = [_]u8{0} ** BYTES_PER_CHUNK;
var out: [BYTES_PER_CHUNK]u8 = undefined;
try merkleize(chunks, null, &out);

try std.testing.expect(std.mem.eql(u8, out[0..], expected[0..]));

var list2 = ArrayList(u8).init(std.testing.allocator);
defer list2.deinit();

chunks = try pack(true, &list2);
expected[0] = 1;
try merkleize(chunks, null, &out);
try std.testing.expect(std.mem.eql(u8, out[0..], expected[0..]));
}

test "merkleize a bytes16 vector with one element" {
var list = ArrayList(u8).init(std.testing.allocator);
defer list.deinit();
const chunks = try pack([_]u8{0xaa} ** 16, &list);
var expected: [32]u8 = [_]u8{0xaa} ** 16 ++ [_]u8{0x00} ** 16;
var out: [32]u8 = undefined;
try merkleize(chunks, null, &out);
try std.testing.expect(std.mem.eql(u8, out[0..], expected[0..]));
}

fn packBits(bits: []const bool, l: ArrayList(u8)) ![]chunk {
var byte: u8 = 0;
for (bits, 0..) |bit, bitIdx| {
if (bit) {
byte |= @as(u8, 1) << @as(u3, @truncate(7 - bitIdx % 8));
}
if (bitIdx % 8 == 7 or bitIdx == bits.len - 1) {
try l.append(byte);
byte = 0;
}
}
// pad the last chunk with 0s
const padding_size = (BYTES_PER_CHUNK - l.items.len % BYTES_PER_CHUNK) % BYTES_PER_CHUNK;
_ = try l.writer().write(zero_chunk[0..padding_size]);

return std.mem.bytesAsSlice(chunk, l.items);
}

pub fn hashTreeRoot(value: anytype, out: *[32]u8, allocator: Allocator) !void {
const type_info = @typeInfo(@TypeOf(value));
switch (type_info) {
.int, .bool => {
var list = ArrayList(u8).init(allocator);
defer list.deinit();
const chunks = try pack(value, &list);
try merkleize(chunks, null, out);
},
.array => {
switch (@typeInfo(type_info.array.child)) {
.int => {
var list = ArrayList(u8).init(allocator);
defer list.deinit();
const chunks = try pack(value, &list);
try merkleize(chunks, null, out);
},
.bool => {
var list = ArrayList(u8).init(allocator);
defer list.deinit();
const chunks = try packBits(value, list);
try merkleize(chunks, null, out);
},
.array => {
var chunks = ArrayList(chunk).init(allocator);
defer chunks.deinit();
var tmp: chunk = undefined;
for (value) |item| {
try hashTreeRoot(item, &tmp, allocator);
try chunks.append(tmp);
}
try merkleize(chunks.items, null, out);
},
else => return error.NotSupported,
}
},
.pointer => {
switch (type_info.Pointer.size) {
.One => hashTreeRoot(value.*, out, allocator),
.Slice => {
switch (@typeInfo(type_info.Pointer.child)) {
.int => {
var list = ArrayList(u8).init(allocator);
defer list.deinit();
const chunks = try pack(value, &list);
merkleize(chunks, null, out);
},
else => return error.UnSupportedPointerType,
}
},
else => return error.UnSupportedPointerType,
}
},
.@"struct" => {
var chunks = ArrayList(chunk).init(allocator);
defer chunks.deinit();
var tmp: chunk = undefined;
inline for (type_info.Struct.fields) |f| {
try hashTreeRoot(@field(value, f.name), &tmp, allocator);
try chunks.append(tmp);
}
try merkleize(chunks.items, null, out);
},
.optional => if (value != null) {
var tmp: chunk = undefined;
try hashTreeRoot(value.?, &tmp, allocator);
mixInSelector(tmp, 1, out);
} else {
mixInSelector(zero_chunk, 0, out);
},
.@"union" => {
if (type_info.@"union".tag_type == null) {
return error.UnionIsNotTagged;
}
inline for (type_info.@"union".fields, 0..) |f, index| {
if (@intFromEnum(value) == index) {
var tmp: chunk = undefined;
try hashTreeRoot(@field(value, f.name), &tmp, allocator);
mixInSelector(tmp, index, out);
}
}
},
else => return error.NotSupported,
}
}
Loading

0 comments on commit c39a943

Please sign in to comment.