Skip to content

Commit

Permalink
pragma: arity on/off
Browse files Browse the repository at this point in the history
  • Loading branch information
hishamhm committed Jun 13, 2024
1 parent 910375e commit 59d1f5b
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 2 deletions.
157 changes: 157 additions & 0 deletions spec/pragma/arity_spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
local util = require("spec.util")

describe("pragma arity", function()
describe("on", function()
it("rejects function calls with missing arguments", util.check_type_error([[
--! arity on
local function f(x: integer, y: integer)
print(x + y)
end
print(f(10))
]], {
{ msg = "wrong number of arguments (given 1, expects 2)" }
}))

it("accepts optional arguments", util.check([[
--! arity on
local function f(x: integer, y?: integer)
print(x + (y or 20))
end
print(f(10))
]]))
end)

describe("off", function()
it("accepts function calls with missing arguments", util.check([[
--! arity off
local function f(x: integer, y: integer)
print(x + (y or 20))
end
print(f(10))
]]))

it("ignores optional argument annotations", util.check([[
--! arity off
local function f(x: integer, y?: integer)
print(x + y)
end
print(f(10))
]]))
end)

describe("applies locally to a file", function()
it("on then off, with error in 'on'", function()
util.mock_io(finally, {
["r.tl"] = [[
--! arity off
local function f(x: integer, y: integer, z: integer)
print(x + (y or 20))
end
print(f(10))
]]
})
util.check_type_error([[
--! arity on
local function f(x: integer, y: integer)
print(x + y)
end
print(f(10))
local r = require("r")
local function g(x: integer, y: integer, z: integer, w: integer)
print(x + y)
end
print(g(10, 20))
]], {
{ filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" },
{ filename = "foo.tl", y = 15, msg = "wrong number of arguments (given 2, expects 4)" },
})()
end)

it("on then on, with errors in both", function()
util.mock_io(finally, {
["r.tl"] = [[
--! arity on
local function f(x: integer, y: integer, z: integer)
print(x + (y or 20))
end
print(f(10))
]]
})
util.check_type_error([[
--! arity on
local function f(x: integer, y: integer)
print(x + y)
end
print(f(10))
local r = require("r")
local function g(x: integer, y: integer, z: integer, w: integer)
print(x + y)
end
print(g(10, 20))
]], {
{ filename = "r.tl", y = 5, msg = "wrong number of arguments (given 1, expects 3)" },
{ filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" },
{ filename = "foo.tl", y = 15, msg = "wrong number of arguments (given 2, expects 4)" },
})()
end)

it("off then on, with error in 'on'", function()
util.mock_io(finally, {
["r.tl"] = [[
--! arity on
local function f(x: integer, y: integer)
print(x + y)
end
print(f(10))
]]
})
util.check_type_error([[
--! arity off
local r = require("r")
local function f(x: integer, y: integer)
print(x + y)
end
print(f(10))
]], {
{ y = 7, filename = "r.tl", msg = "wrong number of arguments (given 1, expects 2)" }
})()
end)
end)

describe("invalid", function()
it("rejects invalid value", util.check_type_error([[
--! arity invalid_value
local function f(x: integer, y?: integer)
print(x + y)
end
print(f(10))
]], {
{ y = 1, msg = "invalid value for pragma 'arity': invalid_value" }
}))
end)
end)
9 changes: 9 additions & 0 deletions spec/pragma/invalid_spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
local util = require("spec.util")

describe("invalid pragma", function()
it("rejects invalid pragma", util.check_type_error([[
--! invalid_pragma on
]], {
{ y = 1, msg = "invalid pragma: invalid_pragma" }
}))
end)
2 changes: 0 additions & 2 deletions spec/util.lua
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,6 @@ local function check_type_error(lax, code, type_errors, gen_target)
end
local result = tl.type_check(ast, "foo.tl", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat })
local result_type_errors = combine_result(result, "type_errors")

batch_compare(batch, "type errors", type_errors, result_type_errors)
batch:assert()
end
Expand Down Expand Up @@ -489,7 +488,6 @@ end
function util.check_type_error(code, type_errors, gen_target)
assert(type(code) == "string")
assert(type(type_errors) == "table")

return check_type_error(false, code, type_errors, gen_target)
end

Expand Down
44 changes: 44 additions & 0 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,7 @@ end






do
Expand Down Expand Up @@ -1259,6 +1260,9 @@ do
elseif state == "got --" then
if c == "[" then
state = "got --["
elseif c == "!" then
end_token("pragma", "--!")
state = "any"
else
fwd = false
state = "comment short"
Expand Down Expand Up @@ -1890,6 +1894,7 @@ end






local TruthyFact = {}
Expand Down Expand Up @@ -2070,6 +2075,10 @@ local Node = {ExpectedContext = {}, }










Expand Down Expand Up @@ -4123,7 +4132,23 @@ do
return parse_function(ps, i, "record")
end

local function parse_pragma(ps, i)
i = i + 1
local pragma = new_node(ps, i, "pragma")
local pk, pv
i, pk = verify_kind(ps, i, "identifier")
if pk then
pragma.pkey = pk.tk
end
i, pv = verify_kind(ps, i, "identifier")
if pv then
pragma.pvalue = pv.tk
end
return i, pragma
end

local parse_statement_fns = {
["--!"] = parse_pragma,
["::"] = parse_label,
["do"] = parse_do,
["if"] = parse_if,
Expand Down Expand Up @@ -4489,6 +4514,7 @@ local no_recurse_node = {
["break"] = true,
["label"] = true,
["number"] = true,
["pragma"] = true,
["string"] = true,
["boolean"] = true,
["integer"] = true,
Expand Down Expand Up @@ -5385,6 +5411,8 @@ function tl.pretty_print_ast(ast, gen_target, mode)
return out
end,
},
["pragma"] = {},

}

local visit_type = {}
Expand Down Expand Up @@ -11930,6 +11958,22 @@ self:expand_type(node, values, elements) })
return node.newtype
end,
},
["pragma"] = {
after = function(self, node, _children)
if node.pkey == "arity" then
if node.pvalue == "on" then
self.feat_arity = true
elseif node.pvalue == "off" then
self.feat_arity = false
else
return self.errs:invalid_at(node, "invalid value for pragma 'arity': " .. node.pvalue)
end
else
return self.errs:invalid_at(node, "invalid pragma: " .. node.pkey)
end
return NONE
end,
},
["error_node"] = {
after = function(_self, node, _children)
return a_type(node, "invalid", {})
Expand Down
44 changes: 44 additions & 0 deletions tl.tl
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,7 @@ local enum TokenKind
"identifier"
"number"
"integer"
"pragma"
"$ERR unfinished_comment$"
"$ERR invalid_string$"
"$ERR invalid_number$"
Expand Down Expand Up @@ -1259,6 +1260,9 @@ do
elseif state == "got --" then
if c == "[" then
state = "got --["
elseif c == "!" then
end_token("pragma", "--!")
state = "any"
else
fwd = false
state = "comment short"
Expand Down Expand Up @@ -1872,6 +1876,7 @@ local enum NodeKind
"macroexp"
"local_macroexp"
"interface"
"pragma"
"error_node"
end

Expand Down Expand Up @@ -2070,6 +2075,10 @@ local record Node
itemtype: Type
decltuple: TupleType

-- pragma
pkey: string
pvalue: string

opt: boolean

debug_type: Type
Expand Down Expand Up @@ -4123,7 +4132,23 @@ local function parse_record_function(ps: ParseState, i: integer): integer, Node
return parse_function(ps, i, "record")
end

local function parse_pragma(ps: ParseState, i: integer): integer, Node
i = i + 1 -- skip "--!"
local pragma = new_node(ps, i, "pragma")
local pk, pv: Node, Node
i, pk = verify_kind(ps, i, "identifier")
if pk then
pragma.pkey = pk.tk
end
i, pv = verify_kind(ps, i, "identifier")
if pv then
pragma.pvalue = pv.tk
end
return i, pragma
end

local parse_statement_fns: {string : function(ParseState, integer):(integer, Node)} = {
["--!"] = parse_pragma,
["::"] = parse_label,
["do"] = parse_do,
["if"] = parse_if,
Expand Down Expand Up @@ -4489,6 +4514,7 @@ local no_recurse_node: {NodeKind : boolean} = {
["break"] = true,
["label"] = true,
["number"] = true,
["pragma"] = true,
["string"] = true,
["boolean"] = true,
["integer"] = true,
Expand Down Expand Up @@ -5385,6 +5411,8 @@ function tl.pretty_print_ast(ast: Node, gen_target: GenTarget, mode: boolean | P
return out
end,
},
["pragma"] = {
},
}

local visit_type: Visitor<nil, TypeName, Type, Output> = {}
Expand Down Expand Up @@ -11930,6 +11958,22 @@ do
return node.newtype
end,
},
["pragma"] = {
after = function(self: TypeChecker, node: Node, _children: {Type}): Type
if node.pkey == "arity" then
if node.pvalue == "on" then
self.feat_arity = true
elseif node.pvalue == "off" then
self.feat_arity = false
else
return self.errs:invalid_at(node, "invalid value for pragma 'arity': " .. node.pvalue)
end
else
return self.errs:invalid_at(node, "invalid pragma: " .. node.pkey)
end
return NONE
end,
},
["error_node"] = {
after = function(_self: TypeChecker, node: Node, _children: {Type}): Type
return an_invalid(node)
Expand Down

0 comments on commit 59d1f5b

Please sign in to comment.