diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index 775f35b58..49449143d 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -673,3 +673,158 @@ for i, name in ipairs({"records", "arrayrecords"}) do ]])) end) end + +describe("embedding", function() + it("embed is not a reserved word", util.check [[ + local record A + embed: number + end + ]]) + it("should make members of a record available to another", util.check [[ + local record A + userdata + a_value: number + end + local record B + embed A + b_value: number + end + local record C + embed B + c_value: string + end + local b: B = {} + local c: C = {} + print(b.a_value, c.a_value + c.b_value, c.c_value) + local function f(a: A, b: B) end + f(b, b) + f(c, c) + local a: A = b + a = c + ]]) + it("should allow for generics", util.check [[ + local record A + property_of_a: T + end + local record B + embed A + property_of_b: T + end + local record C + embed B + property_of_c: string + end + local b: B = {} + local c: C = {} + print(b.property_of_a + 1, c.property_of_b + c.property_of_a, c.property_of_c) + local function f(a: A, b: B) end + f(b, b) + f(c, c) + local a: A = b + a = c + ]]) + it("embed multiple records", util.check [[ + local record A + x: number + end + local record B + y: string + end + local record C + embed A + embed B + z: boolean + end + local c: C = { x = 1, y = "hello", z = true } + print(c.x, c.y, c.z) + local function f(a: A, b: B, c: C) end + f(c, c, c) + ]]) + it("share the same element type with embeded arrayrecord", util.check [[ + local record A + { string } + x: number + end + local record B + embed A + y: string + end + local record C + embed B + z: boolean + end + local c: C = { x = 1, y = "hello", z = true } + local str: string = c[1] + print(c.x, c.y, c.z) + local strs: {string} = c + local b: B = c + strs = b + ]]) + it("subrecords should not be equal", util.check_type_error([[ + local record A + end + local record B + embed A + end + local record C + embed A + end + local a: A + local b: B + local c: C + a = b + a = c + b = c + ]], { + { y = 14, msg = "in assignment: C is not a B" } + })) +end) + +describe("const field", function() + it("const is not a reserved word", util.check [[ + local record A + const: boolean + end + ]]) + it("cannot assign to a const field", util.check_type_error([[ + local record A + const ["a field"]: string + const x: number + end + local record B + embed A + b: A + end + local r: B = {b = {["a field"] = "abc"}} + r.b[ [=[a field]=] ] = 456 + r.x = 123 + ]], { + { y = 10, msg = "cannot assign to const field \"a field\"" }, + { y = 10, msg = "in assignment: got integer, expected string" }, + { y = 11, msg = "cannot assign to const field \"x\"" } + })) + it("should work for generics", util.check_type_error([[ + local record A + const a: T + end + local record B + embed A + const b: T + end + local b: B = { a = 1, b = 2 } + print(b.a, b.b) + b.a, b.b = 2, "3" + ]], { + { y = 10, msg = "cannot assign to const field \"a\"" }, + { y = 10, msg = "cannot assign to const field \"b\"" }, + { y = 10, msg = "in assignment: got string \"3\", expected number" } + })) + it("it is OK to modify const field with rawset", util.check [[ + local record R + const x: boolean + end + local r: R = { x = true } + rawset(r, "x", false) + ]]) +end) + diff --git a/tl.lua b/tl.lua index d369d0818..d8be7fb81 100644 --- a/tl.lua +++ b/tl.lua @@ -1030,6 +1030,9 @@ local Type = {} + + + @@ -2563,6 +2566,16 @@ parse_record_body = function(ps, i, def, node) i = parse_nested_type(ps, i, def, "record", parse_record_body) elseif ps.tokens[i].tk == "enum" and ps.tokens[i + 1].tk ~= ":" then i = parse_nested_type(ps, i, def, "enum", parse_enum_body) + elseif ps.tokens[i].tk == "embed" and ps.tokens[i + 1].tk ~= ":" then + local t + i, t = parse_type(ps, i + 1) + if not t then + return fail(ps, i, "expected a type") + end + if not def.embeds then + def.embeds = {} + end + table.insert(def.embeds, t) else local is_metamethod = false if ps.tokens[i].tk == "metamethod" and ps.tokens[i + 1].tk ~= ":" then @@ -2570,6 +2583,12 @@ parse_record_body = function(ps, i, def, node) i = i + 1 end + local is_const = false + if ps.tokens[i].tk == "const" and ps.tokens[i + 1].tk ~= ":" then + is_const = true + i = i + 1 + end + local v if ps.tokens[i].tk == "[" then i, v = parse_literal(ps, i + 1) @@ -2607,6 +2626,12 @@ parse_record_body = function(ps, i, def, node) fail(ps, i - 1, "not a valid metamethod: " .. field_name) end end + if is_const then + if not def.readonlys then + def.readonlys = {} + end + def.readonlys[field_name] = true + end store_field_in_record(ps, iv, field_name, t, fields, field_order) elseif ps.tokens[i].tk == "=" then local next_word = ps.tokens[i + 1].tk @@ -3025,6 +3050,11 @@ local function recurse_type(ast, visit) table.insert(xs, recurse_type(child, visit)) end end + if ast.embeds then + for _, child in ipairs(ast.embeds) do + table.insert(xs, recurse_type(child, visit)) + end + end if ast.args then for i, child in ipairs(ast.args) do if i > 1 or not ast.is_method then @@ -4349,6 +4379,13 @@ local function show_type_base(t, short, seen) if t.elements then table.insert(out, "{" .. show(t.elements) .. "}") end + if t.embeds then + local es = {} + for _, k in ipairs(t.embeds) do + table.insert(es, show(k)) + end + table.insert(out, "(embed " .. table.concat(es, ", ") .. ")") + end local fs = {} for _, k in ipairs(t.field_order) do local v = t.fields[k] @@ -5508,6 +5545,13 @@ tl.type_check = function(ast, opts) end end + if t.embeds then + copy.embeds = {} + for i, tf in ipairs(t.embeds) do + copy.embeds[i] = resolve(tf) + end + end + if t.elements then copy.elements = resolve(t.elements) end @@ -5527,6 +5571,8 @@ tl.type_check = function(ast, opts) copy.meta_fields[k] = resolve(t.meta_fields[k]) end end + + copy.readonlys = t.readonlys elseif t.typename == "map" then copy.keys = resolve(t.keys) copy.values = resolve(t.values) @@ -5966,6 +6012,83 @@ tl.type_check = function(ast, opts) return typ end + local function are_disjoint(rec1, rec2) + local disjoint = true + local shared_fields = {} + for field, t in pairs(rec1.fields) do + if rec2.fields[field] and rec2.fields[field] ~= t then + shared_fields[field] = true + disjoint = false + end + end + for field, t in pairs(rec2.fields) do + if rec1.fields[field] and rec1.fields[field] ~= t then + shared_fields[field] = true + disjoint = false + end + end + return disjoint, shared_fields + end + + local function can_embed(t1, t2) + assert(is_record_type(t1), "only records can have embeds") + if is_record_type(t2) then + local compat, fields = are_disjoint(t1, t2) + if compat then + return compat + else + local str = {} + for f in pairs(fields) do + table.insert(str, f) + end + return nil, "records share fields: " .. table.concat(str, ", ") + end + end + return false, t2.typename .. " can't be embedded" + end + + local resolve_tuple_and_nominal = nil + + local function resolve_embeds(t) + if not t.embeds or t.embeds_resolved then + return t + end + t.embeds_resolved = true + for i = 1, #t.embeds do + local e = resolve_tuple_and_nominal(t.embeds[i]) + local compat, reason = can_embed(t, e) + if not compat then + type_error(t, "invalid embedding: " .. reason) + end + if e.is_userdata then + t.is_userdata = true + end + if t.elements == nil and e.elements ~= nil then + t.elements = e.elements + t.typename = "arrayrecord" + end + if t.elements and e.elements and not same_type(t.elements, e.elements) then + type_error(t, "Can not do embedding between arrayrecords with different element types") + end + t.embeds[i] = e + if e.fields then + for fname, f in pairs(e.fields) do + t.fields[fname] = f + table.insert(t.field_order, fname) + end + end + if e.readonlys then + if not t.readonlys then + t.readonlys = {} + end + for k, v in pairs(e.readonlys) do + t.readonlys[k] = v + end + end + end + return t + end + local resolve_nominal do local function match_typevals(t, def) @@ -6031,7 +6154,27 @@ tl.type_check = function(ast, opts) end end + local function is_embedded(t1, t2) + t2 = resolve_tuple_and_nominal(t2) + t1 = resolve_tuple_and_nominal(t1) + local embeds = t1.embeds + if embeds and #embeds > 0 then + for i = 1, #embeds do + if embeds[i].typeid == t2.typeid then + return true + elseif embeds[i].embeds then + return is_embedded(embeds[i], t2) + end + end + end + return false + end + local function are_same_nominals(t1, t2) + if is_embedded(t1, t2) then + return true + end + local same_names if t1.found and t2.found then same_names = t1.found.typeid == t2.found.typeid @@ -6084,7 +6227,6 @@ tl.type_check = function(ast, opts) end local is_known_table_type - local resolve_tuple_and_nominal = nil same_type = function(t1, t2) @@ -6102,6 +6244,11 @@ tl.type_check = function(ast, opts) if t1.typename ~= t2.typename then return false, terr(t1, "got %s, expected %s", t1, t2) end + + if is_embedded(t1, t2) then + return true + end + if t1.typename == "array" then return same_type(t1.elements, t2.elements) elseif t1.typename == "tupletable" then @@ -6354,6 +6501,8 @@ tl.type_check = function(ast, opts) end end return false, terr(t1, "cannot match against any alternatives of the polymorphic type") + elseif is_embedded(t1, t2) then + return true elseif t1.typename == "nominal" and t2.typename == "nominal" then local same, err = are_same_nominals(t1, t2) if same then @@ -6959,6 +7108,7 @@ tl.type_check = function(ast, opts) t = resolve_nominal(t) end assert(t.typename ~= "nominal") + t = resolve_embeds(t) return t end @@ -7991,6 +8141,15 @@ tl.type_check = function(ast, opts) if is_typetype(resolve_tuple_and_nominal(vartype)) then node_error(varnode, "cannot reassign a type") elseif val then + if varnode.kind == "op" and (varnode.op.op == "." or varnode.op.op == "@index") then + local t = resolve_tuple_and_nominal(varnode.e1.type) + if t.typename == "record" and t.readonlys then + local key = varnode.e2.conststr or varnode.e2.tk + if t.readonlys[key] then + node_error(varnode, "cannot assign to const field \"" .. key .. '"') + end + end + end assert_is_a(varnode, val, vartype, "in assignment") if varnode.kind == "variable" and vartype.typename == "union" then diff --git a/tl.tl b/tl.tl index d611353e7..df9fb2c7d 100644 --- a/tl.tl +++ b/tl.tl @@ -987,6 +987,9 @@ local record Type meta_fields: {string: Type} meta_field_order: {string} is_userdata: boolean + embeds: {Type} + embeds_resolved: boolean + readonlys: {string: boolean} -- array elements: Type @@ -2563,6 +2566,16 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node): i = parse_nested_type(ps, i, def, "record", parse_record_body) elseif ps.tokens[i].tk == "enum" and ps.tokens[i+1].tk ~= ":" then i = parse_nested_type(ps, i, def, "enum", parse_enum_body) + elseif ps.tokens[i].tk == "embed" and ps.tokens[i+1].tk ~= ":" then + local t: Type + i, t = parse_type(ps, i + 1) + if not t then + return fail(ps, i, "expected a type") + end + if not def.embeds then + def.embeds = {} + end + table.insert(def.embeds, t) else local is_metamethod = false if ps.tokens[i].tk == "metamethod" and ps.tokens[i+1].tk ~= ":" then @@ -2570,6 +2583,12 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node): i = i + 1 end + local is_const = false + if ps.tokens[i].tk == "const" and ps.tokens[i+1].tk ~= ":" then + is_const = true + i = i + 1 + end + local v: Node if ps.tokens[i].tk == "[" then i, v = parse_literal(ps, i+1) @@ -2607,6 +2626,12 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node): fail(ps, i - 1, "not a valid metamethod: " .. field_name) end end + if is_const then + if not def.readonlys then + def.readonlys = {} + end + def.readonlys[field_name] = true + end store_field_in_record(ps, iv, field_name, t, fields, field_order) elseif ps.tokens[i].tk == "=" then local next_word = ps.tokens[i + 1].tk @@ -3025,6 +3050,11 @@ local function recurse_type(ast: Type, visit: Visitor): T table.insert(xs, recurse_type(child, visit)) end end + if ast.embeds then + for _, child in ipairs(ast.embeds) do + table.insert(xs, recurse_type(child, visit)) + end + end if ast.args then for i, child in ipairs(ast.args) do if i > 1 or not ast.is_method then @@ -4349,6 +4379,13 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str if t.elements then table.insert(out, "{" .. show(t.elements) .. "}") end + if t.embeds then + local es = {} + for _, k in ipairs(t.embeds) do + table.insert(es, show(k)) + end + table.insert(out, "(embed " .. table.concat(es, ", ") .. ")") + end local fs = {} for _, k in ipairs(t.field_order) do local v = t.fields[k] @@ -5508,6 +5545,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result end end + if t.embeds then + copy.embeds = {} + for i, tf in ipairs(t.embeds) do + copy.embeds[i] = resolve(tf) + end + end + if t.elements then copy.elements = resolve(t.elements) end @@ -5527,6 +5571,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result copy.meta_fields[k] = resolve(t.meta_fields[k]) end end + + copy.readonlys = t.readonlys elseif t.typename == "map" then copy.keys = resolve(t.keys) copy.values = resolve(t.values) @@ -5966,6 +6012,83 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result return typ end + local function are_disjoint(rec1: Type, rec2: Type): boolean, {string:boolean} + local disjoint = true + local shared_fields: {string:boolean} = {} + for field, t in pairs(rec1.fields) do + if rec2.fields[field] and rec2.fields[field] ~= t then + shared_fields[field] = true + disjoint = false + end + end + for field, t in pairs(rec2.fields) do + if rec1.fields[field] and rec1.fields[field] ~= t then + shared_fields[field] = true + disjoint = false + end + end + return disjoint, shared_fields + end + + local function can_embed(t1: Type, t2: Type): boolean, string + assert(is_record_type(t1), "only records can have embeds") + if is_record_type(t2) then + local compat, fields = are_disjoint(t1, t2) + if compat then + return compat + else + local str: {string} = {} + for f in pairs(fields) do + table.insert(str, f) + end + return nil, "records share fields: " .. table.concat(str, ", ") + end + end + return false, t2.typename .. " can't be embedded" + end + + local resolve_tuple_and_nominal: function(t: Type): Type = nil + + local function resolve_embeds(t: Type): Type + if not t.embeds or t.embeds_resolved then + return t + end + t.embeds_resolved = true + for i = 1, #t.embeds do + local e = resolve_tuple_and_nominal(t.embeds[i]) + local compat, reason = can_embed(t, e) + if not compat then + type_error(t, "invalid embedding: " .. reason) + end + if e.is_userdata then + t.is_userdata = true + end + if t.elements == nil and e.elements ~= nil then + t.elements = e.elements + t.typename = "arrayrecord" + end + if t.elements and e.elements and not same_type(t.elements, e.elements) then + type_error(t, "Can not do embedding between arrayrecords with different element types") + end + t.embeds[i] = e + if e.fields then + for fname, f in pairs(e.fields) do + t.fields[fname] = f + table.insert(t.field_order, fname) + end + end + if e.readonlys then + if not t.readonlys then + t.readonlys = {} + end + for k, v in pairs(e.readonlys) do + t.readonlys[k] = v + end + end + end + return t + end + local resolve_nominal: function(t: Type): Type do local function match_typevals(t: Type, def: Type): Type @@ -6031,7 +6154,27 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result end end + local function is_embedded(t1: Type, t2: Type): boolean + t2 = resolve_tuple_and_nominal(t2) + t1 = resolve_tuple_and_nominal(t1) + local embeds = t1.embeds + if embeds and #embeds > 0 then + for i = 1, #embeds do + if embeds[i].typeid == t2.typeid then + return true + elseif embeds[i].embeds then + return is_embedded(embeds[i], t2) + end + end + end + return false + end + local function are_same_nominals(t1: Type, t2: Type): boolean, {Error} + if is_embedded(t1, t2) then + return true + end + local same_names: boolean if t1.found and t2.found then same_names = t1.found.typeid == t2.found.typeid @@ -6084,7 +6227,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result end local is_known_table_type: function(t: Type): boolean - local resolve_tuple_and_nominal: function(t: Type): Type = nil -- invariant type comparison same_type = function(t1: Type, t2: Type): boolean, {Error} @@ -6102,6 +6244,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result if t1.typename ~= t2.typename then return false, terr(t1, "got %s, expected %s", t1, t2) end + + if is_embedded(t1, t2) then + return true + end + if t1.typename == "array" then return same_type(t1.elements, t2.elements) elseif t1.typename == "tupletable" then @@ -6354,6 +6501,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result end end return false, terr(t1, "cannot match against any alternatives of the polymorphic type") + elseif is_embedded(t1, t2) then + return true elseif t1.typename == "nominal" and t2.typename == "nominal" then local same, err = are_same_nominals(t1, t2) if same then @@ -6959,6 +7108,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result t = resolve_nominal(t) end assert(t.typename ~= "nominal") + t = resolve_embeds(t) return t end @@ -7991,6 +8141,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result if is_typetype(resolve_tuple_and_nominal(vartype)) then node_error(varnode, "cannot reassign a type") elseif val then + if varnode.kind == "op" and (varnode.op.op == "." or varnode.op.op == "@index") then + local t = resolve_tuple_and_nominal(varnode.e1.type) + if t.typename == "record" and t.readonlys then + local key = varnode.e2.conststr or varnode.e2.tk + if t.readonlys[key] then + node_error(varnode, "cannot assign to const field \"" .. key .. '"') + end + end + end assert_is_a(varnode, val, vartype, "in assignment") if varnode.kind == "variable" and vartype.typename == "union" then -- narrow union