Skip to content

Commit

Permalink
Merge pull request #458 from JuliaSymbolics/s/fix-broadcast-single
Browse files Browse the repository at this point in the history
fix extruded dim in broadcast when all args have size 1
  • Loading branch information
shashi authored Nov 30, 2021
2 parents 745cbe8 + da095d5 commit 911137e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
12 changes: 11 additions & 1 deletion src/array-lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,19 @@ function Broadcast.materialize(bc::Broadcast.Broadcasted{SymBroadcast})
ndim = mapfoldl(ndims, max, bc.args, init=0)
subscripts = makesubscripts(ndim)

onedim_count = mapreduce(+, bc.args) do x
if ndims(x) != 0
map(i-> isonedim(x, i) ? 1 : 0, 1:ndim)
else
map(i-> 1, 1:ndim)
end
end

extruded = map(x->x < length(bc.args), onedim_count)

expr_args′ = map(bc.args) do x
if ndims(x) != 0
subs = map(i-> isonedim(x, i) ?
subs = map(i-> extruded[i] && isonedim(x, i) ?
1 : subscripts[i], 1:ndims(x))
x[subs...]
else
Expand Down
2 changes: 1 addition & 1 deletion src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ function scalarize(arr::ArrayOp, idx)
contracted = setdiff(iidx, arr.output_idx)

dict = Dict(oi => (unwrap(i) isa Symbolic ? unwrap(i) : axs[oi][i])
for (oi, i) in zip(arr.output_idx, idx))
for (oi, i) in zip(arr.output_idx, idx) if unwrap(oi) isa Symbolic)
partial = replace_by_scalarizing(arr.expr, dict)

axes = [axs[c] for c in contracted]
Expand Down
8 changes: 7 additions & 1 deletion test/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ end

getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue)
@testset "broadcast & scalarize" begin
@variables A[1:5,1:3]=42 b[1:3]=[2, 3, 5] t x[1:4](t)
@variables A[1:5,1:3]=42 b[1:3]=[2, 3, 5] t x[1:4](t) u[1:1]
AA = Symbolics.scalarize(A)
bb = Symbolics.scalarize(b)
@test all(isequal(42), getdef.(AA))
Expand All @@ -66,6 +66,12 @@ getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue)
D = Differential(t)
@test isequal(collect(D.(x) ~ x), map(i->D(x[i]) ~ x[i], eachindex(x)))
@test_throws ArgumentError A ~ t

# #448
@test isequal(Symbolics.scalarize(u + u), [2u[1]])

# #417
@test isequal(Symbolics.scalarize(x', (1,1)), x[1])
end

@testset "Parent" begin
Expand Down

0 comments on commit 911137e

Please sign in to comment.