diff --git a/src/layers/basic.jl b/src/layers/basic.jl index f0172f83c8..52d8256782 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -687,9 +687,6 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini (m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x) (m::Embedding)(x::AbstractArray{<:Integer}) = reshape(m(vec(x)), :, size(x)...) -(m::Embedding)(x::Nil) = similar(m.weight, Nil, size(m.weight, 1)) -(m::Embedding)(x::AbstractArray{Nil}) = similar(m.weight, Nil, size(m.weight, 1), size(x)...) - function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L} size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L")) return m(onecold(x)) diff --git a/src/outputsize.jl b/src/outputsize.jl index f2bacca2ac..b9ebedf984 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -163,6 +163,9 @@ end ## fixes for layers that don't work out of the box +(m::Embedding)(x::Nil) = similar(m.weight, Nil, size(m.weight, 1)) +(m::Embedding)(x::AbstractArray{Nil}) = similar(m.weight, Nil, size(m.weight, 1), size(x)...) + for (fn, Dims) in ((:conv, DenseConvDims),) @eval begin function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{Nil}, dims::$Dims)