Skip to content

Commit

Permalink
AstGen: order declarations correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
alichraghi authored and emidoots committed Feb 25, 2024
1 parent 3e42ed9 commit ffddcf6
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 26 deletions.
1 change: 1 addition & 0 deletions src/shader/Air.zig
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub fn generate(
defer {
astgen.instructions.deinit(allocator);
astgen.scratch.deinit(allocator);
astgen.globals.deinit(allocator);
astgen.global_var_refs.deinit(allocator);
astgen.scope_pool.deinit();
astgen.inst_arena.deinit();
Expand Down
52 changes: 32 additions & 20 deletions src/shader/AstGen.zig
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ strings: std.ArrayListUnmanaged(u8) = .{},
values: std.ArrayListUnmanaged(u8) = .{},
scratch: std.ArrayListUnmanaged(InstIndex) = .{},
global_var_refs: std.AutoArrayHashMapUnmanaged(InstIndex, void) = .{},
globals: std.ArrayListUnmanaged(InstIndex) = .{},
has_array_length: bool = false,
compute_stage: InstIndex = .none,
vertex_stage: InstIndex = .none,
Expand Down Expand Up @@ -59,9 +60,6 @@ pub const Scope = struct {
};

pub fn genTranslationUnit(astgen: *AstGen) !RefIndex {
const scratch_top = astgen.scratch.items.len;
defer astgen.scratch.shrinkRetainingCapacity(scratch_top);

var root_scope = try astgen.scope_pool.create();
root_scope.* = .{ .tag = .root, .parent = undefined };

Expand All @@ -70,15 +68,24 @@ pub fn genTranslationUnit(astgen: *AstGen) !RefIndex {

for (global_nodes) |node| {
var global = root_scope.decls.get(node).? catch continue;
if (global == .none) {
// declaration has not analysed
global = astgen.genGlobalDecl(root_scope, node) catch |err| switch (err) {
error.AnalysisFail => continue,
error.OutOfMemory => return error.OutOfMemory,
};
}

try astgen.scratch.append(astgen.allocator, global);
global = switch (astgen.tree.nodeTag(node)) {
.@"fn" => blk: {
std.debug.assert(global == .none);
break :blk astgen.genFn(root_scope, node, false) catch |err| switch (err) {
error.Skiped => continue,
else => |e| e,
};
},
else => continue,
} catch |err| {
if (err == error.AnalysisFail) {
root_scope.decls.putAssumeCapacity(node, error.AnalysisFail);
continue;
}
return err;
};
root_scope.decls.putAssumeCapacity(node, global);
try astgen.globals.append(astgen.allocator, global);
}

if (astgen.errors.list.items.len > 0) return error.AnalysisFail;
Expand All @@ -91,7 +98,7 @@ pub fn genTranslationUnit(astgen: *AstGen) !RefIndex {
try astgen.errors.add(Loc{ .start = 0, .end = 1 }, "entry point not found", .{}, null);
}

return astgen.addRefList(astgen.scratch.items[scratch_top..]);
return astgen.addRefList(astgen.globals.items);
}

/// adds `nodes` to scope and checks for re-declarations
Expand Down Expand Up @@ -122,23 +129,26 @@ fn scanDecls(astgen: *AstGen, scope: *Scope, nodes: []const NodeIndex) !void {
}
}

fn genGlobalDecl(astgen: *AstGen, scope: *Scope, node: NodeIndex) !InstIndex {
fn genGlobalDecl(astgen: *AstGen, scope: *Scope, node: NodeIndex) error{ OutOfMemory, AnalysisFail }!InstIndex {
const decl = switch (astgen.tree.nodeTag(node)) {
.global_var => astgen.genGlobalVar(scope, node),
.override => astgen.genOverride(scope, node),
.@"const" => astgen.genConst(scope, node),
.@"struct" => astgen.genStruct(scope, node),
.@"fn" => astgen.genFn(scope, node),
.@"fn" => astgen.genFn(scope, node, false),
.type_alias => astgen.genTypeAlias(scope, node),
else => unreachable,
} catch |err| {
if (err == error.AnalysisFail) {
} catch |err| switch (err) {
error.AnalysisFail => {
scope.decls.putAssumeCapacity(node, error.AnalysisFail);
}
return err;
return error.AnalysisFail;
},
error.Skiped => unreachable,
else => |e| return e,
};

scope.decls.putAssumeCapacity(node, decl);
try astgen.globals.append(astgen.allocator, decl);
return decl;
}

Expand Down Expand Up @@ -402,7 +412,7 @@ fn genStructMembers(astgen: *AstGen, scope: *Scope, node: NodeIndex) !RefIndex {
return astgen.addRefList(astgen.scratch.items[scratch_top..]);
}

fn genFn(astgen: *AstGen, root_scope: *Scope, node: NodeIndex) !InstIndex {
fn genFn(astgen: *AstGen, root_scope: *Scope, node: NodeIndex, only_entry_point: bool) !InstIndex {
const scratch_top = astgen.global_var_refs.count();
defer astgen.global_var_refs.shrinkRetainingCapacity(scratch_top);

Expand Down Expand Up @@ -480,6 +490,8 @@ fn genFn(astgen: *AstGen, root_scope: *Scope, node: NodeIndex) !InstIndex {
}
}

if (only_entry_point and stage == .none) return error.Skiped;

if (stage == .compute) {
if (return_type != .none) {
try astgen.errors.add(
Expand Down
13 changes: 7 additions & 6 deletions src/shader/codegen/spirv.zig
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,7 @@ pub fn gen(allocator: std.mem.Allocator, air: *const Air, debug_info: DebugInfo)

for (air.refToList(air.globals_index)) |inst_idx| {
switch (spv.air.getInst(inst_idx)) {
.@"fn" => |@"fn"| if (@"fn".stage != .none) {
_ = try spv.emitFn(inst_idx);
},
.@"fn" => _ = try spv.emitFn(inst_idx),
.@"const" => _ = try spv.emitConst(&spv.global_section, inst_idx),
.@"var" => _ = try spv.emitVarProto(&spv.global_section, inst_idx),
.@"struct" => _ = try spv.emitStruct(inst_idx),
Expand Down Expand Up @@ -1358,8 +1356,11 @@ const PtrAccess = struct {

fn emitVarAccess(spv: *SpirV, section: *Section, inst: InstIndex) !PtrAccess {
const decl = spv.decl_map.get(inst) orelse blk: {
std.debug.assert(spv.air.getInst(inst) == .@"const");
_ = try spv.emitConst(section, inst);
switch (spv.air.getInst(inst)) {
.@"const" => _ = try spv.emitConst(&spv.global_section, inst),
.@"var" => _ = try spv.emitVarProto(&spv.global_section, inst),
else => unreachable,
}
break :blk spv.decl_map.get(inst).?;
};

Expand Down Expand Up @@ -1969,7 +1970,7 @@ fn emitTripleIntrinsic(spv: *SpirV, section: *Section, triple: Inst.TripleIntrin
.smoothstep => 49,
};

if (triple.op == .mix) {
if (triple.op == .mix and spv.air.getInst(triple.result_type) == .vector) {
const vec_type_inst = spv.air.getInst(triple.result_type).vector;
var constituents = std.BoundedArray(IdRef, 4){};
constituents.appendNTimesAssumeCapacity(a3, @intFromEnum(vec_type_inst.size));
Expand Down

0 comments on commit ffddcf6

Please sign in to comment.