Skip to content

Commit

Permalink
Merge pull request #1197 from AayushSabharwal/as/vector-num
Browse files Browse the repository at this point in the history
feat: support passing arrays of symbolics to registered functions
  • Loading branch information
ChrisRackauckas authored Jul 30, 2024
2 parents ad852d5 + 550b1c4 commit 09d47f4
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 31 deletions.
26 changes: 13 additions & 13 deletions src/array-lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,12 @@ end
@wrapped function Base.adjoint(A::AbstractMatrix)
@syms i::Int j::Int
@arrayop (i, j) A[j, i] term = A'
end
end false

@wrapped function Base.adjoint(b::AbstractVector)
@syms i::Int
@arrayop (1, i) b[i] term = b'
end
end false

import Base: *, \

Expand Down Expand Up @@ -300,8 +300,8 @@ function _matmul(A, B)
return @arrayop (i, j) A[i, k] * B[k, j] term = (A * B)
end

@wrapped (*)(A::AbstractMatrix, B::AbstractMatrix) = _matmul(A, B)
@wrapped (*)(A::AbstractVector, B::AbstractMatrix) = _matmul(A, B)
@wrapped (*)(A::AbstractMatrix, B::AbstractMatrix) = _matmul(A, B) false
@wrapped (*)(A::AbstractVector, B::AbstractMatrix) = _matmul(A, B) false

function _matvec(A, b)
A = inner_unwrap(A)
Expand All @@ -314,19 +314,19 @@ function _matvec(A, b)
return sym_res
end
end
@wrapped (*)(A::AbstractMatrix, b::AbstractVector) = _matvec(A, b)
@wrapped (*)(A::AbstractMatrix, b::AbstractVector) = _matvec(A, b) false

# specialize `dot` to dispatch on `Symbolic{<:Number}` to eventually work for
# arrays of (possibly unwrapped) Symbolic types, see issue #831
@wrapped LinearAlgebra.dot(x::Number, y::Number) = conj(x) * y
@wrapped LinearAlgebra.dot(x::Number, y::Number) = conj(x) * y false

#################### MAP-REDUCE ################
#

@wrapped Base.map(f, x::AbstractArray) = _map(f, x)
@wrapped Base.map(f, x::AbstractArray, xs...) = _map(f, x, xs...)
@wrapped Base.map(f, x, y::AbstractArray, z...) = _map(f, x, y, z...)
@wrapped Base.map(f, x, y, z::AbstractArray, w...) = _map(f, x, y, z, w...)
@wrapped Base.map(f, x::AbstractArray) = _map(f, x) false
@wrapped Base.map(f, x::AbstractArray, xs...) = _map(f, x, xs...) false
@wrapped Base.map(f, x, y::AbstractArray, z...) = _map(f, x, y, z...) false
@wrapped Base.map(f, x, y, z::AbstractArray, w...) = _map(f, x, y, z, w...) false

function _map(f, x, xs...)
N = ndims(x)
Expand Down Expand Up @@ -368,7 +368,7 @@ end
expr,
g,
Term{Any}(_mapreduce, [f, g, x, dims, (kw...,)]))
end
end false

for (ff, opts) in [sum => (identity, +, false),
prod => (identity, *, true),
Expand All @@ -379,9 +379,9 @@ for (ff, opts) in [sum => (identity, +, false),
@eval @wrapped function (::$(typeof(ff)))(x::AbstractArray;
dims=:, init=$init)
mapreduce($f, $g, x, dims=dims, init=init)
end
end false
@eval @wrapped function (::$(typeof(ff)))(f::Function, x::AbstractArray;
dims=:, init=$init)
mapreduce(f, $g, x, dims=dims, init=init)
end
end false
end
8 changes: 4 additions & 4 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -662,12 +662,12 @@ end
@wrapped function Base.:(\)(A::AbstractMatrix, b::AbstractVecOrMat)
t = array_term(\, A, b)
setmetadata(t, ScalarizeCache, Ref{Any}(nothing))
end
end false

@wrapped function Base.inv(A::AbstractMatrix)
t = array_term(inv, A)
setmetadata(t, ScalarizeCache, Ref{Any}(nothing))
end
end false

_det(x, lp) = det(x, laplace=lp)

Expand All @@ -677,7 +677,7 @@ end

@wrapped function LinearAlgebra.det(x::AbstractMatrix; laplace=true)
Term{eltype(x)}(_det, [x, laplace])
end
end false


# A * x = b
Expand Down Expand Up @@ -754,7 +754,7 @@ function scalarize(arr)
end
end

@wrapped Base.isempty(x::AbstractArray) = shape(unwrap(x)) !== Unknown() && _iszero(length(x))
@wrapped Base.isempty(x::AbstractArray) = shape(unwrap(x)) !== Unknown() && _iszero(length(x)) false
Base.collect(x::Arr) = scalarize(x)
Base.collect(x::SymArray) = scalarize(x)
isarraysymbolic(x) = unwrap(x) isa Symbolic && SymbolicUtils.symtype(unwrap(x)) <: AbstractArray
Expand Down
36 changes: 26 additions & 10 deletions src/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ overwriting.
```
See `@register_array_symbolic` to register functions which return arrays.
"""
macro register_symbolic(expr, define_promotion = true, Ts = :([]))
macro register_symbolic(expr, define_promotion = true, Ts = :([]), wrap_arrays = true)
f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr, Ts)

args′ = map((a, T) -> :($a::$T), argnames, Ts)
ret_type = isnothing(ret_type) ? Real : ret_type

fexpr = :(Symbolics.@wrapped function $f($(args′...))
args = [$(argnames...),]
unwrapped_args = map($unwrap, args)
res = if !any(x->$issym(x) || $iscall(x), unwrapped_args)
unwrapped_args = map($nested_unwrap, args)
res = if !any($is_symbolic_or_array_of_symbolic, unwrapped_args)
$f(unwrapped_args...) # partial-eval if all args are unwrapped
else
$Term{$ret_type}($f, unwrapped_args)
Expand All @@ -42,7 +42,7 @@ macro register_symbolic(expr, define_promotion = true, Ts = :([]))
else
return $wrap(res)
end
end)
end $wrap_arrays)

if define_promotion
fexpr = :($fexpr; (::$typeof($promote_symtype))(::$ftype, args...) = $ret_type)
Expand Down Expand Up @@ -80,8 +80,23 @@ function destructure_registration_expr(expr, Ts)
f, ftype, argnames, Ts, ret_type
end

nested_unwrap(x) = unwrap(x)
nested_unwrap(x::Arr) = unwrap(x)
nested_unwrap(x::AbstractArray) = unwrap.(x)

function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs = :(), define_promotion = true)
function is_symbolic_or_array_of_symbolic(x)
return issym(x) || iscall(x)
end
function is_symbolic_or_array_of_symbolic(arr::AbstractArray)
return any(is_symbolic_or_array_of_symbolic.(arr))
end

symbolic_eltype(x) = eltype(x)
symbolic_eltype(::AbstractArray{symT}) where {eT, symT <: SymbolicUtils.Symbolic{eT}} = eT
symbolic_eltype(::AbstractArray{Num}) = Real
symbolic_eltype(::AbstractArray{symT}) where {eT, symT <: Arr{eT}} = eT

function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs = :(), define_promotion = true, wrap_arrays = true)
def_assignments = MacroTools.rmlines(partial_defs).args
defs = map(def_assignments) do ex
@assert ex.head == :(=)
Expand All @@ -93,8 +108,9 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs
fexpr = quote
@wrapped function $f($(args′...))
args = [$(argnames...),]
unwrapped_args = map($unwrap, args)
res = if !any(x->$issym(x) || $iscall(x), unwrapped_args)
unwrapped_args = map($nested_unwrap, args)
eltype = $symbolic_eltype
res = if !any($is_symbolic_or_array_of_symbolic, unwrapped_args)
$f(unwrapped_args...) # partial-eval if all args are unwrapped
elseif $ret_type == nothing || ($ret_type <: AbstractArray)
$array_term($(Expr(:parameters, [Expr(:kw, k, v) for (k, v) in defs]...)), $f, unwrapped_args...)
Expand All @@ -107,7 +123,7 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs
else
return $wrap(res)
end
end
end $wrap_arrays
end |> esc

if define_promotion
Expand Down Expand Up @@ -177,7 +193,7 @@ overloads for one function, all the rest of the registers must set
`define_promotion` to `false` except for the first one, to avoid method
overwriting.
"""
macro register_array_symbolic(expr, block, define_promotion = true)
macro register_array_symbolic(expr, block, define_promotion = true, wrap_arrays = true)
f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr, :([]))
register_array_symbolic(f, ftype, argnames, Ts, ret_type, block, define_promotion)
register_array_symbolic(f, ftype, argnames, Ts, ret_type, block, define_promotion, wrap_arrays)
end
26 changes: 22 additions & 4 deletions src/wrapper-types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function wraps_type end
has_symwrapper(::Type) = false
is_wrapper_type(::Type) = false

function wrap_func_expr(mod, expr)
function wrap_func_expr(mod, expr, wrap_arrays = true)
@assert expr.head == :function || (expr.head == :(=) &&
expr.args[1] isa Expr &&
expr.args[1].head == :call)
Expand Down Expand Up @@ -109,8 +109,26 @@ function wrap_func_expr(mod, expr)
# expected to be defined outside Symbolics
if arg isa Expr && arg.head == :(::)
T = Base.eval(mod, arg.args[2])
has_symwrapper(T) ? (T, :(Symbolics.SymbolicUtils.Symbolic{<:$T}), wrapper_type(T)) :
Ts = has_symwrapper(T) ? (T, :(Symbolics.SymbolicUtils.Symbolic{<:$T}), wrapper_type(T)) :
(T,:(Symbolics.SymbolicUtils.Symbolic{<:$T}))
if T <: AbstractArray && wrap_arrays
eT = eltype(T)
if eT == Any
eT = Real
end
_arr_type_fn = if hasmethod(ndims, Tuple{Type{T}})
(elT) -> :(AbstractArray{T, $(ndims(T))} where {T <: $elT})
else
(elT) -> :(AbstractArray{T} where {T <: $elT})
end
if has_symwrapper(eT)
Ts = (Ts..., # _arr_type_fn(:(Symbolics.SymbolicUtils.Symbolic{<:$eT})),
_arr_type_fn(wrapper_type(eT)))
# else
# Ts = (Ts..., _arr_type_fn(:(Symbolics.SymbolicUtils.Symbolic{<:$eT})))
end
end
Ts
elseif arg isa Expr && arg.head == :(...)
Ts = type_options(arg.args[1])
map(x->Vararg{x},Ts)
Expand Down Expand Up @@ -153,6 +171,6 @@ function wrap_func_expr(mod, expr)
end |> esc
end

macro wrapped(expr)
wrap_func_expr(__module__, expr)
macro wrapped(expr, wrap_arrays = true)
wrap_func_expr(__module__, expr, wrap_arrays)
end
20 changes: 20 additions & 0 deletions test/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ let
@test eltype(gg) == Real
@test symtype(unwrap(gg)) == SymMatrix{Real, 2}
@test promote_symtype(ggg, symtype(unwrap(x))) == Any # no promote_symtype defined

gg = ggg([a, 2a])
@test ndims(gg) == 2
@test size(gg) == (4, 4)
@test eltype(gg) == Real
@test symtype(unwrap(gg)) == SymMatrix{Real, 2}
@test promote_symtype(ggg, Vector{symtype(typeof(a))}) == Any
end
let
# redefine with promote_symtype
Expand Down Expand Up @@ -64,6 +71,12 @@ hh = ccwa(gg, x)
@test eltype(hh) == Real
@test isequal(arguments(unwrap(hh)), unwrap.([gg, x]))

_args = [[a 2a; 4a 6a; 3a 5a], [4a, 6a]]
hh = ccwa(_args...)
@test size(hh) == (3, 2, 10)
@test eltype(hh) == Real
@test isequal(arguments(unwrap(hh)), unwrap.(_args))

@test all(t->getsource(t)[1] === :variables, many_vars)
@test getdefaultval(t) == 0
@test getdefaultval(a) == 1
Expand Down Expand Up @@ -218,3 +231,10 @@ yyy = yy(t)
@test !isequal(yyy, y)
@variables y(..)
@test isequal(yyy, y(t))

spam(x) = 2x
@register_symbolic spam(x::AbstractArray)

sym = spam([a, 2a])
@test sym isa Num
@test unwrap(sym) isa BasicSymbolic{Real}

0 comments on commit 09d47f4

Please sign in to comment.