diff --git a/src/compiler.pr b/src/compiler.pr index 5b4c2f0f..ace7a2d3 100644 --- a/src/compiler.pr +++ b/src/compiler.pr @@ -1522,7 +1522,7 @@ def convert_ref_to_ptr(tpe: &typechecking::Type, value: Value, loc: &Value, stat return bitcast_ret } -def convert_value_to_ref(tpe: &typechecking::Type, value: Value, loc: &Value, state: &State) -> Value { +def convert_value_to_ref(tpe: &typechecking::Type, value: Value, loc: &Value, state: &State, initial_ref_count: size_t = 0) -> Value { if tpe.tpe and value.tpe.kind != tpe.tpe.kind { value = convert_to(loc, value, tpe.tpe, state) } @@ -1564,7 +1564,7 @@ def convert_value_to_ref(tpe: &typechecking::Type, value: Value, loc: &Value, st let store1 = make_insn_dbg(InsnKind::STORE, loc) store1.value.store = [ loc = refcount, - value = [ kind = ValueKind::INT, tpe = builtins::int64_, i = 0 ] !Value + value = [ kind = ValueKind::INT, tpe = builtins::int64_, i = initial_ref_count ] !Value ] !InsnStore push_insn(store1, state) } @@ -2077,6 +2077,19 @@ def convert_to(kind: InsnKind, loc: &Value, value: Value, tpe: &typechecking::Ty return ret } +def get_embed_field(left: &typechecking::Type, right: &typechecking::Type, state: &State) -> &typechecking::StructMember { + if is_interface(left) and is_struct(right) { + if typechecking::implements(right, left, state.module, check_embed = false) { return null } + for var field in @right.fields { + if field.is_embed and typechecking::implements(field.tpe, left, state.module, check_embed = false) { + return field + } + } + } + + return null +} + // value gets loaded by this function def convert_to(loc: &Value, value: Value, tpe: &typechecking::Type, state: &State) -> Value { if not value.tpe or not tpe { return NO_VALUE } @@ -2114,6 +2127,42 @@ def convert_to(loc: &Value, value: Value, tpe: &typechecking::Type, state: &Stat return value } } + let left = tpe.tpe if is_ref(tpe) else tpe + let right = value.tpe.tpe if is_ref(value.tpe) else value.tpe + let embed_field = get_embed_field(left, right, state) + + // Try to convert to embedded struct / reference + if is_struct(right) and (is_struct(left) or embed_field) { + var is_embed = false + var field: StructMember + if embed_field { + is_embed = true + field = @embed_field + } else { + for var f in @right.fields { + if f.is_embed and (equals(f.tpe, tpe) or equals(f.tpe, tpe.tpe)) { + is_embed = true + field = f + break + } + } + } + if is_embed { + var unwrap = load_value(value, loc, state) + // Unwrap reference on the value side + if is_ref(value.tpe) { + let ref = state.extract_value(pointer(right), unwrap, [1], loc) + unwrap = state.load(right, ref, loc) + } + // Extract element + var elem = state.extract_value(field.tpe, unwrap, [field.index !int], loc) + // Wrap in reference if needed + if is_ref(tpe) and not is_ref(field.tpe) { + elem = convert_value_to_ref(tpe, elem, loc, state, 1) + } + return elem + } + } if tpe.kind == value.tpe.kind and value.tpe.is_anon and typechecking::is_struct(value.tpe) { return convert_anon_to_struct(tpe, value, loc, state) } @@ -2322,7 +2371,18 @@ def walk_StructLitUnion(node: &parser::Node, state: &State) -> Value { return load_ret } +// TODO add loc to this def locals_to_insert_value(value: &Value, state: &State) { + import_cstd_function("malloc", state) + + var aref = is_ref(value.tpe) + let reftpe = value.tpe + var new_value = value + if aref { + new_value = @value + new_value.tpe = value.tpe.tpe + } + let values = value.values for var i in 0..values.size { let val = values(i) @@ -2332,22 +2392,87 @@ def locals_to_insert_value(value: &Value, state: &State) { tpe = val.tpe ] !Value - let ret = make_local_value(value.tpe, null, state) + let ret = make_local_value(new_value.tpe, null, state) let index = allocate_ref(int, 1) index(0) = i let insert = make_insn(InsnKind::INSERTVALUE) (@insert).value.insert_value = [ ret = ret, - value = @value, + value = @new_value, element = val, index = index ] !InsnInsertValue push_insn(insert, state) - @value = ret + @new_value = ret } } + if aref { + @value = convert_value_to_ref(reftpe, @new_value, null, state, 1) + } +} + +def struct_lit_create_value(tpe: &typechecking::Type, kwargs: &Vector(&parser::Node), loc: &Value, state: &State) -> &[Value] { + if not tpe { return null } + if is_ref(tpe) { tpe = tpe.tpe } + + var types = vector::make(type &typechecking::Type) + if tpe.kind == typechecking::TypeKind::TUPLE { + types = tpe.return_t + } else { + for var field in @tpe.fields { + types.push(field.tpe) + } + } + + let values = allocate_ref(Value, types.length) + for var i in 0..values.size { + values(i) = [ + kind = ValueKind::ZEROINITIALIZER, + tpe = types(i) + ] !Value + } + + if is_struct(tpe) { + for var k in 0..tpe.fields.size { + let field = tpe.fields(k) + if field.is_embed { + let new_values = struct_lit_create_value(field.tpe, kwargs, loc, state) + let new_value = [ + kind = ValueKind::STRUCT, + values = new_values, + tpe = field.tpe + ] !&Value + + locals_to_insert_value(new_value, state) + + values(k) = @new_value + } + } + } + + for var i in 0..vector::length(kwargs) { + let kwarg = kwargs(i) + let name = typechecking::last_ident_to_str((@kwarg).value.named_arg.name) + let value = walk_expression((@kwarg).value.named_arg.value, state) + + for var j in 0..tpe.fields.size { + let field = tpe.fields(j) + if field.name == name { + values(j) = convert_to(kwarg, value, field.tpe, state) + if typechecking::is_ref(field.tpe) { + increase_ref_count_of_value(values(j), loc, state) + } else if typechecking::has_copy_constructor(field.tpe) { + let ret = state.alloca(values(j).tpe, loc, no_yield_capture = true) + insert_copy_constructor(ret, values(j), loc, state) + values(j) = state.load(values(j).tpe, ret, loc) + } + break + } + } + } + return values } def walk_StructLit(node: &parser::Node, state: &State) -> Value { @@ -2367,23 +2492,8 @@ def walk_StructLit(node: &parser::Node, state: &State) -> Value { } else if tpe.kind == typechecking::TypeKind::UNION { value = walk_StructLitUnion(node, state) } else { - var types = vector::make(type &typechecking::Type) - if tpe.kind == typechecking::TypeKind::TUPLE { - types = tpe.return_t - } else { - for var field in @tpe.fields { - types.push(field.tpe) - } - } - - let values = allocate_ref(Value, types.length) - for var i in 0..values.size { - values(i) = [ - kind = ValueKind::ZEROINITIALIZER, - tpe = types(i) - ] !Value - } - for var i in 0..vector::length(args) { + // args no longer valid + /*for var i in 0..vector::length(args) { let arg = args(i) let arg_tpe = types(i) let value = walk_expression(arg, state) @@ -2395,27 +2505,10 @@ def walk_StructLit(node: &parser::Node, state: &State) -> Value { insert_copy_constructor(ret, values(i), loc, state) values(i) = state.load(values(i).tpe, ret, loc) } - } - for var i in 0..vector::length(kwargs) { - let kwarg = kwargs(i) - let name = typechecking::last_ident_to_str((@kwarg).value.named_arg.name) - let value = walk_expression((@kwarg).value.named_arg.value, state) - - for var j in 0..tpe.fields.size { - let field = tpe.fields(j) - if field.name == name { - values(j) = convert_to(kwarg, value, field.tpe, state) - if typechecking::is_ref(field.tpe) { - increase_ref_count_of_value(values(j), loc, state) - } else if typechecking::has_copy_constructor(field.tpe) { - let ret = state.alloca(values(j).tpe, loc, no_yield_capture = true) - insert_copy_constructor(ret, values(j), loc, state) - values(j) = state.load(values(j).tpe, ret, loc) - } - break - } - } - } + }*/ + + let values = struct_lit_create_value(tpe, kwargs, loc, state) + value = [ kind = ValueKind::STRUCT, values = values, @@ -3987,6 +4080,8 @@ def walk_MemberAccess_gep(node: &parser::Node, tpe: &typechecking::Type, type Member = struct { index: int tpe: &typechecking::Type + is_embed: bool + value: &Value } // This list needs to be reversed to find the actual indices @@ -4005,22 +4100,38 @@ def resolve_member(vec: &Vector(Member), tpe: &typechecking::Type, name: Str) -> return true } } else { - let found = resolve_member(vec, field.tpe, name) + var tpe = field.tpe + if field.is_embed and is_ref(tpe) { tpe = tpe.tpe } + let found = resolve_member(vec, tpe, name) if found { let member = [ index = field.index !int, - tpe = field.tpe + tpe = field.tpe, + is_embed = field.is_embed ] !Member vec.push(member) return true } } } + for var field in @tpe.const_fields { + if field.name == name { + let member = [ + value = field.value + ] !Member + vec.push(member) + return true + } + } + return false } def walk_MemberAccess_struct(node: &parser::Node, tpe: &typechecking::Type, member: &Member, value: Value, state: &State) -> Value { let loc = make_location(node, state) + if member.value { + return @member.value + } var member_type = member.tpe if member.tpe.kind == typechecking::TypeKind::BOX { @@ -4028,7 +4139,6 @@ def walk_MemberAccess_struct(node: &parser::Node, tpe: &typechecking::Type, memb } if tpe.kind == typechecking::TypeKind::UNION { - let index = allocate_ref(Value, 2) index(0) = make_int_value(0) index(1) = make_int_value(0) @@ -4057,8 +4167,17 @@ def walk_MemberAccess_struct(node: &parser::Node, tpe: &typechecking::Type, memb let index = allocate_ref(Value, 2) index(0) = make_int_value(0) index(1) = make_int_value((@member).index) - return walk_MemberAccess_gep(node, tpe, member_type, value, index, state) + let res = walk_MemberAccess_gep(node, tpe, member_type, value, index, state) + + if member.is_embed and is_ref(member.tpe) { + let ref = state.load(member.tpe, @res.addr, loc) + let ptr = state.extract_value(pointer(member.tpe.tpe), ref, [1], loc) + return make_address_value(pointer(member.tpe.tpe), ptr, state) + } + + return res } + } def walk_MemberAccess(node: &parser::Node, state: &State) -> Value { @@ -9427,20 +9546,38 @@ def generate_vtable_function(function: &Function, tpe: &typechecking::Type, stat state.ret(NO_VALUE) } else { // Getter - var deref = state.extract_value(pointer(type_entry.tpe.tpe), reference, [1]) - var value = state.load(type_entry.tpe.tpe, deref) let name = function.unmangled - var findex: size_t = 0 - var ftpe: &typechecking::Type - for var field in @type_entry.tpe.tpe.fields { - if field.name == name { - findex = field.index - ftpe = field.tpe + + var const_field: typechecking::StructMember + var is_const = false + let const_fields = type_entry.tpe.tpe.const_fields + if const_fields { + for var field in @const_fields { + if field.name == name { + is_const = true + const_field = field + break + } } } - value = state.extract_value(ftpe, value, [findex !int]) - state.ret(value) + if is_const { + state.ret(@const_field.value) + } else { + var deref = state.extract_value(pointer(type_entry.tpe.tpe), reference, [1]) + var value = state.load(type_entry.tpe.tpe, deref) + var findex: size_t = 0 + var ftpe: &typechecking::Type + for var field in @type_entry.tpe.tpe.fields { + if field.name == name { + findex = field.index + ftpe = field.tpe + } + } + + value = state.extract_value(ftpe, value, [findex !int]) + state.ret(value) + } } } } diff --git a/src/consteval.pr b/src/consteval.pr index cdffde9d..9957e667 100644 --- a/src/consteval.pr +++ b/src/consteval.pr @@ -45,6 +45,7 @@ export def make_value(value: compiler::Value) -> &scope::Value { export def copy_state(state: &typechecking::State) -> &typechecking::State { let new_state: &typechecking::State = @state new_state.function_stack = vector::copy(state.function_stack) + new_state.context = null return new_state } diff --git a/src/debug.pr b/src/debug.pr index b4786055..3dc490f4 100644 --- a/src/debug.pr +++ b/src/debug.pr @@ -95,6 +95,15 @@ def id_decl_to_json(node: &parser::Node, types: bool) -> &Json { def id_decl_struct_to_json(node: &parser::Node, types: bool) -> &Json { let res = json::make_object() res("kind") = "IdDeclStruct" + if node.value.id_decl_struct.is_embed { + res("is_embed") = true + } + if node.value.id_decl_struct.is_bitfield { + res("bit_size") = node.value.id_decl_struct.bit_size + } + if node.value.id_decl_struct.is_const { + res("value") = node_to_json(node.value.id_decl_struct.value, types) + } res("ident") = node_to_json(node.value.id_decl_struct.ident, types) res("tpe") = node_to_json(node.value.id_decl_struct.tpe, types) return res diff --git a/src/parser.pr b/src/parser.pr index 2a3953c1..ec02c3f3 100644 --- a/src/parser.pr +++ b/src/parser.pr @@ -317,7 +317,10 @@ export type NodeIdDeclStruct = struct { ident: &Node tpe: &Node is_bitfield: bool + is_embed: bool bit_size: size_t + is_const: bool + value: &Node } export type NodeEnumT = struct { @@ -826,6 +829,7 @@ export def offset(node: &Node, changes: &[server::TextDocumentChangeEvent]) { case NodeKind::ID_DECL_STRUCT offset(node.value.id_decl_struct.ident, changes) offset(node.value.id_decl_struct.tpe, changes) + offset(node.value.id_decl_struct.value, changes) case NodeKind::ID_DECL_ENUM offset(node.value.id_decl_enum.ident, changes) offset(node.value.id_decl_enum.value, changes) @@ -960,6 +964,7 @@ export def clear(node: &Node) { case NodeKind::ID_DECL_STRUCT clear(node.value.id_decl_struct.ident) clear(node.value.id_decl_struct.tpe) + clear(node.value.id_decl_struct.value) case NodeKind::ID_DECL_ENUM clear(node.value.id_decl_enum.ident) clear(node.value.id_decl_enum.value) @@ -1155,6 +1160,8 @@ export def find(node: &Node, line: int, column: int) -> &Node { if n2 { return n2 } n2 = find(node.value.id_decl_struct.tpe, line, column) if n2 { return n2 } + n2 = find(node.value.id_decl_struct.value, line, column) + if n2 { return n2 } case NodeKind::ID_DECL_ENUM var n2 = find(node.value.id_decl_enum.ident, line, column) if n2 { return n2 } @@ -1379,6 +1386,7 @@ export def deep_copy_node(node: &Node, clear_svalue: bool = true) -> &Node { case NodeKind::ID_DECL_STRUCT copy.value.id_decl_struct.ident = deep_copy_node(node.value.id_decl_struct.ident, clear_svalue) copy.value.id_decl_struct.tpe = deep_copy_node(node.value.id_decl_struct.tpe, clear_svalue) + copy.value.id_decl_struct.value = deep_copy_node(node.value.id_decl_struct.value, clear_svalue) case NodeKind::ID_DECL_ENUM copy.value.id_decl_enum.ident = deep_copy_node(node.value.id_decl_enum.ident, clear_svalue) copy.value.id_decl_enum.value = deep_copy_node(node.value.id_decl_enum.value, clear_svalue) @@ -1955,22 +1963,43 @@ def parse_id_decl_struct(parse_state: &ParseState) -> &Node { skip_newline(parse_state) token = peek(parse_state) + var is_embed = false + var is_const = false var ident: &Node = null var tpe: &Node = null + var value: &Node = null if token.tpe == lexer::TokenType::COLON { pop(parse_state) skip_newline(parse_state) tpe = expect_type(parse_state) - } else { + } else if token.tpe == lexer::TokenType::K_CONST { + pop(parse_state) + is_const = true ident = expect_identifier(parse_state) - skip_newline(parse_state) expect(parse_state, lexer::TokenType::COLON, "Expected ':'") - skip_newline(parse_state) tpe = expect_type(parse_state) + expect(parse_state, lexer::TokenType::OP_ASSIGN, "Expected '='") + value = expect_expression(parse_state) + } else { + if token.tpe == lexer::TokenType::OP_BAND { + is_embed = true + tpe = expect_type(parse_state) + } else { + ident = expect_identifier(parse_state) + token = peek(parse_state) + if token.tpe == lexer::TokenType::COLON { + pop(parse_state) + tpe = expect_type(parse_state) + } else { + tpe = ident + ident = null + is_embed = true + } + } } - if not ident and not is_bitfield { + if not ident and not is_bitfield and not is_embed { errors::errort(token, parse_state, "Expected identifier") } @@ -1979,9 +2008,12 @@ def parse_id_decl_struct(parse_state: &ParseState) -> &Node { ident = ident, tpe = tpe, is_bitfield = is_bitfield, - bit_size = bit_size + is_embed = is_embed, + bit_size = bit_size, + is_const = is_const, + value = value ] !NodeIdDeclStruct - node._hash = combine_hashes(node.kind !uint64, is_bitfield !uint64, bit_size, hash(ident), hash(tpe)) + node._hash = combine_hashes(node.kind !uint64, is_bitfield !uint64, bit_size, is_embed !uint64, hash(ident), hash(tpe), hash(value)) } parse_t_term(parse_state) diff --git a/src/scope.pr b/src/scope.pr index 62f1f45c..69f80490 100644 --- a/src/scope.pr +++ b/src/scope.pr @@ -1712,16 +1712,16 @@ export def create_polymorph( scope = scope.module.scope // Global scope let module_name = module.filename - let v = create_function_value(scope, name_node, parser::ShareMarker::NONE, Phase::COMPILED, node, state, tpe, impl) if not scope.polymorphics { scope.polymorphics = map::make(type &Scope) } var poly_scope = scope.polymorphics.get_or_default(module_name, null) if not poly_scope { - poly_scope = enter_scope(null, scope.module) + poly_scope = enter_scope(null, null) + scope.polymorphics(module_name) = poly_scope } - scope.polymorphics(module_name) = poly_scope + let v = create_function_value(scope, name_node, parser::ShareMarker::NONE, Phase::COMPILED, node, null, tpe, impl) return create_function(poly_scope, name_node, v, impl) } diff --git a/src/serialize.pr b/src/serialize.pr index 0e599ba1..b3f797cd 100644 --- a/src/serialize.pr +++ b/src/serialize.pr @@ -533,7 +533,7 @@ def deserialize_type(deserialize: &Deserialize, fp: File, tpe: &typechecking::Ty field_types.push(type_member) tpe.field_types = field_types } else { - typechecking::make_struct_type(fields, tpe) + typechecking::make_struct_type(fields, null, tpe) } tpe.size = size tpe.align = align diff --git a/src/typechecking.pr b/src/typechecking.pr index 440a1cdb..c2b291bf 100644 --- a/src/typechecking.pr +++ b/src/typechecking.pr @@ -84,6 +84,10 @@ export type StructMember = struct { is_bitfield: bool bit_size: size_t bit_offset: size_t + + is_embed: bool + is_const: bool + value: &compiler::Value } export type StructuralTypeMember = struct { @@ -141,6 +145,8 @@ export type Type = struct { packed: bool // Fields for struct, array of StructMember fields: &[StructMember] + // Constants + const_fields: &[StructMember] // Vector of TypeMember field_types: &Vector(TypeMember) // Function and Tuple @@ -899,7 +905,7 @@ export def clear_type_cache { } } -export def iterate_member_functions(tpe: &Type, visited: &Set(&Type) = null) -> TypeEntryMember { +export def iterate_member_functions(tpe: &Type, visited: &Set(&Type) = null, check_embed: bool = true) -> TypeEntryMember { if not tpe { return } let name = debug::type_to_str(tpe, full_name = true) @@ -908,6 +914,13 @@ export def iterate_member_functions(tpe: &Type, visited: &Set(&Type) = null) -> for var i in 0..entry.functions.length { yield entry.functions(i) } + if check_embed and entry.tpe and is_struct(entry.tpe) { + for var field in @entry.tpe.fields { + if field.is_embed { + yield from iterate_member_functions(field.tpe, visited) + } + } + } } let module = get_module(tpe) @@ -1256,7 +1269,7 @@ def is_getter(mb: StructuralTypeMember) -> bool { return mb.parameter_t.length == 0 } -def has_function(entry: &TypeEntry, intf: &Type, mb: StructuralTypeMember, module: &toolchain::Module, visited: &Set(&Type) = null) -> bool { +def has_function(entry: &TypeEntry, intf: &Type, mb: StructuralTypeMember, module: &toolchain::Module, visited: &Set(&Type) = null, check_embed: bool = true) -> bool { var tpe = entry.tpe if tpe.kind == typechecking::TypeKind::STRUCT { @@ -1277,10 +1290,15 @@ def has_function(entry: &TypeEntry, intf: &Type, mb: StructuralTypeMember, modul for var field in @tpe.fields { if field.name == mb.name and equals(mb.return_t(0), field.tpe) { return true } } + if tpe.const_fields { + for var field in @tpe.const_fields { + if field.name == mb.name and equals(mb.return_t(0), field.tpe) { return true } + } + } } } - for var member in iterate_member_functions(entry.tpe, visited) { + for var member in iterate_member_functions(entry.tpe, visited, check_embed) { let function = member.function if function.name != mb.name { continue } if module != member.module and not member.exported { continue } @@ -1326,7 +1344,7 @@ def has_function(entry: &TypeEntry, intf: &Type, mb: StructuralTypeMember, modul // Returns true if a implements b // b needs to be a structural type -export def implements(a: &Type, b: &Type, module: &toolchain::Module, visited: &Set(&Type) = null) -> bool { +export def implements(a: &Type, b: &Type, module: &toolchain::Module, visited: &Set(&Type) = null, check_embed: bool = true) -> bool { if not a or not b { return false } assert b.kind == TypeKind::STRUCTURAL if a.kind == typechecking::TypeKind::REFERENCE and equals(b, a.tpe) { @@ -1336,7 +1354,7 @@ export def implements(a: &Type, b: &Type, module: &toolchain::Module, visited: & var type_entry = create_type_entry(a) if not type_entry { return false } let nameb = debug::type_to_str(b, full_name = true) - if type_entry.cached.contains(nameb) { + if type_entry.cached.contains(nameb) and check_embed { return type_entry.cached(nameb) == CacheEntry::CONTAINS } @@ -1378,7 +1396,7 @@ export def implements(a: &Type, b: &Type, module: &toolchain::Module, visited: & found = true } if not found { - found = has_function(type_entry, b, mb, module, visited.copy()) + found = has_function(type_entry, b, mb, module, visited.copy(), check_embed) } if not found { type_entry.cached(nameb) = CacheEntry::NOT_CONTAINS @@ -1391,7 +1409,7 @@ export def implements(a: &Type, b: &Type, module: &toolchain::Module, visited: & for var i in 0..vector::length(b.members) { let mb = b.members(i) - if not has_function(type_entry, b, mb, module, visited.copy()) { + if not has_function(type_entry, b, mb, module, visited.copy(), check_embed) { type_entry.cached(nameb) = CacheEntry::NOT_CONTAINS return false } @@ -1562,6 +1580,20 @@ export def convert_type_score(a: &Type, b: &Type, module: &toolchain::Module, is } } + // Embedded structs + if b.kind == TypeKind::STRUCT or is_ref(b) and b.tpe and b.tpe.kind == TypeKind::STRUCT { + var nb = b.tpe if is_ref(b) else b + for var field in @nb.fields { + if field.is_embed { + if equals(field.tpe, a) { + return 7 + } else if is_ref(a) and equals(field.tpe, a.tpe) { + return 8 + } + } + } + } + // We need to check if they are actually compatible elsewhere if a.kind == TypeKind::TYPE_DEF { return convert_type_score(a.tpe, b, module, true) @@ -2220,7 +2252,7 @@ export def make_stub_type(ident: &parser::Node, state: &State, prefix: bool = tr return tpe } -export def make_union_type(fields: &[StructMember], current_type: &Type = null) -> &Type { +export def make_union_type(fields: &[StructMember], const_fields: &[StructMember], current_type: &Type = null) -> &Type { let tpe = make_type_raw(TypeKind::UNION) if not current_type else current_type let field_types = vector::make(TypeMember) @@ -2249,6 +2281,7 @@ export def make_union_type(fields: &[StructMember], current_type: &Type = null) tpe.size = size tpe.align = align tpe.fields = fields + tpe.const_fields = const_fields if const_fields else allocate_ref(StructMember, 0) tpe.field_types = field_types return tpe @@ -2269,7 +2302,7 @@ export def make_tuple_type(types: &Vector(&Type)) -> &Type { return tpe } -export def make_struct_type(fields: &[StructMember], current_type: &Type = null) -> &Type { +export def make_struct_type(fields: &[StructMember], const_fields: &[StructMember] = null, current_type: &Type = null) -> &Type { let struct_tpe = make_type_raw(TypeKind::STRUCT) if not current_type else current_type let field_types = vector::make(TypeMember) @@ -2366,6 +2399,7 @@ export def make_struct_type(fields: &[StructMember], current_type: &Type = null) struct_tpe.size = offset struct_tpe.align = align struct_tpe.fields = fields + struct_tpe.const_fields = const_fields if const_fields else allocate_ref(StructMember, 0) struct_tpe.field_types = field_types return struct_tpe @@ -2624,7 +2658,9 @@ export def do_type_lookup(node: &parser::Node, state: &State, current_type: &Typ } else if node.kind == parser::NodeKind::STRUCT_T or node.kind == parser::NodeKind::UNION_T { let length = vector::length(node.value.body) - let fields = allocate_ref(StructMember, length) + let fields = vector::make(StructMember) + let const_fields = vector::make(StructMember) + for var i in 0..length { let field = node.value.body(i) if not field { continue } @@ -2634,38 +2670,56 @@ export def do_type_lookup(node: &parser::Node, state: &State, current_type: &Typ var field_tpe: &Type = null var name: Str var is_bitfield = false + var is_embed = field.value.id_decl_struct.is_embed var bit_size = 0 !size_t + var is_const = field.value.id_decl_struct.is_const + var value: &compiler::Value var field_type: &Type = null - if current_type and current_type.kind != TypeKind::STUB { field_type = current_type.fields(i).tpe } + if not is_embed and current_type and current_type.kind != TypeKind::STUB { field_type = current_type.fields(i).tpe } if (@field).kind == parser::NodeKind::ID_DECL_STRUCT { - let ident = (@field).value.id_decl_struct.ident - name = last_ident_to_str(ident) - if field.value.id_decl_struct.tpe { field.value.id_decl_struct.tpe.parent = field } - field_tpe = lookup_field_type((@field).value.id_decl_struct.tpe, state, field_type, cache) - if field.value.id_decl_struct.is_bitfield { - is_bitfield = true - bit_size = field.value.id_decl_struct.bit_size + if field.value.id_decl_struct.is_embed { + let ftpe = field.value.id_decl_struct.tpe + field_tpe = type_lookup(ftpe, state, ftpe.tpe, false, cache) + } else { + let ident = (@field).value.id_decl_struct.ident + name = last_ident_to_str(ident) + if field.value.id_decl_struct.tpe { field.value.id_decl_struct.tpe.parent = field } + field_tpe = lookup_field_type((@field).value.id_decl_struct.tpe, state, field_type, cache) + if field.value.id_decl_struct.is_bitfield { + is_bitfield = true + bit_size = field.value.id_decl_struct.bit_size + } + if is_const { + // Load value for const + value = consteval::expr(field.value.id_decl_struct.value, state) + } } } else if (@field).kind == parser::NodeKind::STRUCT_T or (@field).kind == parser::NodeKind::UNION_T { field_tpe = type_lookup(field, state, field_type, false, cache) } - fields(i) = [ + let member = [ node = field, line = line, name = name, tpe = field_tpe, - is_bitfield = is_bitfield, - bit_size = bit_size + is_bitfield = is_bitfield, + is_embed = is_embed, + bit_size = bit_size, + is_const = is_const, + value = value ] !StructMember + + if is_const { const_fields.push(member) } + else { fields.push(member) } } - var tpe = (make_struct_type(fields, current_type) + var tpe = (make_struct_type(fields.to_array(), const_fields.to_array(), current_type) if node.kind == parser::NodeKind::STRUCT_T - else make_union_type(fields, current_type)) + else make_union_type(fields.to_array(), const_fields.to_array(), current_type)) tpe.line = node.loc.line tpe.type_name = make_unique_name("", state) @@ -3052,7 +3106,7 @@ export def do_type_lookup(node: &parser::Node, state: &State, current_type: &Typ tpe.size = (ceil(offset / biggest_type.align !double) * biggest_type.align) !int tpe.align = util::lcm(align !int, biggest_type.align !int) - tpe._tpe = make_union_type(fields, tpe._tpe) + tpe._tpe = make_union_type(fields, null, tpe._tpe) tpe.variants = set::make(variants) tpe.line = node.loc.line @@ -5579,19 +5633,27 @@ export def lookup_struct_member(member: StructMember, resolved: &SSet = null) { } } -def resolve_member(fields: &[StructMember], name: String) -> &Type { - for var i in 0..fields.size { - let member = fields(i) +def resolve_member(tpe: &Type, name: String) -> &Type { + for var member in @tpe.fields { if member.name { if member.name == name { lookup_struct_member(member) return member.tpe } } else { - let member = resolve_member((@member.tpe).fields, name) + var tpe = member.tpe + if member.is_embed and is_ref(member.tpe) { + tpe = member.tpe.tpe + } + let member = resolve_member(tpe, name) if member { return member } } } + for var member in @tpe.const_fields { + if member.name == name { + return member.tpe + } + } return null } @@ -5651,7 +5713,7 @@ def walk_MemberAccess_aggregate(node: &parser::Node, ucs: bool, state: &State) - if (@tpe).kind == TypeKind::STRUCT or (@tpe).kind == TypeKind::UNION { let name = last_ident_to_str(right) - var rtpe = resolve_member((@tpe).fields, name) + var rtpe = resolve_member(tpe, name) if not rtpe { if ucs { if walk_MemberAccess_ucs(node, state) { return true } @@ -5687,6 +5749,26 @@ def walk_MemberAccess_aggregate(node: &parser::Node, ucs: bool, state: &State) - return true } +// TODO Make this a generator, doesn't work for some reason +def flatten_fields(tpe: &Type, state: &State, v: &Vector(StructMember) = null) -> &Vector(StructMember) { + if not v { v = vector::make(StructMember) } + if not tpe { return v } + if not is_struct(tpe) { + errors::errorn(tpe.node, "Invalid embed, needs to be a struct or a reference to a struct") + return v + } + + for var field in @tpe.fields { + if field.is_embed { + let tpe = field.tpe.tpe if is_box(field.tpe) else field.tpe + flatten_fields(tpe, state, v) + } else { + v.push(field) + } + } + return v +} + def walk_StructLit(node: &parser::Node, state: &State) { let prev_tpe = node.tpe var tpe = prev_tpe @@ -5841,8 +5923,7 @@ def walk_StructLit(node: &parser::Node, state: &State) { let name = last_ident_to_str((@kwarg).value.named_arg.name) var found = false - for var j in 0..(@tpe).fields.size { - let field = (@tpe).fields(j) + for var field in flatten_fields(tpe, state) { if field.name == name { found = true diff --git a/test/test_compiler.pr b/test/test_compiler.pr index 5e6b4594..7ba80c25 100644 --- a/test/test_compiler.pr +++ b/test/test_compiler.pr @@ -603,6 +603,22 @@ def #test test_struct { assert D("align").as_double() == 8 } +def #test test_struct_const { + var src = """ + type A = struct { + const x: int = 20 + const y: int = 30 + a: int + b: int + } + + let a = [ a = 20, b = 30 ] !A + let x = a.x + let y = a.y + """ + assert compile(src) != null +} + def #test test_struct_lit { let src = """ type T = struct { @@ -1055,4 +1071,26 @@ def #test test_lambda { make_adder(10)(20)(30) """ assert compile(str) != null +} + +def #test test_embed { + var str = """ + type E = struct { + a: int + b: int + } + + type T = struct { + E + x: int + y: int + } + + def takes_e(e: E) {} + + let t = [a = 10, b = 20, x = 1, y = 2] !T + + takes_e(t) + """ + assert compile(str) != null } \ No newline at end of file diff --git a/test/test_parser.pr b/test/test_parser.pr index 8d24e093..ec261437 100644 --- a/test/test_parser.pr +++ b/test/test_parser.pr @@ -3001,6 +3001,181 @@ def #test test_struct { } ] }""")) + + assert parse(""" + type T = struct { + const i: int = 20 + } + """) == program(json::parse("""{ + "kind": "TypeDecl", + "share": "NONE", + "left": [ + { + "kind": "Identifier", + "path": [ + "T" + ], + "prefixed": false, + "args": null + } + ], + "right": [ + { + "kind": "Struct", + "body": [ + { + "kind": "IdDeclStruct", + "value": { + "kind": "Integer", + "value": 20.000000 + }, + "ident": { + "kind": "Identifier", + "path": [ + "i" + ], + "prefixed": false, + "args": null + }, + "tpe": { + "kind": "Identifier", + "path": [ + "int" + ], + "prefixed": false, + "args": null + } + } + ] + } + ] + }""")) +} + +def #test test_embed { + assert parse(""" + type T = struct { + X + a: int + } + """) == program(json::parse("""{ + "kind": "TypeDecl", + "share": "NONE", + "left": [ + { + "kind": "Identifier", + "path": [ + "T" + ], + "prefixed": false, + "args": null + } + ], + "right": [ + { + "kind": "Struct", + "body": [ + { + "kind": "IdDeclStruct", + "is_embed": true, + "ident": null, + "tpe": { + "kind": "Identifier", + "path": [ + "X" + ], + "prefixed": false, + "args": null + } + }, + { + "kind": "IdDeclStruct", + "ident": { + "kind": "Identifier", + "path": [ + "a" + ], + "prefixed": false, + "args": null + }, + "tpe": { + "kind": "Identifier", + "path": [ + "int" + ], + "prefixed": false, + "args": null + } + } + ] + } + ] + } + """)) + + assert parse(""" + type T = struct { + &X + a: int + } + """) == program(json::parse("""{ + "kind": "TypeDecl", + "share": "NONE", + "left": [ + { + "kind": "Identifier", + "path": [ + "T" + ], + "prefixed": false, + "args": null + } + ], + "right": [ + { + "kind": "Struct", + "body": [ + { + "kind": "IdDeclStruct", + "is_embed": true, + "ident": null, + "tpe": { + "kind": "RefT", + "kw": "VAR", + "tpe": { + "kind": "Identifier", + "path": [ + "X" + ], + "prefixed": false, + "args": null + } + } + }, + { + "kind": "IdDeclStruct", + "ident": { + "kind": "Identifier", + "path": [ + "a" + ], + "prefixed": false, + "args": null + }, + "tpe": { + "kind": "Identifier", + "path": [ + "int" + ], + "prefixed": false, + "args": null + } + } + ] + } + ] + } + """)) } def #test test_interface { diff --git a/test/test_runtime.pr b/test/test_runtime.pr index 29206e7f..a1de6e3f 100644 --- a/test/test_runtime.pr +++ b/test/test_runtime.pr @@ -869,5 +869,81 @@ def #test test_lambda { assert make_adder(10)(20)(30) == 60 """ + assert run_source(str) == 0 +} + +def #test test_embed { + var str = """ + type Base = struct { a: int } + type T = struct { + Base + b: int + } + type U = struct { + &Base + b: int + } + + def takes_1(b: &Base) -> int { return b.a } + def takes_2(b: Base) -> int { return b.a } + + let t = [a = 10, b = 20] !T + let u = [a = 20, b = 30] !U + + assert takes_1(t) == 10 + assert takes_1(u) == 20 + assert takes_2(t) == 10 + + assert t.a == 10 + assert u.a == 20 + assert t.b == 20 + assert u.b == 30 + """ + + assert run_source(str) == 0 + + str = """ + type I = interface { def method } + type Base = struct { a: int } + type E = struct { + &Base + b: int + } + + def method(b: &Base) {} + def takes(i: &I) { + i.method() + } + + let e = [ a = 20, b = 30 ] !E + takes(e) + """ + + assert run_source(str) == 0 +} + +def #test test_struct_const { + var str = """ + type Point = interface { + let x: int + let y: int + } + + type X = struct { + const x: int = 20 + const y: int = 30 + } + + def takes(p: &Point) { + assert p.x == 20 + assert p.y == 30 + } + + assert (size_of X) == 0 + + let x = [] !&X + takes(x) + """ + assert run_source(str) == 0 } \ No newline at end of file