Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type narrowing proof of concept #577

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions spec/is/is.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
local util = require("spec.util")

describe("Is<T>:", function()

it("is_ function", util.check [[
local record Is<T> end

local record MyRecord
a: number
end

local record OtherRecord
a: boolean
end

local r : MyRecord | OtherRecord = { a = 1 }

local n : number

local function is_myrecord(x: any): Is<MyRecord>
if x is table then
local a = x.a
return (a is number)
else return false end
end
if is_myrecord(r) then
n = r.a
end
]])

it("is_ method", util.check [[
local record Is<T> end

local record A
is_b : function(self : A | B) : Is<B>
end

local record B
is_b : function(self : A | B) : Is<B>
b_field : string
end

local b1 : B = {
is_b = function(self : A | B) : Is<B>
return (self as {string:any}).b_field ~= nil
end,
b_field = "yes",
}

local ab : A | B = b1

if ab:is_b() then
local b2 : B = ab
local s : string = b2.b_field
end
]])

end)
53 changes: 47 additions & 6 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5585,7 +5585,6 @@ tl.type_check = function(ast, opts)
local function is_valid_union(typ)


local n_table_types = 0
local n_function_types = 0
local n_userdata_types = 0
local n_string_enum = 0
Expand All @@ -5597,11 +5596,6 @@ tl.type_check = function(ast, opts)
if n_userdata_types > 1 then
return false, "cannot discriminate a union between multiple userdata types: %s"
end
elseif ut == "table" then
n_table_types = n_table_types + 1
if n_table_types > 1 then
return false, "cannot discriminate a union between multiple table types: %s"
end
elseif ut == "function" then
n_function_types = n_function_types + 1
if n_function_types > 1 then
Expand Down Expand Up @@ -6711,6 +6705,10 @@ tl.type_check = function(ast, opts)
return false, terr(t1, "enum is incompatible with %s", t2)
end
elseif t1.typename == "integer" and t2.typename == "number" then
return true
elseif t1.typename == "boolean" and t2.typename == "nominal" and t2.tk == "Is" then


return true
elseif t1.typename == "string" and t2.typename == "enum" then
local ok = t1.tk and t2.enumset[unquote(t1.tk)]
Expand Down Expand Up @@ -7190,6 +7188,28 @@ tl.type_check = function(ast, opts)
else
return nil, "invalid key '" .. key .. "' in type %s"
end

elseif tbl.typename == "union" then
assert(tbl.types[1], "Union has no members")
local field
for _, t in ipairs(tbl.types) do

t = resolve_tuple_and_nominal(t)
t = resolve_typetype(t)


if not is_record_type(t) then
return nil, "cannot index key '" .. key .. "' in '" .. t.tk .. "' from union %s (not a record)"
end
assert(t.fields, "record has no fields!?")

if not t.fields[key] then
return nil, "invalid key '" .. key .. "' in type '" .. t.tk .. "' from union %s"
else
field = t.fields[key]
end
end
return field
elseif tbl.typename == "emptytable" or is_unknown(tbl) then
if lax then
return INVALID
Expand Down Expand Up @@ -9114,6 +9134,27 @@ node.exps[3] and node.exps[3].type, }
end
end
elseif node.op.op == "@funcall" then


local is_is_function = a.typename == "function" and a.rets[1] and a.rets[1].tk == "Is"
local is_method = node.e1.op and node.e1.op.op == ":"
local first_arg = node.e2[1]
if is_is_function and (first_arg or is_method) then
local refined_var


if is_method then
refined_var = node.e1.e1.tk
else
refined_var = first_arg.tk
end
node.known = Fact({
fact = "is",
var = refined_var,
typ = a.rets[1].typevals[1],
where = node,
})
end
if lax and is_unknown(a) then
if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then
add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk)
Expand Down
53 changes: 47 additions & 6 deletions tl.tl
Original file line number Diff line number Diff line change
Expand Up @@ -5585,7 +5585,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
local function is_valid_union(typ: Type): boolean, string
-- check for limitations in our union support
-- due to codegen limitations (we only check with type() so far)
local n_table_types = 0
local n_function_types = 0
local n_userdata_types = 0
local n_string_enum = 0
Expand All @@ -5597,11 +5596,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
if n_userdata_types > 1 then
return false, "cannot discriminate a union between multiple userdata types: %s"
end
elseif ut == "table" then
n_table_types = n_table_types + 1
if n_table_types > 1 then
return false, "cannot discriminate a union between multiple table types: %s"
end
elseif ut == "function" then
n_function_types = n_function_types + 1
if n_function_types > 1 then
Expand Down Expand Up @@ -6712,6 +6706,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
end
elseif t1.typename == "integer" and t2.typename == "number" then
return true
elseif t1.typename == "boolean" and t2.typename == "nominal" and t2.tk == "Is" then
-- Treat booleans as Is<>, so that is_() functions don't have to cast
-- their return values.
return true
elseif t1.typename == "string" and t2.typename == "enum" then
local ok = t1.tk and t2.enumset[unquote(t1.tk)]
if ok then
Expand Down Expand Up @@ -7190,6 +7188,28 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
else
return nil, "invalid key '" .. key .. "' in type %s"
end
-- Union of records
elseif tbl.typename == "union" then
assert(tbl.types[1], "Union has no members")
local field : Type
for _,t in ipairs(tbl.types) do
-- TODO probably doing too much stuff
t = resolve_tuple_and_nominal(t)
t = resolve_typetype(t)
-- TODO support unions of unions, recursively
-- (properly, so that eg. we do all those extra checks outside of this if)
if not is_record_type(t) then
return nil, "cannot index key '" .. key .. "' in '" .. t.tk .. "' from union %s (not a record)"
end
assert(t.fields, "record has no fields!?")
-- key should be in all records
if not t.fields[key] then
return nil, "invalid key '" .. key .. "' in type '" .. t.tk .. "' from union %s"
else
field = t.fields[key]
end
end
return field
elseif tbl.typename == "emptytable" or is_unknown(tbl) then
if lax then
return INVALID
Expand Down Expand Up @@ -9114,6 +9134,27 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
end
end
elseif node.op.op == "@funcall" then
-- Huge hack, Is should be its own type, or its definition should
-- be provided by teal.
local is_is_function = a.typename == "function" and a.rets[1] and a.rets[1].tk == "Is"
local is_method = node.e1.op and node.e1.op.op == ":"
local first_arg = node.e2[1]
if is_is_function and (first_arg or is_method) then
local refined_var : string
-- If it's a method call, we refine self, otherwise the first
-- argument of the function.
if is_method then
refined_var = node.e1.e1.tk
else
refined_var = first_arg.tk
end
node.known = Fact {
fact = "is",
var = refined_var,
typ = a.rets[1].typevals[1], -- type argument of Is<>
where = node,
}
end
if lax and is_unknown(a) then
if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then
add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk)
Expand Down