diff --git a/spec/cli/feat_spec.lua b/spec/cli/feat_spec.lua index 286033387..7741cc8f3 100644 --- a/spec/cli/feat_spec.lua +++ b/spec/cli/feat_spec.lua @@ -43,8 +43,8 @@ local test_cases = { status = 1, match = { "2 errors:", - ":9:22: wrong number of arguments (given 3, expects 2)", - ":19:22: wrong number of arguments (given 3, expects at least 1 and at most 2)", + ":9:22: wrong number of arguments (given 3, expects at most 2)", + ":19:22: wrong number of arguments (given 3, expects at most 2)", } } } diff --git a/spec/pragma/arity_spec.lua b/spec/pragma/arity_spec.lua new file mode 100644 index 000000000..02cf8ccc7 --- /dev/null +++ b/spec/pragma/arity_spec.lua @@ -0,0 +1,237 @@ +local util = require("spec.util") + +describe("pragma arity", function() + describe("on", function() + it("rejects function calls with missing arguments", util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + ]], { + { msg = "wrong number of arguments (given 1, expects 2)" } + })) + + it("accepts optional arguments", util.check([[ + --#pragma arity on + + local function f(x: integer, y?: integer) + print(x + (y or 20)) + end + + print(f(10)) + ]])) + end) + + describe("off", function() + it("accepts function calls with missing arguments", util.check([[ + --#pragma arity off + + local function f(x: integer, y: integer) + print(x + (y or 20)) + end + + print(f(10)) + ]])) + + it("ignores optional argument annotations", util.check([[ + --#pragma arity off + + local function f(x: integer, y?: integer) + print(x + y) + end + + print(f(10)) + ]])) + end) + + describe("no propagation from required module upwards:", function() + it("on then off, with error in 'on'", function() + util.mock_io(finally, { + ["r.tl"] = [[ + --#pragma arity off + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + local r = require("r") + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + ]], { + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 15, msg = "wrong number of arguments (given 2, expects 4)" }, + })() + end) + + it("on then on, with errors in both", function() + util.mock_io(finally, { + ["r.tl"] = [[ + --#pragma arity on + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + local r = require("r") + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + ]], { + { filename = "r.tl", y = 5, msg = "wrong number of arguments (given 1, expects 3)" }, + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 15, msg = "wrong number of arguments (given 2, expects 4)" }, + })() + end) + + it("off then on, with error in 'on'", function() + util.mock_io(finally, { + ["r.tl"] = [[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + ]] + }) + util.check_type_error([[ + --#pragma arity off + + local r = require("r") + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + ]], { + { y = 7, filename = "r.tl", msg = "wrong number of arguments (given 1, expects 2)" } + })() + end) + end) + + describe("does propagate downwards into required module:", function() + it("can trigger errors in required modules", function() + util.mock_io(finally, { + ["r.tl"] = [[ + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + + return { + f = f + } + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + local r = require("r") + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + + r.f(10) + ]], { + { filename = "r.tl", y = 4, msg = "wrong number of arguments (given 1, expects 3)" }, + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 15, msg = "wrong number of arguments (given 2, expects 4)" }, + { filename = "foo.tl", y = 17, msg = "wrong number of arguments (given 1, expects 3)" }, + })() + end) + + it("can be used to load modules with different settings", function() + util.mock_io(finally, { + ["r.tl"] = [[ + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + + return { + f = f + } + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + --#pragma arity off + local r = require("r") + --#pragma arity on + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + + r.f(10) -- no error here! + ]], { + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 17, msg = "wrong number of arguments (given 2, expects 4)" }, + })() + end) + end) + + describe("invalid", function() + it("rejects invalid value", util.check_type_error([[ + --#pragma arity invalid_value + + local function f(x: integer, y?: integer) + print(x + y) + end + + print(f(10)) + ]], { + { y = 1, msg = "invalid value for pragma 'arity': invalid_value" } + })) + end) +end) diff --git a/tl.lua b/tl.lua index 846bb89e8..9fa2f23e8 100644 --- a/tl.lua +++ b/tl.lua @@ -6667,21 +6667,33 @@ function tl.search_module(module_name, search_dtl) return nil, nil, tried end -local function require_module(w, module_name, feat_lax, env) +local function require_module(w, module_name, opts, env) local mod = env.modules[module_name] if mod then return mod, env.module_filenames[module_name] end local found, fd = tl.search_module(module_name, true) - if found and (feat_lax or found:match("tl$")) then + if found and (opts.feat_lax == "on" or found:match("tl$")) then env.module_filenames[module_name] = found env.modules[module_name] = a_type(w, "typedecl", { def = a_type(w, "circular_require", {}) }) + local save_defaults = env.defaults + local defaults = { + feat_lax = opts.feat_lax or save_defaults.feat_lax, + feat_arity = opts.feat_arity or save_defaults.feat_arity, + gen_compat = opts.gen_compat or save_defaults.gen_compat, + gen_target = opts.gen_target or save_defaults.gen_target, + run_internal_compiler_checks = opts.run_internal_compiler_checks or save_defaults.run_internal_compiler_checks, + } + env.defaults = defaults + local found_result, err = tl.process(found, env, fd) assert(found_result, err) + env.defaults = save_defaults + env.modules[module_name] = found_result.type return found_result.type, found @@ -6884,7 +6896,11 @@ tl.new_env = function(opts) if opts.predefined_modules then for _, name in ipairs(opts.predefined_modules) do - local module_type = require_module(w, name, env.defaults.feat_lax == "on", env) + local tc_opts = { + feat_lax = env.defaults.feat_lax, + feat_arity = env.defaults.feat_arity, + } + local module_type = require_module(w, name, tc_opts, env) if module_type.typename == "invalid" then return nil, string.format("Error: could not predefine module '%s'", name) @@ -7157,9 +7173,15 @@ do local function show_arity(f) local nfargs = #f.args.tuple - return f.min_arity < nfargs and - "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) or - tostring(nfargs or 0) + if f.min_arity < nfargs then + if f.min_arity > 0 then + return "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) + else + return (f.args.is_va and "any number" or "at most " .. nfargs) + end + else + return tostring(nfargs or 0) + end end local function resolve_typedecl(t) @@ -8729,7 +8751,11 @@ a.types[i], b.types[i]), } if self.feat_lax and is_unknown(func) then local unk = func - func = a_function(func, { min_arity = 0, args = a_vararg(func, { unk }), rets = a_vararg(func, { unk }) }) + func = a_function(func, { + min_arity = 0, + args = a_vararg(func, { unk }), + rets = a_vararg(func, { unk }), + }) end func = self:to_structural(func) @@ -9367,9 +9393,9 @@ a.types[i], b.types[i]), } end end - function TypeChecker:add_function_definition_for_recursion(node, fnargs) + function TypeChecker:add_function_definition_for_recursion(node, fnargs, feat_arity) self:add_var(nil, node.name.tk, a_function(node, { - min_arity = node.min_arity, + min_arity = feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = fnargs, rets = self.get_rets(node.rets), @@ -10076,7 +10102,7 @@ a.types[i], b.types[i]), } local arg2 = node.e2[2] local msgh = table.remove(b.tuple, 1) local msgh_type = a_function(arg2, { - min_arity = 1, + min_arity = self.feat_arity and 1 or 0, args = a_type(arg2, "tuple", { tuple = { a_type(arg2, "any", {}) } }), rets = a_type(arg2, "tuple", { tuple = {} }), }) @@ -10164,7 +10190,11 @@ a.types[i], b.types[i]), } end local module_name = assert(node.e2[1].conststr) - local t, module_filename = require_module(node, module_name, self.feat_lax, self.env) + local tc_opts = { + feat_lax = self.feat_lax and "on" or "off", + feat_arity = self.feat_arity and "on" or "off", + } + local t, module_filename = require_module(node, module_name, tc_opts, self.env) if t.typename == "invalid" then if not module_filename then @@ -11279,7 +11309,7 @@ self:expand_type(node, values, elements) }) assert(args.typename == "tuple") self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self, node, children) local args = children[2] @@ -11290,7 +11320,7 @@ self:expand_type(node, values, elements) }) self:end_function_scope(node) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11319,7 +11349,7 @@ self:expand_type(node, values, elements) }) self:check_macroexp_arg_use(node.macrodef) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.macrodef.min_arity, + min_arity = self.feat_arity and node.macrodef.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11352,7 +11382,7 @@ self:expand_type(node, values, elements) }) assert(args.typename == "tuple") self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self, node, children) local args = children[2] @@ -11366,7 +11396,7 @@ self:expand_type(node, values, elements) }) end self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11428,7 +11458,7 @@ self:expand_type(node, values, elements) }) end local fn_type = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, is_method = node.is_method, typeargs = node.typeargs, args = args, @@ -11502,7 +11532,7 @@ self:expand_type(node, values, elements) }) self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11528,7 +11558,7 @@ self:expand_type(node, values, elements) }) self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = rets, @@ -12011,7 +12041,18 @@ self:expand_type(node, values, elements) }) end, }, ["pragma"] = { - after = function(_self, _node, _children) + after = function(self, node, _children) + if node.pkey == "arity" then + if node.pvalue == "on" then + self.feat_arity = true + elseif node.pvalue == "off" then + self.feat_arity = false + else + return self.errs:invalid_at(node, "invalid value for pragma 'arity': " .. node.pvalue) + end + else + return self.errs:invalid_at(node, "invalid pragma: " .. node.pkey) + end return NONE end, }, diff --git a/tl.tl b/tl.tl index aaf3ce513..f8fd52bd4 100644 --- a/tl.tl +++ b/tl.tl @@ -6667,21 +6667,33 @@ function tl.search_module(module_name: string, search_dtl: boolean): string, FIL return nil, nil, tried end -local function require_module(w: Where, module_name: string, feat_lax: boolean, env: Env): Type, string +local function require_module(w: Where, module_name: string, opts: TypeCheckOptions, env: Env): Type, string local mod = env.modules[module_name] if mod then return mod, env.module_filenames[module_name] end local found, fd = tl.search_module(module_name, true) - if found and (feat_lax or found:match("tl$") as boolean) then + if found and (opts.feat_lax == "on" or found:match("tl$") as boolean) then env.module_filenames[module_name] = found env.modules[module_name] = a_typedecl(w, a_type(w, "circular_require", {})) + local save_defaults = env.defaults + local defaults : TypeCheckOptions = { + feat_lax = opts.feat_lax or save_defaults.feat_lax, + feat_arity = opts.feat_arity or save_defaults.feat_arity, + gen_compat = opts.gen_compat or save_defaults.gen_compat, + gen_target = opts.gen_target or save_defaults.gen_target, + run_internal_compiler_checks = opts.run_internal_compiler_checks or save_defaults.run_internal_compiler_checks, + } + env.defaults = defaults + local found_result, err: Result, string = tl.process(found, env, fd) assert(found_result, err) + env.defaults = save_defaults + env.modules[module_name] = found_result.type return found_result.type, found @@ -6884,7 +6896,11 @@ tl.new_env = function(opts?: EnvOptions): Env, string if opts.predefined_modules then for _, name in ipairs(opts.predefined_modules) do - local module_type = require_module(w, name, env.defaults.feat_lax == "on", env) + local tc_opts = { + feat_lax = env.defaults.feat_lax, + feat_arity = env.defaults.feat_arity, + } + local module_type = require_module(w, name, tc_opts, env) if module_type is InvalidType then return nil, string.format("Error: could not predefine module '%s'", name) @@ -7157,9 +7173,15 @@ do local function show_arity(f: FunctionType): string local nfargs = #f.args.tuple - return f.min_arity < nfargs - and "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) - or tostring(nfargs or 0) + if f.min_arity < nfargs then + if f.min_arity > 0 then + return "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) + else + return (f.args.is_va and "any number" or "at most " .. nfargs) + end + else + return tostring(nfargs or 0) + end end local function resolve_typedecl(t: Type): Type @@ -8729,7 +8751,11 @@ do -- resolve unknown in lax mode, produce a general unknown function if self.feat_lax and is_unknown(func) then local unk = func - func = a_function(func, { min_arity = 0, args = a_vararg(func, { unk }), rets = a_vararg(func, { unk }) }) + func = a_function(func, { + min_arity = 0, + args = a_vararg(func, { unk }), + rets = a_vararg(func, { unk }) + }) end -- unwrap if tuple, resolve if nominal func = self:to_structural(func) @@ -9367,9 +9393,9 @@ do end end - function TypeChecker:add_function_definition_for_recursion(node: Node, fnargs: TupleType) + function TypeChecker:add_function_definition_for_recursion(node: Node, fnargs: TupleType, feat_arity: boolean) self:add_var(nil, node.name.tk, a_function(node, { - min_arity = node.min_arity, + min_arity = feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = fnargs, rets = self.get_rets(node.rets), @@ -10076,7 +10102,7 @@ do local arg2 = node.e2[2] local msgh = table.remove(b.tuple, 1) local msgh_type = a_function(arg2, { - min_arity = 1, + min_arity = self.feat_arity and 1 or 0, args = a_tuple(arg2, { a_type(arg2, "any", {}) }), rets = a_tuple(arg2, {}) }) @@ -10164,7 +10190,11 @@ do end local module_name = assert(node.e2[1].conststr) - local t, module_filename = require_module(node, module_name, self.feat_lax, self.env) + local tc_opts: TypeCheckOptions = { + feat_lax = self.feat_lax and "on" or "off", + feat_arity = self.feat_arity and "on" or "off", + } + local t, module_filename = require_module(node, module_name, tc_opts, self.env) if t.typename == "invalid" then if not module_filename then @@ -11279,7 +11309,7 @@ do assert(args is TupleType) self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] @@ -11290,7 +11320,7 @@ do self:end_function_scope(node) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11319,7 +11349,7 @@ do self:check_macroexp_arg_use(node.macrodef) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.macrodef.min_arity, + min_arity = self.feat_arity and node.macrodef.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11352,7 +11382,7 @@ do assert(args is TupleType) self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] @@ -11366,7 +11396,7 @@ do end self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11428,7 +11458,7 @@ do end local fn_type = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, is_method = node.is_method, typeargs = node.typeargs, args = args, @@ -11502,7 +11532,7 @@ do self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11528,7 +11558,7 @@ do self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = rets, @@ -12011,7 +12041,18 @@ do end, }, ["pragma"] = { - after = function(_self: TypeChecker, _node: Node, _children: {Type}): Type + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + if node.pkey == "arity" then + if node.pvalue == "on" then + self.feat_arity = true + elseif node.pvalue == "off" then + self.feat_arity = false + else + return self.errs:invalid_at(node, "invalid value for pragma 'arity': " .. node.pvalue) + end + else + return self.errs:invalid_at(node, "invalid pragma: " .. node.pkey) + end return NONE end, },