Skip to content

Commit

Permalink
Merge pull request #1361 from AayushSabharwal/as/get-variables-callable
Browse files Browse the repository at this point in the history
fix: fix searching inside callable symbolics in `get_variables!`
  • Loading branch information
ChrisRackauckas authored Nov 15, 2024
2 parents 0dd679d + 0514859 commit 3485fbb
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
8 changes: 5 additions & 3 deletions ext/SymbolicsLuxExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ using Symbolics
using Lux.LuxCore
using Symbolics.SymbolicUtils

function Lux.NilSizePropagation.recursively_nillify(x::SymbolicUtils.BasicSymbolic{<:Vector{<:Real}})
Lux.NilSizePropagation.recursively_nillify(Symbolics.wrap(x))
@static if isdefined(Lux.NilSizePropagation, :recursively_nillify)
function Lux.NilSizePropagation.recursively_nillify(x::SymbolicUtils.BasicSymbolic{<:Vector{<:Real}})
Lux.NilSizePropagation.recursively_nillify(Symbolics.wrap(x))
end
end

@register_array_symbolic LuxCore.stateless_apply(
model::LuxCore.AbstractLuxLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin
size = LuxCore.outputsize(model, x, LuxCore.Random.default_rng())
size = LuxCore.outputsize(model, Symbolics.wrap(x), LuxCore.Random.default_rng())
eltype = Real
end

Expand Down
3 changes: 2 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function is_singleton(e)
op = operation(e)
op === getindex && return true
iscall(op) && return is_singleton(op) # recurse to reach getindex for array element variables
return issym(op)
return issym(op) && !hasmetadata(e, CallWithParent)
else
return issym(e)
end
Expand All @@ -76,6 +76,7 @@ function get_variables!(vars, e::Symbolic, varlist=nothing)
push!(vars, e)
end
else
get_variables!(vars, operation(e), varlist)
foreach(x -> get_variables!(vars, x, varlist), arguments(e))
end
return (vars isa AbstractVector) ? unique!(vars) : vars
Expand Down
8 changes: 8 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ using Symbolics: symbolic_to_float, var_from_nested_derivative, unwrap

sorted_vars2 = Symbolics.get_variables(ex2; sort = true)
@test isequal(sorted_vars2, [x, y])

@variables c(..)
ex3 = c(x) + c(t) - c(c(t) + y)
vars3 = Symbolics.get_variables(ex3)
@test length(vars3) == 4

sorted_vars3 = Symbolics.get_variables(ex3; sort = true)
@test isequal(sorted_vars3, [c.f, t, x, y])
end

@testset "symbolic_to_float" begin
Expand Down

0 comments on commit 3485fbb

Please sign in to comment.