Skip to content

Commit

Permalink
metatables: check metamethod types in metatable definition
Browse files Browse the repository at this point in the history
Add special-case behavior to specialize a type `metatable<R>` using the
definition of `metamethod` entries from `R` (and not just a type-variable
application of `R` into the definition of `global record metatable<T>` from
the standard library definition.

See tests in spec/declaration/metatable_spec.lua for examples of the added
checks.

Fixes #633.

(At least the extent of it that can be resolved at this time,
without explicit `nil` support -- a good explanation as to why
the second case isn't resolved is given by @bjornbm in
#633 (comment) :

"the record definition defines what keys/values are valid, but not that they
are defined (or more generally perhaps the values may be nil, since nil is a
valid value of every type). What is checked is that values for the defined
keys have the right type, and that no other keys are added to the record.
  • Loading branch information
hishamhm committed Sep 3, 2024
1 parent 3cc78b6 commit fcdd861
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 54 deletions.
145 changes: 145 additions & 0 deletions spec/declaration/metatable_spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
local util = require("spec.util")

describe("metatable declaration", function()
it("checks metamethod declarations in record against a general contract", util.check_type_error([[
local type Rec = record
n: integer
metamethod __sub: function(self: Rec, b: integer, wat: integer): Rec
end
local rec_mt: metatable<Rec>
rec_mt = {
__add = function(self: Rec, b: Rec): Rec
return { n = self.n + b.n }
end,
}
local r: Rec = setmetatable({ n = 10 }, rec_mt)
print((r - 3).n)
]], {
{ y = 3, x = 28, msg = "__sub does not follow metatable contract: got function(Rec, integer, integer): Rec, expected function<A, B, C>(A, B): C" },
{ y = 14, x = 16, msg = "wrong number of arguments" },
}))

it("checks metatable against metamethod declarations", util.check_type_error([[
local type Rec = record
n: integer
metamethod __add: function(self: Rec, b: integer): Rec
end
local rec_mt: metatable<Rec>
rec_mt = {
__add = function(self: Rec, b: Rec): Rec
return { n = self.n + b.n }
end,
}
local r: Rec = setmetatable({ n = 10 }, rec_mt)
print((r + 9).n)
print((9 + r).n)
]], {
{ y = 8, x = 41, msg = "in record field: __add: argument 2: got Rec, expected integer" },
{ y = 15, x = 14, msg = "argument 1: got integer, expected Rec" },
}))

it("checks non-method metamethods with self in any position", util.check_type_error([[
local type Rec = record
n: integer
metamethod __mul: function(a: integer, b: Rec): integer
end
local rec_mt: metatable<Rec>
rec_mt = {
__mul = function(a: integer, b: Rec): integer
return a * b.n
end,
}
local r: Rec = setmetatable({ n = 10 }, rec_mt)
print((9 * r) + 3.0)
print((r * 9) + 3.0)
]], {
{ y = 15, x = 14, msg = "argument 1: got Rec, expected integer" },
}))

it("checks metamethods with multiple entries of the type", util.check_type_error([[
local type Rec = record
n: integer
metamethod __div: function(a: Rec, b: Rec): integer
end
local rec_mt: metatable<Rec>
rec_mt = {
__div = function(a: Rec, b: Rec): integer
return a.n // b.n
end,
}
local r: Rec = setmetatable({ n = 10 }, rec_mt)
print((r / 9) + 3.0)
print((r / r) + 3.0)
]], {
{ y = 14, x = 18, msg = "argument 2: got integer, expected Rec" },
}))

it("checks metamethods with method-like self", util.check_type_error([[
local type Rec = record
n: integer
metamethod __index: function(Rec, s: string): Rec
end
local rec_mt: metatable<Rec>
rec_mt = {
__index = function(self: Rec, k: string): Rec
return { n = #k }
end,
}
local r: Rec = setmetatable({ n = 10 }, rec_mt)
print(r["hello"])
print(r[true])
]], {
{ y = 15, x = 15, msg = "argument 1: got boolean, expected string" },
}))

it("checks metamethods with method-like self (explicit self)", util.check_type_error([[
local type Rec = record
n: integer
metamethod __index: function(self: Rec, s: string): Rec
end
local rec_mt: metatable<Rec>
rec_mt = {
__index = function(r: Rec, k: string): Rec
return { n = #k }
end,
}
local r: Rec = setmetatable({ n = 10 }, rec_mt)
print(r["hello"])
print(r[true])
]], {
{ y = 15, x = 15, msg = "argument 1: got boolean, expected string" },
}))

it("checks metamethods with method-like self (other name)", util.check_type_error([[
local type Rec = record
n: integer
metamethod __index: function(r: Rec, s: string): Rec
end
local rec_mt: metatable<Rec>
rec_mt = {
__index = function(r: Rec, k: string): Rec
return { n = #k }
end,
}
local r: Rec = setmetatable({ n = 10 }, rec_mt)
print(r["hello"])
print(r[true])
]], {
{ y = 15, x = 15, msg = "argument 1: got boolean, expected string" },
}))

end)
4 changes: 2 additions & 2 deletions spec/metamethods/index_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ describe("metamethod __index", function()
end
local rec_mt: metatable<Rec> = {
__index = function(self: Rec, s: string, n: number): string
return tostring(self.x + n) .. s
__index = function(self: Rec, s: string): string
return tostring(self.x) .. s
end
}
Expand Down
4 changes: 2 additions & 2 deletions spec/metamethods/le_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ describe("binary metamethod __le using <=", function()
it("can be used via the second argument", util.check([[
local type Rec = record
x: number
metamethod __le: function(number, Rec): Rec
metamethod __le: function(number, Rec): boolean
end
local rec_mt: metatable<Rec>
Expand Down Expand Up @@ -153,7 +153,7 @@ describe("binary metamethod __le using >=", function()
it("can be used via the second argument", util.check([[
local type Rec = record
x: number
metamethod __le: function(number, Rec): Rec
metamethod __le: function(number, Rec): boolean
end
local rec_mt: metatable<Rec>
Expand Down
4 changes: 2 additions & 2 deletions spec/metamethods/lt_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ describe("binary metamethod __lt using <", function()
it("can be used via the second argument", util.check([[
local type Rec = record
x: number
metamethod __lt: function(number, Rec): Rec
metamethod __lt: function(number, Rec): boolean
end
local rec_mt: metatable<Rec>
Expand Down Expand Up @@ -153,7 +153,7 @@ describe("binary metamethod __lt using >", function()
it("can be used via the second argument", util.check([[
local type Rec = record
x: number
metamethod __lt: function(number, Rec): Rec
metamethod __lt: function(number, Rec): boolean
end
local rec_mt: metatable<Rec>
Expand Down
96 changes: 72 additions & 24 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -229,35 +229,35 @@ do
__mode: Mode
__name: string
__tostring: function(T): string
__pairs: function<K, V>(T): (function(): (K, V))
__pairs: function<K, V>(T): function(): (K, V)
__index: any --[[FIXME: function | table | anything with an __index metamethod]]
__newindex: any --[[FIXME: function | table | anything with an __index metamethod]]
__gc: function(T)
__close: function(T)
__add: function(any, any): any
__sub: function(any, any): any
__mul: function(any, any): any
__div: function(any, any): any
__idiv: function(any, any): any
__mod: function(any, any): any
__pow: function(any, any): any
__band: function(any, any): any
__bor: function(any, any): any
__bxor: function(any, any): any
__shl: function(any, any): any
__shr: function(any, any): any
__concat: function(any, any): any
__len: function(T): any
__unm: function(T): any
__bnot: function(T): any
__eq: function(any, any): boolean
__lt: function(any, any): boolean
__le: function(any, any): boolean
__add: function<A, B, C>(A, B): C
__sub: function<A, B, C>(A, B): C
__mul: function<A, B, C>(A, B): C
__div: function<A, B, C>(A, B): C
__idiv: function<A, B, C>(A, B): C
__mod: function<A, B, C>(A, B): C
__pow: function<A, B, C>(A, B): C
__band: function<A, B, C>(A, B): C
__bor: function<A, B, C>(A, B): C
__bxor: function<A, B, C>(A, B): C
__shl: function<A, B, C>(A, B): C
__shr: function<A, B, C>(A, B): C
__concat: function<A, B, C>(A, B): C
__len: function<A>(T): A
__unm: function<A>(T): A
__bnot: function<A>(T): A
__eq: function<A, B>(A, B): boolean
__lt: function<A, B>(A, B): boolean
__le: function<A, B>(A, B): boolean
end
global record os
Expand Down Expand Up @@ -6330,7 +6330,7 @@ end
function Errors:fail_unresolved_nominals(scope, global_scope)
if global_scope and scope.pending_nominals then
for name, types in pairs(scope.pending_nominals) do
if not global_scope.pending_global_types[name] then
if not global_scope.pending_global_types[name] and name ~= "metatable" then
for _, typ in ipairs(types) do
assert(typ.x)
assert(typ.y)
Expand Down Expand Up @@ -7120,6 +7120,8 @@ do








Expand Down Expand Up @@ -7220,6 +7222,9 @@ do
function TypeChecker:find_type(names, accept_typearg)
local typ = self:find_var_type(names[1], "use_type")
if not typ then
if #names == 1 and names[1] == "metatable" then
return self:find_type({ "_metatable" })
end
return nil
end
if typ.typename == "nominal" and typ.found then
Expand Down Expand Up @@ -7891,6 +7896,27 @@ do
self:add_var(nil, def.typeargs[i].typearg, tt)
end
local ret = self:resolve_typevars_at(t, def)

if def == self.cache_std_metatable_type then
local tv = t.typevals[1]
if tv.typename == "nominal" then
local found = tv.found
if found and found.typename == "typedecl" then
local rec = found.def
if rec.fields and rec.meta_fields and ret.fields then
for fname, ftype in pairs(rec.meta_fields) do
if ret.fields[fname] then
if not self:is_a(ftype, ret.fields[fname]) then
self.errs:add(ftype, fname .. " does not follow metatable contract: got %s, expected %s", ftype, ret.fields[fname])
end
end
ret.fields[fname] = ftype
end
end
end
end
end

self:end_scope()
return ret
elseif t.typevals then
Expand Down Expand Up @@ -9440,7 +9466,12 @@ a.types[i], b.types[i]), }
e2[2] = node.e2
args.tuple[2] = orig_b
end
return self:to_structural(resolve_tuple((self:type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator

local mtdelta = metamethod.typename == "function" and metamethod.is_method and -1 or 0
local ret_call = self:type_check_function_call(node, metamethod, args, mtdelta, node, e2)
local ret_unary = resolve_tuple(ret_call)
local ret = self:to_structural(ret_unary)
return ret, meta_on_operator
else
return nil, nil
end
Expand Down Expand Up @@ -12566,6 +12597,20 @@ self:expand_type(node, values, elements) })
return true
end

local metamethod_is_method = {
["__bnot"] = true,
["__call"] = true,
["__close"] = true,
["__gc"] = true,
["__index"] = true,
["__is"] = true,
["__len"] = true,
["__newindex"] = true,
["__pairs"] = true,
["__tostring"] = true,
["__unm"] = true,
}

local visit_type
visit_type = {
cbs = {
Expand Down Expand Up @@ -12643,6 +12688,7 @@ self:expand_type(node, values, elements) })
fmacros = fmacros or {}
table.insert(fmacros, ftype)
end
ftype.is_method = metamethod_is_method[name]
end
typ.meta_fields[name] = ftype
i = i + 1
Expand Down Expand Up @@ -12874,6 +12920,8 @@ self:expand_type(node, values, elements) })
type_priorities = TypeChecker.type_priorities,
}

self.cache_std_metatable_type = env.globals["metatable"] and (env.globals["metatable"].t).def

setmetatable(self, { __index = TypeChecker })

self.feat_lax = set_feat(opts.feat_lax or env.defaults.feat_lax, false)
Expand Down
Loading

0 comments on commit fcdd861

Please sign in to comment.