Skip to content

Commit

Permalink
Merge pull request #1421 from AayushSabharwal/as/build-function-simil…
Browse files Browse the repository at this point in the history
…arto

feat: add `similarto` keyword to `build_function`
  • Loading branch information
ChrisRackauckas authored Jan 29, 2025
2 parents 4d6af94 + 7152f58 commit 046f3b4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ Special Keyword Arguments:
- `force_SA`: Forces the output of the OOP version to be a StaticArray.
Defaults to `false`, and outputs a static array when the first argument
is a static array.
- `similarto`: An `AbstractArray` subtype which controls the type of the
returned array for the OOP version. If provided, it ignores the value of `force_SA`.
- `skipzeros`: Whether to skip filling zeros in the in-place version if the
filling function is 0.
- `fillzeros`: Whether to perform `fill(out,0)` before the calculations to ensure
Expand All @@ -288,6 +290,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
outputidxs=nothing,
skipzeros = false,
force_SA = false,
similarto = nothing,
wrap_code = (nothing, nothing),
fillzeros = skipzeros && !(rhss isa SparseMatrixCSC),
states = LazyState(),
Expand All @@ -301,7 +304,9 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
dargs = map((x) -> destructure_arg(x[2], !checkbounds,
Symbol("ˍ₋arg$(x[1])")), enumerate([args...]))
i = findfirst(x->x isa DestructuredArgs, dargs)
similarto = force_SA ? SArray : i === nothing ? Array : dargs[i].name
if similarto === nothing
similarto = force_SA ? SArray : i === nothing ? Array : dargs[i].name
end

oop, iip = iip_config
oop_body = if oop
Expand Down
9 changes: 9 additions & 0 deletions test/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,12 @@ end
T = value(x .^ 2)
@test_nowarn toexpr(T, NameState())
end

@testset "`similarto` keyword argument" begin
@variables x[1:2]
T = collect(value(x .^ 2))
fn = build_function(T, collect(x); expression = false)[1]
@test_throws MethodError fn((1.0, 2.0))
fn = build_function(T, collect(x); similarto = Array, expression = false)[1]
@test fn((1.0, 2.0)) [1.0, 4.0]
end

0 comments on commit 046f3b4

Please sign in to comment.