From fc020c9f925f3e6108b7eaa7b6a13baa1ac9b891 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20de=20Souza?= Date: Fri, 31 May 2024 09:58:36 +0200 Subject: [PATCH] Remove unused type parameter N --- src/layers/basic.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index a5a279f80..5b8831d42 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -492,8 +492,8 @@ This layer is often used to store word embeddings and retrieve them using indice end function Embedding( - (in_dims, out_dims)::Pair{<:Union{Integer, NTuple{N, <:Integer}}, <:Integer}; - init_weight=randn32) where {N} + (in_dims, out_dims)::Pair{<:Union{Integer, NTuple{<:Any, <:Integer}}, <:Integer}; + init_weight=randn32) return Embedding(in_dims, out_dims, init_weight) end @@ -508,17 +508,15 @@ end function (e::Embedding)(x::AbstractArray{<:Integer}, ps, st::NamedTuple) return reshape(e(vec(x), ps, st)[1], :, size(x)...), st end -function (e::Embedding)(x::NTuple{N, <:Integer}, ps, st::NamedTuple) where {N} +function (e::Embedding)(x::NTuple{<:Any, <:Integer}, ps, st::NamedTuple) view(ps.weight, :, x...), st end -function (e::Embedding)( - x::NTuple{N, <:AbstractVector{<:Integer}}, ps, st::NamedTuple) where {N} +function (e::Embedding)(x::NTuple{<:Any, <:AbstractVector{<:Integer}}, ps, st::NamedTuple) sizes = size.(x) @argcheck allequal(sizes) DimensionMismatch("Input vectors must have the same shape") return NNlib.gather(ps.weight, x...), st end -function (e::Embedding)( - x::NTuple{N, <:AbstractArray{<:Integer}}, ps, st::NamedTuple) where {N} +function (e::Embedding)(x::NTuple{<:Any, <:AbstractArray{<:Integer}}, ps, st::NamedTuple) sizes = size.(x) @argcheck allequal(sizes) DimensionMismatch("Input arrays must have the same shape") return reshape(e(vec.(x), ps, st)[1], :, first(sizes)...), st