Skip to content

Commit

Permalink
Remove unused type parameter N
Browse files Browse the repository at this point in the history
  • Loading branch information
ldeso committed May 31, 2024
1 parent 9ed484f commit fc020c9
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,8 @@ This layer is often used to store word embeddings and retrieve them using indice
end

function Embedding(
(in_dims, out_dims)::Pair{<:Union{Integer, NTuple{N, <:Integer}}, <:Integer};
init_weight=randn32) where {N}
(in_dims, out_dims)::Pair{<:Union{Integer, NTuple{<:Any, <:Integer}}, <:Integer};
init_weight=randn32)
return Embedding(in_dims, out_dims, init_weight)
end

Expand All @@ -508,17 +508,15 @@ end
function (e::Embedding)(x::AbstractArray{<:Integer}, ps, st::NamedTuple)
return reshape(e(vec(x), ps, st)[1], :, size(x)...), st
end
function (e::Embedding)(x::NTuple{N, <:Integer}, ps, st::NamedTuple) where {N}
function (e::Embedding)(x::NTuple{<:Any, <:Integer}, ps, st::NamedTuple)
view(ps.weight, :, x...), st
end
function (e::Embedding)(
x::NTuple{N, <:AbstractVector{<:Integer}}, ps, st::NamedTuple) where {N}
function (e::Embedding)(x::NTuple{<:Any, <:AbstractVector{<:Integer}}, ps, st::NamedTuple)
sizes = size.(x)
@argcheck allequal(sizes) DimensionMismatch("Input vectors must have the same shape")
return NNlib.gather(ps.weight, x...), st
end
function (e::Embedding)(
x::NTuple{N, <:AbstractArray{<:Integer}}, ps, st::NamedTuple) where {N}
function (e::Embedding)(x::NTuple{<:Any, <:AbstractArray{<:Integer}}, ps, st::NamedTuple)
sizes = size.(x)
@argcheck allequal(sizes) DimensionMismatch("Input arrays must have the same shape")
return reshape(e(vec.(x), ps, st)[1], :, first(sizes)...), st
Expand Down

0 comments on commit fc020c9

Please sign in to comment.