Skip to content

Commit

Permalink
fix: concrete type Number => parameterization
Browse files Browse the repository at this point in the history
  • Loading branch information
skyleaworlder committed Feb 5, 2023
1 parent bddbb4b commit 85a6eff
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ function (g::GlobalMeanPool)(x)
end

"""
GlobalLPNormPool(p::Float64)
GlobalLPNormPool(p::T)
Global lp norm pooling layer.
Expand All @@ -636,16 +636,16 @@ by performing lp norm pooling on the complete (w,h)-shaped feature maps.
See also [`LPNormPool`](@ref).
```jldoctest
julia> xs = rand(Float32, 100, 100, 3, 50)
julia> xs = rand(Float32, 100, 100, 3, 50);
julia> m = Chain(Conv((3,3), 3 => 7), GlobalLPNormPool(2.0))
julia> m = Chain(Conv((3,3), 3 => 7), GlobalLPNormPool(2.0));
julia> m(xs) |> size
(1, 1, 7, 50)
```
"""
struct GlobalLPNormPool
p::Float64
struct GlobalLPNormPool{T<:Number}
p::T
end

function (g::GlobalLPNormPool)(x)
Expand Down Expand Up @@ -778,7 +778,7 @@ function Base.show(io::IO, m::MeanPool)
end

"""
LPNormPool(window::NTuple, p::Float64; pad=0, stride=window)
LPNormPool(window::NTuple, p::T; pad=0, stride=window)
Lp norm pooling layer, calculating p-norm distance for each window,
also known as LPPool in pytorch.
Expand All @@ -801,7 +801,7 @@ julia> xs = rand(Float32, 100, 100, 3, 50);
julia> m = Chain(Conv((5,5), 3 => 7), LPNormPool((5,5), 2.0; pad=SamePad()))
Chain(
Conv((5, 5), 3 => 7), # 532 parameters
LPNormPool((5, 5), p=2, pad=2),
LPNormPool((5, 5), 2.0, pad=2),
)
julia> m[1](xs) |> size
Expand All @@ -811,20 +811,20 @@ julia> m(xs) |> size
(20, 20, 7, 50)
julia> layer = LPNormPool((5,), 2.0, pad=2, stride=(3,)) # one-dimensional window
LPNormPool((5,), p=2, pad=2, stride=3)
LPNormPool((5,), 2.0, pad=2, stride=3)
julia> layer(rand(Float32, 100, 7, 50)) |> size
(34, 7, 50)
```
"""
struct LPNormPool{N,M}
struct LPNormPool{N,M,T<:Number}
k::NTuple{N,Int}
p::Float64
p::T
pad::NTuple{M,Int}
stride::NTuple{N,Int}
end

function LPNormPool(k::NTuple{N,Integer}, p::Float64; pad = 0, stride = k) where N
function LPNormPool(k::NTuple{N,Integer}, p::T; pad = 0, stride = k) where {N,T}
stride = expand(Val(N), stride)
pad = calc_padding(LPNormPool, pad, k, 1, stride)
return LPNormPool(k, p, pad, stride)
Expand Down

0 comments on commit 85a6eff

Please sign in to comment.