From e178c42bcb81fbe81599a07432246d5f5a466a95 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 13 Aug 2024 10:53:48 -0300 Subject: [PATCH] fix: do not close nested types too early Closes #775. --- spec/declaration/record_function_spec.lua | 20 ++++++++++++ tl.lua | 37 +++++++++++++++++------ tl.tl | 37 +++++++++++++++++------ 3 files changed, 74 insertions(+), 20 deletions(-) diff --git a/spec/declaration/record_function_spec.lua b/spec/declaration/record_function_spec.lua index 290f5e617..75bd616f5 100644 --- a/spec/declaration/record_function_spec.lua +++ b/spec/declaration/record_function_spec.lua @@ -107,5 +107,25 @@ describe("record function", function() ]], {}, { { y = 7, msg = "different number of return values: got 1, expected 0" }, })) + + it("does not close nested types too early (regression test for #775)", util.check([[ + -- declare a nested record + local record mul + record Fil + mime: function(Fil) + end + end + + -- declare an alias + local type Fil = mul.Fil + + -- this works + function mul.Fil:new_method1(self: Fil) + end + + -- should work as well for alias + function Fil:new_method2(self: Fil) + end + ]])) end) end) diff --git a/tl.lua b/tl.lua index 4f135af23..c8a56c598 100644 --- a/tl.lua +++ b/tl.lua @@ -12277,22 +12277,38 @@ self:expand_type(node, values, elements) }) end, } + function TypeChecker:begin_temporary_record_types(typ) + self:add_var(nil, "@self", type_at(typ, a_type(typ, "typedecl", { def = typ }))) + + for fname, ftype in fields_of(typ) do + if ftype.typename == "typealias" then + self:resolve_nominal(ftype.alias_to) + self:add_var(nil, fname, ftype) + elseif ftype.typename == "typedecl" then + self:add_var(nil, fname, ftype) + end + end + end + + function TypeChecker:end_temporary_record_types(typ) + + + local scope = self.st[#self.st] + scope.vars["@self"] = nil + for fname, ftype in fields_of(typ) do + if ftype.typename == "typealias" or ftype.typename == "typedecl" then + scope.vars[fname] = nil + end + end + end + local visit_type visit_type = { cbs = { ["record"] = { before = function(self, typ) self:begin_scope() - self:add_var(nil, "@self", type_at(typ, a_type(typ, "typedecl", { def = typ }))) - - for fname, ftype in fields_of(typ) do - if ftype.typename == "typealias" then - self:resolve_nominal(ftype.alias_to) - self:add_var(nil, fname, ftype) - elseif ftype.typename == "typedecl" then - self:add_var(nil, fname, ftype) - end - end + self:begin_temporary_record_types(typ) end, after = function(self, typ, children) local i = 1 @@ -12386,6 +12402,7 @@ self:expand_type(node, values, elements) }) end end + self:end_temporary_record_types(typ) self:end_scope() return typ diff --git a/tl.tl b/tl.tl index 7d6250047..957b891d4 100644 --- a/tl.tl +++ b/tl.tl @@ -12277,22 +12277,38 @@ do end, } + function TypeChecker:begin_temporary_record_types(typ: RecordType) + self:add_var(nil, "@self", type_at(typ, a_typedecl(typ, typ))) + + for fname, ftype in fields_of(typ) do + if ftype is TypeAliasType then + self:resolve_nominal(ftype.alias_to) + self:add_var(nil, fname, ftype) + elseif ftype is TypeDeclType then + self:add_var(nil, fname, ftype) + end + end + end + + function TypeChecker:end_temporary_record_types(typ: RecordType) + -- drop @self and nested records from scope + -- to avoid closing them prematurely in end_scope() + local scope = self.st[#self.st] + scope.vars["@self"] = nil + for fname, ftype in fields_of(typ) do + if ftype is TypeAliasType or ftype is TypeDeclType then + scope.vars[fname] = nil + end + end + end + local visit_type: Visitor visit_type = { cbs = { ["record"] = { before = function(self: TypeChecker, typ: RecordType) self:begin_scope() - self:add_var(nil, "@self", type_at(typ, a_typedecl(typ, typ))) - - for fname, ftype in fields_of(typ) do - if ftype is TypeAliasType then - self:resolve_nominal(ftype.alias_to) - self:add_var(nil, fname, ftype) - elseif ftype is TypeDeclType then - self:add_var(nil, fname, ftype) - end - end + self:begin_temporary_record_types(typ) end, after = function(self: TypeChecker, typ: RecordType, children: {Type}): Type local i = 1 @@ -12386,6 +12402,7 @@ do end end + self:end_temporary_record_types(typ) self:end_scope() return typ