diff --git a/spec/declaration/record_method_spec.lua b/spec/declaration/record_method_spec.lua index a9f0c3632..b920c17a2 100644 --- a/spec/declaration/record_method_spec.lua +++ b/spec/declaration/record_method_spec.lua @@ -460,6 +460,21 @@ describe("record method", function() end ]])) + it("inherits type variables from the record definition (regression test for #657)", util.check([[ + local record Test + value: T + end + + function Test.new(value: T): Test + return setmetatable({ value = value }, { __index = Test }) + end + + function Test:print() + local t: T + t = self.value + end + ]])) + describe("redeclaration: ", function() it("an inconsistent arity in redeclaration produces an error (regression test for #496)", util.check_type_error([[ local record Y diff --git a/tl.lua b/tl.lua index 7ae8b9b5e..f0038cc92 100644 --- a/tl.lua +++ b/tl.lua @@ -1326,6 +1326,7 @@ local is_attribute = attributes + local function is_array_type(t) @@ -3200,6 +3201,8 @@ end + + @@ -3542,6 +3545,7 @@ local function recurse_node(root, recurse_typeargs(ast, visit_type) xs[1] = recurse(ast.fn_owner) xs[2] = recurse(ast.name) + extra_callback("before_arguments", ast, xs, visit_node) xs[3] = recurse(ast.args) xs[4] = recurse_type(ast.rets, visit_type) extra_callback("before_statements", ast, xs, visit_node) @@ -8110,10 +8114,18 @@ tl.type_check = function(ast, opts) end if t.typename == "typetype" then + local typevals + if t.def.typeargs then + typevals = {} + for _, a in ipairs(t.def.typeargs) do + table.insert(typevals, a_type({ typename = "typevar", typevar = a.typearg })) + end + end return a_type({ y = exp.y, x = exp.x, typename = "nominal", + typevals = typevals, names = { exp.tk }, found = t, }) @@ -9629,10 +9641,25 @@ tl.type_check = function(ast, opts) widen_all_unions() begin_scope(node) end, + before_arguments = function(node, children) + node.rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) + + + if node.rtype.typeargs then + for _, typ in ipairs(node.rtype.typeargs) do + add_var(nil, typ.typearg, a_type({ + y = typ.y, + x = typ.x, + typename = "typearg", + typearg = typ.typearg, + })) + end + end + end, before_statements = function(node, children) add_internal_function_variables(node) - local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) + local rtype = node.rtype if rtype.typename == "emptytable" then rtype.typename = "record" rtype.fields = {} diff --git a/tl.tl b/tl.tl index f4ac554fb..fabd0aff8 100644 --- a/tl.tl +++ b/tl.tl @@ -1272,6 +1272,7 @@ local record Node body: Node implicit_global_function: boolean is_predeclared_local_function: boolean + rtype: Type name: Node @@ -3182,12 +3183,14 @@ end local record VisitorCallbacks before: function(N, {T}) before_expressions: function({N}, {T}) + before_arguments: function({N}, {T}) before_statements: function({N}, {T}) before_e2: function({N}, {T}) after: function(N, {T}, T): T end local enum VisitorExtraCallback + "before_arguments" "before_statements" "before_expressions" "before_e2" @@ -3542,6 +3545,7 @@ local function recurse_node(root: Node, recurse_typeargs(ast, visit_type) xs[1] = recurse(ast.fn_owner) xs[2] = recurse(ast.name) + extra_callback("before_arguments", ast, xs, visit_node) xs[3] = recurse(ast.args) xs[4] = recurse_type(ast.rets, visit_type) extra_callback("before_statements", ast, xs, visit_node) @@ -8110,10 +8114,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if t.typename == "typetype" then + local typevals: Type + if t.def.typeargs then + typevals = {} + for _, a in ipairs(t.def.typeargs) do + table.insert(typevals, a_type { typename = "typevar", typevar = a.typearg }) + end + end return a_type { y = exp.y, x = exp.x, typename = "nominal", + typevals = typevals, names = { exp.tk }, found = t, } @@ -9629,10 +9641,25 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string widen_all_unions() begin_scope(node) end, + before_arguments = function(node: Node, children: {Type}) + node.rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) + + -- add type arguments from the record implicitly + if node.rtype.typeargs then + for _, typ in ipairs(node.rtype.typeargs) do + add_var(nil, typ.typearg, a_type { + y = typ.y, + x = typ.x, + typename = "typearg", + typearg = typ.typearg, + }) + end + end + end, before_statements = function(node: Node, children: {Type}) add_internal_function_variables(node) - local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) + local rtype = node.rtype if rtype.typename == "emptytable" then rtype.typename = "record" rtype.fields = {}