Skip to content

Commit

Permalink
move code
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 16, 2022
1 parent 3f57415 commit 57b38a8
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
3 changes: 0 additions & 3 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -687,9 +687,6 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini
(m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractArray{<:Integer}) = reshape(m(vec(x)), :, size(x)...)

(m::Embedding)(x::Nil) = similar(m.weight, Nil, size(m.weight, 1))
(m::Embedding)(x::AbstractArray{Nil}) = similar(m.weight, Nil, size(m.weight, 1), size(x)...)

function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L}
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
return m(onecold(x))
Expand Down
3 changes: 3 additions & 0 deletions src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ end

## fixes for layers that don't work out of the box

(m::Embedding)(x::Nil) = similar(m.weight, Nil, size(m.weight, 1))
(m::Embedding)(x::AbstractArray{Nil}) = similar(m.weight, Nil, size(m.weight, 1), size(x)...)

for (fn, Dims) in ((:conv, DenseConvDims),)
@eval begin
function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{Nil}, dims::$Dims)
Expand Down

0 comments on commit 57b38a8

Please sign in to comment.