From e69d91bb2243a1b633567cc7c7b733c31d03e9f0 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 23 Jul 2024 11:31:11 -0300 Subject: [PATCH] we can now reexport nested types Fixes #765. --- spec/declaration/record_spec.lua | 34 +++++++ spec/util.lua | 6 ++ tl.lua | 130 +++++++++++++++++-------- tl.tl | 158 ++++++++++++++++++++----------- 4 files changed, 236 insertions(+), 92 deletions(-) diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index a188a482c..2af252bcc 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -555,6 +555,40 @@ for i, name in ipairs({"records", "arrayrecords", "interfaces", "arrayinterfaces ]])() end) + it("can reexport types as nested " .. name, function() + util.mock_io(finally, { + ["inner.tl"] = [[ + local record inner + ]]..statement..[[ SubType ]]..array(i, "{integer}")..[[ + item: K + end + end + + return inner + ]], + ["outer.tl"] = [[ + local core = require("inner") + + local record mod + f: string + type SubType = core.SubType + end + + return mod + ]], + }) + util.run_check_type_error([[ + local mod = require("outer") + + print(mod.f) + local v: mod.SubType = { + item = "hello" + } + ]], { + { msg = 'in record field: item: got string "hello", expected integer' } + }) + end) + it("resolves aliasing of nested " .. name .. " (see #400)", util.check([[ local ]]..statement..[[ Foo ]]..array(i, "{Foo}")..[[ ]]..statement..[[ Bar ]]..array(i, "{Bar}")..[[ diff --git a/spec/util.lua b/spec/util.lua index 25cbdae50..bb6b9af4b 100644 --- a/spec/util.lua +++ b/spec/util.lua @@ -436,6 +436,12 @@ local function check(lax, code, unknowns, gen_target) gen_compat = "off" end local result = tl.type_check(ast, "foo.lua", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) + + for _, mname in pairs(result.env.loaded_order) do + local mresult = result.env.loaded[mname] + batch:add(assert.same, {}, mresult.syntax_errors or {}, "Code was not expected to have syntax errors") + end + batch:add(assert.same, {}, result.type_errors) if unknowns then diff --git a/tl.lua b/tl.lua index f48911d36..68b6d7a69 100644 --- a/tl.lua +++ b/tl.lua @@ -1889,7 +1889,6 @@ end - local TruthyFact = {} @@ -2216,6 +2215,7 @@ do local parse_argument_list local parse_argument_type_list local parse_type + local parse_type_declaration local parse_newtype local parse_interface_name @@ -3771,14 +3771,19 @@ do elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then i = i + 1 local iv = i - local v - i, v = verify_kind(ps, i, "identifier", "type_identifier") + + local lt + i, lt = parse_type_declaration(ps, i, "local_type") + if not lt then + return fail(ps, i, "expected a type definition") + end + + local v = lt.var if not v then return fail(ps, i, "expected a variable name") end - i = verify_tk(ps, i, "=") - local nt - i, nt = parse_newtype(ps, i) + + local nt = lt.value if not nt or not nt.newtype then return fail(ps, i, "expected a type definition") end @@ -4000,9 +4005,7 @@ do return i, asgn end - local function parse_type_declaration(ps, i, node_name) - i = i + 2 - + parse_type_declaration = function(ps, i, node_name) local asgn = new_node(ps, i, node_name) local var @@ -4048,6 +4051,10 @@ do def.declname = asgn.var.tk end end + elseif nt.typename == "typealias" then + if typeargs then + nt.typeargs = typeargs + end end return i, asgn @@ -4078,7 +4085,7 @@ do end local function skip_type_declaration(ps, i) - return parse_type_declaration(ps, i - 1, "local_type") + return parse_type_declaration(ps, i + 1, "local_type") end local function parse_local_macroexp(ps, i) @@ -4097,7 +4104,7 @@ do if ntk == "function" then return parse_local_function(ps, i) elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then - return parse_type_declaration(ps, i, "local_type") + return parse_type_declaration(ps, i + 2, "local_type") elseif ntk == "macroexp" and ps.tokens[i + 2].kind == "identifier" then return parse_local_macroexp(ps, i) elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then @@ -4112,7 +4119,7 @@ do if ntk == "function" then return parse_function(ps, i + 1, "global") elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then - return parse_type_declaration(ps, i, "global_type") + return parse_type_declaration(ps, i + 2, "global_type") elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then return parse_type_constructor(ps, i, "global_type", tn, parse_type_body_fns[tn]) elseif ps.tokens[i + 1].kind == "identifier" then @@ -4434,6 +4441,11 @@ local function recurse_type(s, ast, visit) table.insert(xs, recurse_type(s, ast.vtype, visit)) end elseif ast.typename == "typealias" then + if ast.typeargs then + for _, child in ipairs(ast.typeargs) do + table.insert(xs, recurse_type(s, child, visit)) + end + end table.insert(xs, recurse_type(s, ast.alias_to, visit)) elseif ast.typename == "typedecl" then if ast.typeargs then @@ -6998,7 +7010,7 @@ do fresh_typevar_ctr = fresh_typevar_ctr + 1 local ok ok, t = typevar_resolver(nil, t, fresh_typevar, fresh_typearg) - assert(ok, "Internal Compiler Error: error creating fresh type variables") + assert(ok and t, "Internal Compiler Error: error creating fresh type variables") return t end @@ -7711,6 +7723,10 @@ do found = found.alias_to.found end + if not found then + return self.errs:invalid_at(t, table.concat(t.names, ".") .. " is not a resolved type") + end + if not (found.typename == "typedecl") then return self.errs:invalid_at(t, table.concat(t.names, ".") .. " is not a type") end @@ -7731,7 +7747,7 @@ do local function resolve_decl_into_nominal(self, t, found) local def = found.def local resolved - if def.typename == "record" or def.typename == "function" then + if def.fields or def.typename == "function" then resolved = match_typevals(self, t, def) if not resolved then return self.errs:invalid_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") @@ -7757,7 +7773,7 @@ do local t = typealias.alias_to local immediate, found = find_nominal_type_decl(self, t) - if immediate then + if type(immediate) == "table" then return immediate end @@ -8132,6 +8148,16 @@ do end + function TypeChecker:forall_are_subtype_of(xs, t) + for _, x in ipairs(xs.types) do + if not self:is_a(x, t) then + return false + end + end + return true + end + + local emptytable_relations = { ["array"] = compare_true, ["map"] = compare_true, @@ -8263,6 +8289,15 @@ do ["*"] = compare_true, }, ["union"] = { + ["nominal"] = function(self, a, b) + + local rb = self:resolve_nominal(b) + if rb.typename == "union" then + return self:is_a(a, rb) + end + + return self:forall_are_subtype_of(a, b) + end, ["union"] = function(self, a, b) local used = {} for _, t in ipairs(a.types) do @@ -8281,14 +8316,7 @@ do end return true end, - ["*"] = function(self, a, b) - for _, t in ipairs(a.types) do - if not self:is_a(t, b) then - return false - end - end - return true - end, + ["*"] = TypeChecker.forall_are_subtype_of, }, ["poly"] = { ["*"] = function(self, a, b) @@ -8305,21 +8333,36 @@ do return true end - local rb = self:resolve_nominal(b) - if rb.typename == "interface" then + local ra = self:resolve_nominal(a) + local rb = self:resolve_nominal(b) + if ra.typename == "union" and rb.typename == "union" then + return self:is_a(ra, rb) + end + if ra.typename == "union" then + return self:is_a(ra, b) + end + if rb.typename == "union" then return self:is_a(a, rb) end - local ra = self:resolve_nominal(a) - if ra.typename == "union" or rb.typename == "union" then - return self:is_a(ra, rb) + if rb.typename == "interface" then + return self:is_a(a, rb) end return ok, errs end, + ["union"] = function(self, a, b) + + local ra = self:resolve_nominal(a) + if ra.typename == "union" then + return self:is_a(ra, b) + end + + return not not self:exists_supertype_in(a, b) + end, ["*"] = TypeChecker.subtype_nominal, }, ["enum"] = { @@ -10518,7 +10561,10 @@ self:expand_type(node, values, elements) }) value.e1.tk == "require" then local t = special_functions["require"](self, value, self:find_var_type("require"), a_type(value.e2, "tuple", { tuple = { a_type(value.e2[1], "string", {}) } }), 0) + + local ty = t.typename == "tuple" and t.tuple[1] or t + ty = (ty.typename == "typealias") and self:resolve_typealias(ty) or ty local td = (ty.typename == "typedecl") and ty or a_type(value, "typedecl", { def = ty }) return td @@ -10660,6 +10706,9 @@ self:expand_type(node, values, elements) }) if node.value then local resolved, aliasing = self:get_typedecl(node.value) local added = self:add_global(node.var, name, resolved) + if resolved.typename == "invalid" then + return + end node.value.newtype = resolved if aliasing then added.aliasing = aliasing @@ -12143,18 +12192,19 @@ self:expand_type(node, values, elements) }) end end + local visit_type_with_typeargs = { + before = function(self, _typ) + self:begin_scope() + end, + after = function(self, typ, _children) + self:end_scope() + return self:ensure_fresh_typeargs(typ) + end, + } + local visit_type visit_type = { cbs = { - ["function"] = { - before = function(self, _typ) - self:begin_scope() - end, - after = function(self, typ, _children) - self:end_scope() - return self:ensure_fresh_typeargs(typ) - end, - }, ["record"] = { before = function(self, typ) self:begin_scope() @@ -12332,11 +12382,13 @@ self:expand_type(node, values, elements) }) } visit_type.cbs["interface"] = visit_type.cbs["record"] - visit_type.cbs["typedecl"] = visit_type.cbs["function"] + + visit_type.cbs["function"] = visit_type_with_typeargs + visit_type.cbs["typedecl"] = visit_type_with_typeargs + visit_type.cbs["typealias"] = visit_type_with_typeargs visit_type.cbs["string"] = default_type_visitor visit_type.cbs["tupletable"] = default_type_visitor - visit_type.cbs["typealias"] = default_type_visitor visit_type.cbs["array"] = default_type_visitor visit_type.cbs["map"] = default_type_visitor visit_type.cbs["enum"] = default_type_visitor diff --git a/tl.tl b/tl.tl index 7d1c5d180..8722ec297 100644 --- a/tl.tl +++ b/tl.tl @@ -1607,17 +1607,23 @@ local record BooleanType where self.typename == "boolean" end -local record TypeDeclType +local interface HasTypeArgs is Type - where self.typename == "typedecl" + where self.typeargs typeargs: {TypeArgType} +end + +local record TypeDeclType + is Type, HasTypeArgs + where self.typename == "typedecl" + def: Type closed: boolean end local record TypeAliasType - is Type + is Type, HasTypeArgs where self.typename == "typealias" alias_to: NominalType @@ -1645,13 +1651,6 @@ local record Scope narrows: {string:boolean} end -local interface HasTypeArgs - is Type - where self.typeargs - - typeargs: {TypeArgType} -end - local interface HasDeclName declname: string end @@ -2216,6 +2215,7 @@ local parse_statements: function(ParseState, integer, ? boolean): integer, Node local parse_argument_list: function(ParseState, integer): integer, Node, integer local parse_argument_type_list: function(ParseState, integer): integer, TupleType, boolean, integer local parse_type: function(ParseState, integer): integer, Type, integer +local parse_type_declaration: function(ps: ParseState, i: integer, node_name: NodeKind): integer, Node local parse_newtype: function(ps: ParseState, i: integer): integer, Node local parse_interface_name: function(ps: ParseState, i: integer): integer, Type, integer @@ -3771,14 +3771,19 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, no elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then i = i + 1 local iv = i - local v: Node - i, v = verify_kind(ps, i, "identifier", "type_identifier") + + local lt: Node + i, lt = parse_type_declaration(ps, i, "local_type") -- local_type Node will be discarded + if not lt then + return fail(ps, i, "expected a type definition") + end + + local v = lt.var if not v then return fail(ps, i, "expected a variable name") end - i = verify_tk(ps, i, "=") - local nt: Node - i, nt = parse_newtype(ps, i) + + local nt = lt.value if not nt or not nt.newtype then return fail(ps, i, "expected a type definition") end @@ -4000,9 +4005,7 @@ local function parse_variable_declarations(ps: ParseState, i: integer, node_name return i, asgn end -local function parse_type_declaration(ps: ParseState, i: integer, node_name: NodeKind): integer, Node - i = i + 2 -- skip `local` or `global`, and `type` - +parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKind): integer, Node local asgn: Node = new_node(ps, i, node_name) local var: Node @@ -4048,6 +4051,10 @@ local function parse_type_declaration(ps: ParseState, i: integer, node_name: Nod def.declname = asgn.var.tk end end + elseif nt is TypeAliasType then + if typeargs then + nt.typeargs = typeargs + end end return i, asgn @@ -4078,7 +4085,7 @@ local function parse_type_constructor(ps: ParseState, i: integer, node_name: Nod end local function skip_type_declaration(ps: ParseState, i: integer): integer, Node - return parse_type_declaration(ps, i - 1, "local_type") + return parse_type_declaration(ps, i + 1, "local_type") end local function parse_local_macroexp(ps: ParseState, i: integer): integer, Node @@ -4096,8 +4103,8 @@ local function parse_local(ps: ParseState, i: integer): integer, Node local tn = ntk as TypeName if ntk == "function" then return parse_local_function(ps, i) - elseif ntk == "type" and ps.tokens[i+2].kind == "identifier" then - return parse_type_declaration(ps, i, "local_type") + elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_declaration(ps, i + 2, "local_type") elseif ntk == "macroexp" and ps.tokens[i+2].kind == "identifier" then return parse_local_macroexp(ps, i) elseif parse_type_body_fns[tn] and ps.tokens[i+2].kind == "identifier" then @@ -4111,8 +4118,8 @@ local function parse_global(ps: ParseState, i: integer): integer, Node local tn = ntk as TypeName if ntk == "function" then return parse_function(ps, i + 1, "global") - elseif ntk == "type" and ps.tokens[i+2].kind == "identifier" then - return parse_type_declaration(ps, i, "global_type") + elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_declaration(ps, i + 2, "global_type") elseif parse_type_body_fns[tn] and ps.tokens[i+2].kind == "identifier" then return parse_type_constructor(ps, i, "global_type", tn, parse_type_body_fns[tn]) elseif ps.tokens[i+1].kind == "identifier" then @@ -4434,6 +4441,11 @@ local function recurse_type(s: S, ast: Type, visit: Visitor visit_type = { cbs = { - ["function"] = { - before = function(self: TypeChecker, _typ: Type) - self:begin_scope() - end, - after = function(self: TypeChecker, typ: Type, _children: {Type}): Type - self:end_scope() - return self:ensure_fresh_typeargs(typ) - end, - }, ["record"] = { before = function(self: TypeChecker, typ: RecordType) self:begin_scope() @@ -12332,11 +12382,13 @@ do } visit_type.cbs["interface"] = visit_type.cbs["record"] - visit_type.cbs["typedecl"] = visit_type.cbs["function"] + + visit_type.cbs["function"] = visit_type_with_typeargs + visit_type.cbs["typedecl"] = visit_type_with_typeargs + visit_type.cbs["typealias"] = visit_type_with_typeargs visit_type.cbs["string"] = default_type_visitor visit_type.cbs["tupletable"] = default_type_visitor - visit_type.cbs["typealias"] = default_type_visitor visit_type.cbs["array"] = default_type_visitor visit_type.cbs["map"] = default_type_visitor visit_type.cbs["enum"] = default_type_visitor