diff --git a/ext/SymbolicsLuxExt.jl b/ext/SymbolicsLuxExt.jl index b4dfcab51..60e748950 100644 --- a/ext/SymbolicsLuxExt.jl +++ b/ext/SymbolicsLuxExt.jl @@ -5,13 +5,15 @@ using Symbolics using Lux.LuxCore using Symbolics.SymbolicUtils -function Lux.NilSizePropagation.recursively_nillify(x::SymbolicUtils.BasicSymbolic{<:Vector{<:Real}}) - Lux.NilSizePropagation.recursively_nillify(Symbolics.wrap(x)) +@static if isdefined(Lux.NilSizePropagation, :recursively_nillify) + function Lux.NilSizePropagation.recursively_nillify(x::SymbolicUtils.BasicSymbolic{<:Vector{<:Real}}) + Lux.NilSizePropagation.recursively_nillify(Symbolics.wrap(x)) + end end @register_array_symbolic LuxCore.stateless_apply( model::LuxCore.AbstractLuxLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin - size = LuxCore.outputsize(model, x, LuxCore.Random.default_rng()) + size = LuxCore.outputsize(model, Symbolics.wrap(x), LuxCore.Random.default_rng()) eltype = Real end diff --git a/src/utils.jl b/src/utils.jl index 5fffae8a5..17d91a971 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -62,7 +62,7 @@ function is_singleton(e) op = operation(e) op === getindex && return true iscall(op) && return is_singleton(op) # recurse to reach getindex for array element variables - return issym(op) + return issym(op) && !hasmetadata(e, CallWithParent) else return issym(e) end @@ -76,6 +76,7 @@ function get_variables!(vars, e::Symbolic, varlist=nothing) push!(vars, e) end else + get_variables!(vars, operation(e), varlist) foreach(x -> get_variables!(vars, x, varlist), arguments(e)) end return (vars isa AbstractVector) ? unique!(vars) : vars diff --git a/test/utils.jl b/test/utils.jl index 727c53366..709024da5 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -19,6 +19,14 @@ using Symbolics: symbolic_to_float, var_from_nested_derivative, unwrap sorted_vars2 = Symbolics.get_variables(ex2; sort = true) @test isequal(sorted_vars2, [x, y]) + + @variables c(..) + ex3 = c(x) + c(t) - c(c(t) + y) + vars3 = Symbolics.get_variables(ex3) + @test length(vars3) == 4 + + sorted_vars3 = Symbolics.get_variables(ex3; sort = true) + @test isequal(sorted_vars3, [c.f, t, x, y]) end @testset "symbolic_to_float" begin