diff --git a/Project.toml b/Project.toml index 44eab571f..ae7ab52de 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLBase" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" authors = ["Chris Rackauckas and contributors"] -version = "1.92.5" +version = "1.92.6" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/utils.jl b/src/utils.jl index d013c9ac9..b2c8a57a9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -205,6 +205,19 @@ struct FunctionArgumentsError <: Exception f::Any end +function sig_has_vararg(sig) + isvarargtype(x) = x isa typeof(Vararg) + # unwrap unionalls + while sig isa UnionAll + sig = sig.body + end + if sig <: Tuple + return any(isvarargtype, sig.parameters) + else + error("SciMLBase is messing around with types that it doesn't understand. Please file an issue") + end +end + function Base.showerror(io::IO, e::FunctionArgumentsError) println(io, ARGUMENTS_ERROR_MESSAGE) print(io, "Offending function: ") @@ -255,8 +268,7 @@ function isinplace(f, inplace_param_number, fname = "f", iip_preferred = true; # Find if there's a `f(args...)` dispatch # If so, no error for i in 1:length(nargs) - if nargs[i] < inplace_param_number && - any(isequal(Vararg{Any}), methods(f).ms[1].sig.parameters) + if nargs[i] < inplace_param_number && sig_has_vararg(methods(f).ms[1].sig) # If varargs, assume iip return iip_preferred end @@ -274,8 +286,7 @@ function isinplace(f, inplace_param_number, fname = "f", iip_preferred = true; # Find if there's a `f(args...)` dispatch # If so, no error for i in 1:length(nargs) - if nargs[i] < inplace_param_number && - any(isequal(Vararg{Any}), methods(f).ms[1].sig.parameters) + if nargs[i] < inplace_param_number && sig_has_vararg(methods(f).ms[1].sig) # If varargs, assume iip return iip_preferred end diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index bb5414344..0544c37b0 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -463,3 +463,13 @@ optf(u) = 1.0 optf(u, p) = 1.0 OptimizationFunction(optf) OptimizationProblem(optf, 1.0) + +# Varargs +var1(u...) = 1.0 +var2(u::Vararg{Any, N}) where N = 1.0 +var3(u::Vararg{Int, N}) where N = 1.0 +var4(u::Vararg{Vector{T}, N}) where {T, N} = 1.0 +OptimizationFunction(var1) +OptimizationFunction(var2) +OptimizationFunction(var3) +OptimizationFunction(var4)