-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix sum(bc::Broadcasted; dims = 1, init = 0)
#43736
Conversation
The only usage of Lines 248 to 261 in 8404e45
|
has_fast_linear_indexing
rely on IndexStyle
/ndims
sum(bc::Broadcasted; dims = 1, init = 0)
julia> using Base.Broadcast: broadcasted, instantiate
julia> bc = instantiate(broadcasted(+, [1 2], (3, 4)))
Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}}(+, ([1 2], (3, 4)))
julia> sum(bc)
20
julia> sum(bc; dims=1, init=0)
ERROR: MethodError: no method matching has_fast_linear_indexing(::Tuple{Int64, Int64}) Broadcasts which materialise to a tuple still fail, elsewhere. Maybe that's OK since you can't sum a tuple with julia> Base.has_fast_linear_indexing(_) = false # or the PR, I think
julia> bc2 = instantiate(broadcasted(+, 1, (2, 3, 4)))
Base.Broadcast.Broadcasted(+, (1, (2, 3, 4)))
julia> sum(bc2; dims=1, init=0)
ERROR: MethodError: no method matching similar(::Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(+), Tuple{Int64, Tuple{Int64, Int64, Int64}}}, ::Type{Int64}, ::Tuple{Base.OneTo{Int64}})
julia> copy(bc2)
(3, 4, 5)
julia> sum(ans; dims=1, init=0)
ERROR: MethodError: no method matching mapfoldl(::typeof(identity), ::typeof(Base.add_sum), ::Tuple{Int64, Int64, Int64}; dims::Int64, init::Int64) The array .+ tuple case came up in JuliaDiff/ChainRules.jl#705 |
It works for me though: julia> using Base.Broadcast: broadcasted, instantiate
julia> bc = instantiate(broadcasted(+, [1 2], (3, 4)))
Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}}(+, ([1 2], (3, 4)))
julia> sum(bc)
20
julia> sum(bc; dims=1, init=0)
1×2 Matrix{Int64}:
9 11 I guess the patch didn't work because of the old |
Oh I'm an idiot, sorry, that's right. The PR removes the recursive method which caused the problem. |
This PR make `has_fast_linear_indexing` rely on `IndexStyle`/`ndims` to fix `mapreduce` for `Broadcasted` with `dim > 1`. Before: ```julia julia> a = randn(100,100); julia> bc = Broadcast.instantiate(Base.broadcasted(+,a,a)); julia> sum(bc,dims = 1,init = 0.0) == sum(collect(bc), dims = 1) ERROR: MethodError: no method matching LinearIndices(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(+), Tuple{Matrix{Float64}, Matrix{Float64}}}) ``` After: ```julia julia> sum(bc,dims = 1,init = 0.0) == sum(collect(bc), dims = 1) true ``` This should extend the optimized fallback to more `AbstractArray`. (e.g. `SubArray`) Test added.
This PR make
has_fast_linear_indexing
rely onIndexStyle
/ndims
to fixmapreduce
forBroadcasted
withdim > 1
.Before:
After:
This should extend the optimized fallback to more
AbstractArray
. (e.g.SubArray
)Test added.