From 58ebd4b7ab2c57dcd19f5c5e29c69221d8a016aa Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 18 Sep 2024 22:22:18 -0300 Subject: [PATCH] return: module returns the nominal's type, including typeargs Fixes #804. --- spec/statement/return_spec.lua | 88 ++++++++++++++++++++++++++++++++++ tl.lua | 8 +++- tl.tl | 8 +++- 3 files changed, 102 insertions(+), 2 deletions(-) diff --git a/spec/statement/return_spec.lua b/spec/statement/return_spec.lua index b52fc912a..1b9f17949 100644 --- a/spec/statement/return_spec.lua +++ b/spec/statement/return_spec.lua @@ -153,6 +153,94 @@ describe("return", function() assert.same({}, result.syntax_errors) assert.same({}, result.type_errors) end) + + it("when exporting a generic (regression test for #804)", function () + util.mock_io(finally, { + ["foo.tl"] = [[ + local record Foo + bar: T + end + return Foo + ]], + ["main.tl"] = [[ + local Foo = require("foo") + + local foo: Foo + + foo = { + bar = 5 + } + + print(string.format("bar: %d", foo.bar + 1)) + ]], + }) + + local tl = require("tl") + local result, err = tl.process("main.tl", assert(tl.init_env())) + + assert.same(nil, err) + assert.same({}, result.syntax_errors) + assert.same({}, result.type_errors) + end) + + it("when exporting a typealias (variation on regression test for #804)", function () + util.mock_io(finally, { + ["foo.tl"] = [[ + local record Foo + bar: T + end + local type FooInteger = Foo + return FooInteger + ]], + ["main.tl"] = [[ + local Foo = require("foo") + + local foo: Foo + + foo = { + bar = 5 + } + + print(string.format("bar: %d", foo.bar + 1)) + ]], + }) + + local tl = require("tl") + local result, err = tl.process("main.tl", assert(tl.init_env())) + + assert.same(nil, err) + assert.same({}, result.syntax_errors) + assert.same({}, result.type_errors) + end) + + it("when exporting a non-generic (variation on regression test for #804)", function () + util.mock_io(finally, { + ["foo.tl"] = [[ + local record Foo + bar: integer + end + return Foo + ]], + ["main.tl"] = [[ + local Foo = require("foo") + + local foo: Foo + + foo = { + bar = 5 + } + + print(string.format("bar: %d", foo.bar + 1)) + ]], + }) + + local tl = require("tl") + local result, err = tl.process("main.tl", assert(tl.init_env())) + + assert.same(nil, err) + assert.same({}, result.syntax_errors) + assert.same({}, result.type_errors) + end) end) it("when exporting type alias through multiple levels", function () diff --git a/tl.lua b/tl.lua index 044dcfbfd..4e2daccd8 100644 --- a/tl.lua +++ b/tl.lua @@ -11527,7 +11527,13 @@ self:expand_type(node, values, elements) }) if not expected then expected = self:infer_at(node, got) - self.module_type = drop_constant_value(self:to_structural(resolve_tuple(expected))) + local module_type = resolve_tuple(expected) + if module_type.typename == "nominal" then + self:resolve_nominal(module_type) + self.module_type = module_type.found + else + self.module_type = drop_constant_value(module_type) + end self.st[2].vars["@return"] = { t = expected } end local expected_t = expected.tuple diff --git a/tl.tl b/tl.tl index dad8de5ab..bd3a043d0 100644 --- a/tl.tl +++ b/tl.tl @@ -11527,7 +11527,13 @@ do if not expected then -- if at the toplevel expected = self:infer_at(node, got) - self.module_type = drop_constant_value(self:to_structural(resolve_tuple(expected))) + local module_type = resolve_tuple(expected) + if module_type is NominalType then + self:resolve_nominal(module_type) + self.module_type = module_type.found + else + self.module_type = drop_constant_value(module_type) + end self.st[2].vars["@return"] = { t = expected } end local expected_t = expected.tuple