diff --git a/src/gather.jl b/src/gather.jl index 3d8a8949d..cc47a607d 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -20,15 +20,28 @@ or multiple `dst` columns. See [`gather`](@ref) for an allocating version. """ -function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) - dims = scatter_dims(src, dst, idx) - colons = ntuple(i -> Colon(), dims) +# function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) +# dims = scatter_dims(src, dst, idx) +# colons = ntuple(i -> Colon(), dims) +# for k in CartesianIndices(idx) +# _view(dst, colons, k) .= _view(src, colons, idx[k]) +# end +# return dst +# end + +""" +dst[:, ... , k, :,...] .= src[:, ... , idx[k]..., :,...] +""" +function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray; dims = nothing) + nbefore, nafter = scatter_dims(src, dst, idx, dims) + colbefore = ntuple(i -> Colon(), nbefore) + colafter = ntuple(i -> Colon(), nafter) for k in CartesianIndices(idx) - _view(dst, colons, k) .= _view(src, colons, idx[k]) + _view(dst, colbefore, k, colafter) .= _view(src, colbefore, idx[k], colafter) end - return dst end + """ NNlib.gather(src, idx) -> dst diff --git a/src/scatter.jl b/src/scatter.jl index 88bb42c65..3fadf8f3f 100644 --- a/src/scatter.jl +++ b/src/scatter.jl @@ -17,25 +17,41 @@ typelength(::Type{CartesianIndex{M}}) where M = M Performs dimensional consistency checks and return the dimensionality of the scattered objects. """ + function scatter_dims(X::AbstractArray{Tx,Nx}, Y::AbstractArray{Ty,Ny}, - idx::AbstractArray{Tidx,Nidx}) where {Tx,Ty,Tidx,Nx,Ny,Nidx} - M = typelength(Tidx) - dims = scatter_dims(Nx, Ny, M, Nidx) - size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes.")) - size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) - return dims + idx::AbstractArray{Tidx,Nidx}, + dims::Union{Nothing, Integer} = nothing) where {Tx,Ty,Tidx,Nx,Ny,Nidx} + nsrcin = typelength(Tidx) + ndstin = Nidx + nbefore, nafter = scatter_dims(Nx, Ny, nsrcin, ndstin, dims) + size(Y)[1:nbefore] == size(X)[1:nbefore] || throw(ArgumentError("Incompatible input shapes.")) + size(Y)[nbefore+1:nbefore+ndstin] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) + size(Y)[nbefore+ndstin+1:end] == size(X)[nbefore+nsrcin+1:end] || throw(ArgumentError("Incompatible input shapes.")) + return nbefore, nafter end -function scatter_dims(Nx, Ny, M, Nidx) - @assert Nx - M == Ny - Nidx "Incompatible input shapes of (dst, src, idx) = ($Nx, $Ny, $Nidx)." - dims = Nx - M - dims < 0 && throw(ArgumentError("dims must be non-negative but got dims=$dims.")) - return dims +function scatter_dims(Nx, Ny, nsrcin, ndstin, dims = nothing) + @assert Nx - nsrcin == Ny - ndstin "Incompatible input shapes of (dst, src, idx, Tidx) = ($Nx, $Ny, $ndstin, $nsrcin)." + + if dims === nothing + nbefore = Nx - nsrcin + nbefore < 0 && throw(ArgumentError("nbefore must be non-negative but got $nbefore.")) + nafter = 0 + return nbefore, nafter + else + nbefore = dims - 1 + nafter = Ny - ndstin - nbefore + nbefore < 0 && throw(ArgumentError("nbefore must be non-negative but got $nbefore.")) + nafter < 0 && throw(ArgumentError("nafter must be non-negative but got $nafter.")) + return nbefore, nafter + end end _view(X, colons, k) = view(X, colons..., k...) _view(X, colons, k::Union{Integer, CartesianIndex}) = view(X, colons..., k) +_view(X, colbefore, k, colafter) = view(X, colbefore..., k..., colafter...) +_view(X, colbefore, k::Union{Integer, CartesianIndex}, colafter) = view(X, colbefore..., k, colafter...) """ NNlib.scatter!(op, dst, src, idx) @@ -78,7 +94,7 @@ function scatter!(op::OP, dst::AbstractArray, src::AbstractArray, idx::AbstractA src_v = _view(src, colons, k) dst_v .= (op).(dst_v, src_v) end - dst + return dst end function scatter!(op::typeof(mean), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) diff --git a/test/gather.jl b/test/gather.jl index eb6b8f6f9..88cf94b49 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -149,3 +149,18 @@ end gradtest(xs -> gather!(dst, xs, index), src) gradtest(xs -> gather(xs, index), src) end + +using NNlib: gather!, gather + +@testset "gather! dims" begin + + src = reshape([1:15;], 3, 5) + index = [1, 3, 2, 2] + dst = zeros(Int, 4, 5) + gather!(dst, src, index, dims=1) + + @test dst == [1 4 7 10 13 + 3 6 9 12 15 + 2 5 8 11 14 + 2 5 8 11 14] +end