From 976480b356cd3dd848a74d7dbe58ab606dce793d Mon Sep 17 00:00:00 2001 From: Francesco Gazzetta Date: Thu, 3 Nov 2022 16:33:14 +0100 Subject: [PATCH] Type narrowing proof of concept --- spec/is/is.lua | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++ tl.lua | 53 +++++++++++++++++++++++++++++++++++++++------ tl.tl | 53 +++++++++++++++++++++++++++++++++++++++------ 3 files changed, 152 insertions(+), 12 deletions(-) create mode 100644 spec/is/is.lua diff --git a/spec/is/is.lua b/spec/is/is.lua new file mode 100644 index 00000000..eb7c2f25 --- /dev/null +++ b/spec/is/is.lua @@ -0,0 +1,58 @@ +local util = require("spec.util") + +describe("Is:", function() + + it("is_ function", util.check [[ + local record Is end + + local record MyRecord + a: number + end + + local record OtherRecord + a: boolean + end + + local r : MyRecord | OtherRecord = { a = 1 } + + local n : number + + local function is_myrecord(x: any): Is + if x is table then + local a = x.a + return (a is number) + else return false end + end + if is_myrecord(r) then + n = r.a + end + ]]) + + it("is_ method", util.check [[ + local record Is end + + local record A + is_b : function(self : A | B) : Is + end + + local record B + is_b : function(self : A | B) : Is + b_field : string + end + + local b1 : B = { + is_b = function(self : A | B) : Is + return (self as {string:any}).b_field ~= nil + end, + b_field = "yes", + } + + local ab : A | B = b1 + + if ab:is_b() then + local b2 : B = ab + local s : string = b2.b_field + end + ]]) + +end) diff --git a/tl.lua b/tl.lua index cb73c7f4..b74f14ea 100644 --- a/tl.lua +++ b/tl.lua @@ -5585,7 +5585,6 @@ tl.type_check = function(ast, opts) local function is_valid_union(typ) - local n_table_types = 0 local n_function_types = 0 local n_userdata_types = 0 local n_string_enum = 0 @@ -5597,11 +5596,6 @@ tl.type_check = function(ast, opts) if n_userdata_types > 1 then return false, "cannot discriminate a union between multiple userdata types: %s" end - elseif ut == "table" then - n_table_types = n_table_types + 1 - if n_table_types > 1 then - return false, "cannot discriminate a union between multiple table types: %s" - end elseif ut == "function" then n_function_types = n_function_types + 1 if n_function_types > 1 then @@ -6711,6 +6705,10 @@ tl.type_check = function(ast, opts) return false, terr(t1, "enum is incompatible with %s", t2) end elseif t1.typename == "integer" and t2.typename == "number" then + return true + elseif t1.typename == "boolean" and t2.typename == "nominal" and t2.tk == "Is" then + + return true elseif t1.typename == "string" and t2.typename == "enum" then local ok = t1.tk and t2.enumset[unquote(t1.tk)] @@ -7190,6 +7188,28 @@ tl.type_check = function(ast, opts) else return nil, "invalid key '" .. key .. "' in type %s" end + + elseif tbl.typename == "union" then + assert(tbl.types[1], "Union has no members") + local field + for _, t in ipairs(tbl.types) do + + t = resolve_tuple_and_nominal(t) + t = resolve_typetype(t) + + + if not is_record_type(t) then + return nil, "cannot index key '" .. key .. "' in '" .. t.tk .. "' from union %s (not a record)" + end + assert(t.fields, "record has no fields!?") + + if not t.fields[key] then + return nil, "invalid key '" .. key .. "' in type '" .. t.tk .. "' from union %s" + else + field = t.fields[key] + end + end + return field elseif tbl.typename == "emptytable" or is_unknown(tbl) then if lax then return INVALID @@ -9114,6 +9134,27 @@ node.exps[3] and node.exps[3].type, } end end elseif node.op.op == "@funcall" then + + + local is_is_function = a.typename == "function" and a.rets[1] and a.rets[1].tk == "Is" + local is_method = node.e1.op and node.e1.op.op == ":" + local first_arg = node.e2[1] + if is_is_function and (first_arg or is_method) then + local refined_var + + + if is_method then + refined_var = node.e1.e1.tk + else + refined_var = first_arg.tk + end + node.known = Fact({ + fact = "is", + var = refined_var, + typ = a.rets[1].typevals[1], + where = node, + }) + end if lax and is_unknown(a) then if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) diff --git a/tl.tl b/tl.tl index 6934c5ee..c595fe66 100644 --- a/tl.tl +++ b/tl.tl @@ -5585,7 +5585,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function is_valid_union(typ: Type): boolean, string -- check for limitations in our union support -- due to codegen limitations (we only check with type() so far) - local n_table_types = 0 local n_function_types = 0 local n_userdata_types = 0 local n_string_enum = 0 @@ -5597,11 +5596,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if n_userdata_types > 1 then return false, "cannot discriminate a union between multiple userdata types: %s" end - elseif ut == "table" then - n_table_types = n_table_types + 1 - if n_table_types > 1 then - return false, "cannot discriminate a union between multiple table types: %s" - end elseif ut == "function" then n_function_types = n_function_types + 1 if n_function_types > 1 then @@ -6712,6 +6706,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end elseif t1.typename == "integer" and t2.typename == "number" then return true + elseif t1.typename == "boolean" and t2.typename == "nominal" and t2.tk == "Is" then + -- Treat booleans as Is<>, so that is_() functions don't have to cast + -- their return values. + return true elseif t1.typename == "string" and t2.typename == "enum" then local ok = t1.tk and t2.enumset[unquote(t1.tk)] if ok then @@ -7190,6 +7188,28 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string else return nil, "invalid key '" .. key .. "' in type %s" end + -- Union of records + elseif tbl.typename == "union" then + assert(tbl.types[1], "Union has no members") + local field : Type + for _,t in ipairs(tbl.types) do + -- TODO probably doing too much stuff + t = resolve_tuple_and_nominal(t) + t = resolve_typetype(t) + -- TODO support unions of unions, recursively + -- (properly, so that eg. we do all those extra checks outside of this if) + if not is_record_type(t) then + return nil, "cannot index key '" .. key .. "' in '" .. t.tk .. "' from union %s (not a record)" + end + assert(t.fields, "record has no fields!?") + -- key should be in all records + if not t.fields[key] then + return nil, "invalid key '" .. key .. "' in type '" .. t.tk .. "' from union %s" + else + field = t.fields[key] + end + end + return field elseif tbl.typename == "emptytable" or is_unknown(tbl) then if lax then return INVALID @@ -9114,6 +9134,27 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end elseif node.op.op == "@funcall" then + -- Huge hack, Is should be its own type, or its definition should + -- be provided by teal. + local is_is_function = a.typename == "function" and a.rets[1] and a.rets[1].tk == "Is" + local is_method = node.e1.op and node.e1.op.op == ":" + local first_arg = node.e2[1] + if is_is_function and (first_arg or is_method) then + local refined_var : string + -- If it's a method call, we refine self, otherwise the first + -- argument of the function. + if is_method then + refined_var = node.e1.e1.tk + else + refined_var = first_arg.tk + end + node.known = Fact { + fact = "is", + var = refined_var, + typ = a.rets[1].typevals[1], -- type argument of Is<> + where = node, + } + end if lax and is_unknown(a) then if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk)