Skip to content

Commit

Permalink
we can now reexport nested types
Browse files Browse the repository at this point in the history
Fixes #765.
  • Loading branch information
hishamhm committed Jul 23, 2024
1 parent 48bca4d commit e69d91b
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 92 deletions.
34 changes: 34 additions & 0 deletions spec/declaration/record_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,40 @@ for i, name in ipairs({"records", "arrayrecords", "interfaces", "arrayinterfaces
]])()
end)

it("can reexport types as nested " .. name, function()
util.mock_io(finally, {
["inner.tl"] = [[
local record inner
]]..statement..[[ SubType<K> ]]..array(i, "{integer}")..[[
item: K
end
end
return inner
]],
["outer.tl"] = [[
local core = require("inner")
local record mod
f: string
type SubType<K> = core.SubType<K>
end
return mod
]],
})
util.run_check_type_error([[
local mod = require("outer")
print(mod.f)
local v: mod.SubType<integer> = {
item = "hello"
}
]], {
{ msg = 'in record field: item: got string "hello", expected integer' }
})
end)

it("resolves aliasing of nested " .. name .. " (see #400)", util.check([[
local ]]..statement..[[ Foo ]]..array(i, "{Foo}")..[[
]]..statement..[[ Bar ]]..array(i, "{Bar}")..[[
Expand Down
6 changes: 6 additions & 0 deletions spec/util.lua
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,12 @@ local function check(lax, code, unknowns, gen_target)
gen_compat = "off"
end
local result = tl.type_check(ast, "foo.lua", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat })

for _, mname in pairs(result.env.loaded_order) do
local mresult = result.env.loaded[mname]
batch:add(assert.same, {}, mresult.syntax_errors or {}, "Code was not expected to have syntax errors")
end

batch:add(assert.same, {}, result.type_errors)

if unknowns then
Expand Down
130 changes: 91 additions & 39 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1889,7 +1889,6 @@ end






local TruthyFact = {}
Expand Down Expand Up @@ -2216,6 +2215,7 @@ do
local parse_argument_list
local parse_argument_type_list
local parse_type
local parse_type_declaration
local parse_newtype
local parse_interface_name

Expand Down Expand Up @@ -3771,14 +3771,19 @@ do
elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then
i = i + 1
local iv = i
local v
i, v = verify_kind(ps, i, "identifier", "type_identifier")

local lt
i, lt = parse_type_declaration(ps, i, "local_type")
if not lt then
return fail(ps, i, "expected a type definition")
end

local v = lt.var
if not v then
return fail(ps, i, "expected a variable name")
end
i = verify_tk(ps, i, "=")
local nt
i, nt = parse_newtype(ps, i)

local nt = lt.value
if not nt or not nt.newtype then
return fail(ps, i, "expected a type definition")
end
Expand Down Expand Up @@ -4000,9 +4005,7 @@ do
return i, asgn
end

local function parse_type_declaration(ps, i, node_name)
i = i + 2

parse_type_declaration = function(ps, i, node_name)
local asgn = new_node(ps, i, node_name)
local var

Expand Down Expand Up @@ -4048,6 +4051,10 @@ do
def.declname = asgn.var.tk
end
end
elseif nt.typename == "typealias" then
if typeargs then
nt.typeargs = typeargs
end
end

return i, asgn
Expand Down Expand Up @@ -4078,7 +4085,7 @@ do
end

local function skip_type_declaration(ps, i)
return parse_type_declaration(ps, i - 1, "local_type")
return parse_type_declaration(ps, i + 1, "local_type")
end

local function parse_local_macroexp(ps, i)
Expand All @@ -4097,7 +4104,7 @@ do
if ntk == "function" then
return parse_local_function(ps, i)
elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then
return parse_type_declaration(ps, i, "local_type")
return parse_type_declaration(ps, i + 2, "local_type")
elseif ntk == "macroexp" and ps.tokens[i + 2].kind == "identifier" then
return parse_local_macroexp(ps, i)
elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then
Expand All @@ -4112,7 +4119,7 @@ do
if ntk == "function" then
return parse_function(ps, i + 1, "global")
elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then
return parse_type_declaration(ps, i, "global_type")
return parse_type_declaration(ps, i + 2, "global_type")
elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then
return parse_type_constructor(ps, i, "global_type", tn, parse_type_body_fns[tn])
elseif ps.tokens[i + 1].kind == "identifier" then
Expand Down Expand Up @@ -4434,6 +4441,11 @@ local function recurse_type(s, ast, visit)
table.insert(xs, recurse_type(s, ast.vtype, visit))
end
elseif ast.typename == "typealias" then
if ast.typeargs then
for _, child in ipairs(ast.typeargs) do
table.insert(xs, recurse_type(s, child, visit))
end
end
table.insert(xs, recurse_type(s, ast.alias_to, visit))
elseif ast.typename == "typedecl" then
if ast.typeargs then
Expand Down Expand Up @@ -6998,7 +7010,7 @@ do
fresh_typevar_ctr = fresh_typevar_ctr + 1
local ok
ok, t = typevar_resolver(nil, t, fresh_typevar, fresh_typearg)
assert(ok, "Internal Compiler Error: error creating fresh type variables")
assert(ok and t, "Internal Compiler Error: error creating fresh type variables")
return t
end

Expand Down Expand Up @@ -7711,6 +7723,10 @@ do
found = found.alias_to.found
end

if not found then
return self.errs:invalid_at(t, table.concat(t.names, ".") .. " is not a resolved type")
end

if not (found.typename == "typedecl") then
return self.errs:invalid_at(t, table.concat(t.names, ".") .. " is not a type")
end
Expand All @@ -7731,7 +7747,7 @@ do
local function resolve_decl_into_nominal(self, t, found)
local def = found.def
local resolved
if def.typename == "record" or def.typename == "function" then
if def.fields or def.typename == "function" then
resolved = match_typevals(self, t, def)
if not resolved then
return self.errs:invalid_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope")
Expand All @@ -7757,7 +7773,7 @@ do
local t = typealias.alias_to

local immediate, found = find_nominal_type_decl(self, t)
if immediate then
if type(immediate) == "table" then
return immediate
end

Expand Down Expand Up @@ -8132,6 +8148,16 @@ do
end


function TypeChecker:forall_are_subtype_of(xs, t)
for _, x in ipairs(xs.types) do
if not self:is_a(x, t) then
return false
end
end
return true
end


local emptytable_relations = {
["array"] = compare_true,
["map"] = compare_true,
Expand Down Expand Up @@ -8263,6 +8289,15 @@ do
["*"] = compare_true,
},
["union"] = {
["nominal"] = function(self, a, b)

local rb = self:resolve_nominal(b)
if rb.typename == "union" then
return self:is_a(a, rb)
end

return self:forall_are_subtype_of(a, b)
end,
["union"] = function(self, a, b)
local used = {}
for _, t in ipairs(a.types) do
Expand All @@ -8281,14 +8316,7 @@ do
end
return true
end,
["*"] = function(self, a, b)
for _, t in ipairs(a.types) do
if not self:is_a(t, b) then
return false
end
end
return true
end,
["*"] = TypeChecker.forall_are_subtype_of,
},
["poly"] = {
["*"] = function(self, a, b)
Expand All @@ -8305,21 +8333,36 @@ do
return true
end

local rb = self:resolve_nominal(b)
if rb.typename == "interface" then

local ra = self:resolve_nominal(a)
local rb = self:resolve_nominal(b)
if ra.typename == "union" and rb.typename == "union" then
return self:is_a(ra, rb)
end
if ra.typename == "union" then
return self:is_a(ra, b)
end
if rb.typename == "union" then
return self:is_a(a, rb)
end

local ra = self:resolve_nominal(a)
if ra.typename == "union" or rb.typename == "union" then

return self:is_a(ra, rb)
if rb.typename == "interface" then
return self:is_a(a, rb)
end


return ok, errs
end,
["union"] = function(self, a, b)

local ra = self:resolve_nominal(a)
if ra.typename == "union" then
return self:is_a(ra, b)
end

return not not self:exists_supertype_in(a, b)
end,
["*"] = TypeChecker.subtype_nominal,
},
["enum"] = {
Expand Down Expand Up @@ -10518,7 +10561,10 @@ self:expand_type(node, values, elements) })
value.e1.tk == "require" then

local t = special_functions["require"](self, value, self:find_var_type("require"), a_type(value.e2, "tuple", { tuple = { a_type(value.e2[1], "string", {}) } }), 0)


local ty = t.typename == "tuple" and t.tuple[1] or t

ty = (ty.typename == "typealias") and self:resolve_typealias(ty) or ty
local td = (ty.typename == "typedecl") and ty or a_type(value, "typedecl", { def = ty })
return td
Expand Down Expand Up @@ -10660,6 +10706,9 @@ self:expand_type(node, values, elements) })
if node.value then
local resolved, aliasing = self:get_typedecl(node.value)
local added = self:add_global(node.var, name, resolved)
if resolved.typename == "invalid" then
return
end
node.value.newtype = resolved
if aliasing then
added.aliasing = aliasing
Expand Down Expand Up @@ -12143,18 +12192,19 @@ self:expand_type(node, values, elements) })
end
end

local visit_type_with_typeargs = {
before = function(self, _typ)
self:begin_scope()
end,
after = function(self, typ, _children)
self:end_scope()
return self:ensure_fresh_typeargs(typ)
end,
}

local visit_type
visit_type = {
cbs = {
["function"] = {
before = function(self, _typ)
self:begin_scope()
end,
after = function(self, typ, _children)
self:end_scope()
return self:ensure_fresh_typeargs(typ)
end,
},
["record"] = {
before = function(self, typ)
self:begin_scope()
Expand Down Expand Up @@ -12332,11 +12382,13 @@ self:expand_type(node, values, elements) })
}

visit_type.cbs["interface"] = visit_type.cbs["record"]
visit_type.cbs["typedecl"] = visit_type.cbs["function"]

visit_type.cbs["function"] = visit_type_with_typeargs
visit_type.cbs["typedecl"] = visit_type_with_typeargs
visit_type.cbs["typealias"] = visit_type_with_typeargs

visit_type.cbs["string"] = default_type_visitor
visit_type.cbs["tupletable"] = default_type_visitor
visit_type.cbs["typealias"] = default_type_visitor
visit_type.cbs["array"] = default_type_visitor
visit_type.cbs["map"] = default_type_visitor
visit_type.cbs["enum"] = default_type_visitor
Expand Down
Loading

0 comments on commit e69d91b

Please sign in to comment.