From 4fc85606c7078f3c188f06320e504bb2d2e54d3c Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 26 Dec 2020 20:13:05 +0100 Subject: [PATCH 1/9] dense keyword handling + docstring --- src/layers/basic.jl | 45 ++++++++++++++++++++++++++++---------------- test/layers/basic.jl | 3 +++ 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 38a5a1eef9..74ce2f217a 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -69,25 +69,22 @@ extraChain(::Tuple{}, x) = () """ - Dense(in, out, σ=identity; initW=glorot_uniform, initb=zeros, bias=true) - Dense(W, b, σ=identity) + Dense(in, out, σ=identity; bias=true) + Dense(W::AbstractMatrix, bias, [σ]) -Create a traditional `Dense` layer with in×out weight matrix `W` and -bias vector `b` of length `out`. The forward pass is given by: +Create a traditional `Dense` layer, whose forward pass is given by: y = σ.(W * x .+ b) -The input `x` must be a vector of length `in`, a batch of vectors represented -as an `in × N` matrix, or a higher order tensor where all dimensions -after the first one will be treated as batch dimensions. +The input `x` should be a vector of length `in`, or batch of vectors represented +as an `in × N` matrix, or any array with `size(x,1) == in`. +The out `y` will be a vector of length `out`, or a batch with +`size(y) == (out, size(x)[2:end]...)` -The out `y` will be a vector of length `out` or a batch whose first -dimension is `out` and the remaining dimensions are the same as in the input. - -Setting `bias` to `false` will switch the bias off for the layer. - -`initW` and `initb` are callables used to initialize weights and biases respectively, -through the calls `initW(out, in)` and `initb(out)`. +Keyword `bias=false` will switch off trainable bias for the layer. +The initialisation of the weight matrix is `W = init(out, in)`, controlled by +another keyword, with default `init=glorot_uniform`. The weight matrix and the +weight vector may also be provided explicitly. # Examples @@ -113,8 +110,24 @@ end Dense(W, b) = Dense(W, b, identity) function Dense(in::Integer, out::Integer, σ = identity; - initW = glorot_uniform, initb = zeros, bias=true) - return Dense(initW(out, in), create_bias(bias, initb, out), σ) + initW = nothing, initb = nothing, + init = glorot_uniform, bias=true) + + W = if initW !== nothing + @warn "keyword initW is deprecated, please use init" maxlog=1 _id=hash(initW) + initW(out, in) + else + init(out, in) + end + + b = if bias === true && initb !== nothing + @warn "keyword initb is deprecated, please simply supply the " maxlog=1 _id=hash(initW) + create_bias(bias, initb, out) + else + create_bias(bias, zeros, out) + end + + return Dense(W, b, σ) end @functor Dense diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 073182c03c..0f6014784f 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -46,11 +46,14 @@ import Flux: activations @test size(Dense(10, 5)(randn(10,2,3,4))) == (5,2,3,4) end @testset "zeros" begin + # old keywords @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1) @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2) @test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1) @test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] @test Dense(10, 2, identity, initW = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] + # new + @test Dense(10, 1, identity, init = ones)(ones(10,2)) == 10*ones(1, 2) end end From 31b9411c885bcde010b7127d858ee9c61788965e Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 26 Dec 2020 20:18:36 +0100 Subject: [PATCH 2/9] fixes, add news --- NEWS.md | 1 + src/layers/basic.jl | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/NEWS.md b/NEWS.md index f83f3d681b..9b4bc15963 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,7 @@ * Added [Focal Loss function](https://github.com/FluxML/Flux.jl/pull/1489) to Losses module * The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405). * Dense and Conv layers no longer perform [implicit type conversion](https://github.com/FluxML/Flux.jl/pull/1394). +* The keyword `initW` is of Dense layers is now `init`, to agree with convolutional layers. * Excise datasets in favour of other providers in the julia ecosystem. * Added option to set `bias` to [false](https://github.com/FluxML/Flux.jl/pull/1379) to eliminating `bias` from being trained. * Add [CTC loss function](https://github.com/FluxML/Flux.jl/pull/1287) to Losses module diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 74ce2f217a..eee9d291f7 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -107,7 +107,7 @@ struct Dense{F,S<:AbstractArray,T<:Union{Zeros, AbstractVector}} σ::F end -Dense(W, b) = Dense(W, b, identity) +Dense(W, bias) = Dense(W, create_bias(bias, zeros, size(W,1)), identity) function Dense(in::Integer, out::Integer, σ = identity; initW = nothing, initb = nothing, @@ -121,7 +121,7 @@ function Dense(in::Integer, out::Integer, σ = identity; end b = if bias === true && initb !== nothing - @warn "keyword initb is deprecated, please simply supply the " maxlog=1 _id=hash(initW) + @warn "keyword initb is deprecated, please simply supply the " maxlog=1 _id=hash(initb) create_bias(bias, initb, out) else create_bias(bias, zeros, out) From e4b11432c0644862d5f4b71be6e6dc4f3074dcbd Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 12 Feb 2021 21:44:11 +0100 Subject: [PATCH 3/9] squash pr 1440 Update Dense, Bilinear and Diagonal layers to match Conv in how they handle keywords. --- src/deprecations.jl | 11 +++ src/layers/basic.jl | 160 ++++++++++++++++++++++++++----------------- src/layers/conv.jl | 16 ++--- src/utils.jl | 27 +++++--- src/zeros.jl | 3 + test/layers/basic.jl | 50 +++++++++++--- test/layers/conv.jl | 5 ++ test/outputsize.jl | 4 ++ test/utils.jl | 15 ++-- 9 files changed, 191 insertions(+), 100 deletions(-) diff --git a/src/deprecations.jl b/src/deprecations.jl index 78eb55b733..91ab097092 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -7,3 +7,14 @@ @deprecate Conv(; weight, bias, activation=identity, kws...) Conv(weight, bias, activation; kws...) @deprecate ConvTranspose(; weight, bias, activation=identity, kws...) ConvTranspose(weight, bias, activation; kws...) @deprecate DepthwiseConv(; weight, bias, activation=identity, kws...) DepthwiseConv(weight, bias, activation; kws...) + +function Base.getproperty(a::Dense, s::Symbol) + if s === :W + Base.depwarn("field name dense.W is deprecated in favour of dense.weight", :Dense) + return getfield(a, :weight) + elseif s === :b + Base.depwarn("field name dense.b is deprecated in favour of dense.bias", :Dense) + return getfield(a, :bias) + end + return getfield(a, s) +end diff --git a/src/layers/basic.jl b/src/layers/basic.jl index eee9d291f7..9bf3f72364 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -69,12 +69,12 @@ extraChain(::Tuple{}, x) = () """ - Dense(in, out, σ=identity; bias=true) - Dense(W::AbstractMatrix, bias, [σ]) + Dense(in, out, σ=identity; bias=true, init=glorot_uniform) + Dense(W::AbstractMatrix, [bias, σ]) Create a traditional `Dense` layer, whose forward pass is given by: - y = σ.(W * x .+ b) + y = σ.(W * x .+ bias) The input `x` should be a vector of length `in`, or batch of vectors represented as an `in × N` matrix, or any array with `size(x,1) == in`. @@ -82,49 +82,59 @@ The out `y` will be a vector of length `out`, or a batch with `size(y) == (out, size(x)[2:end]...)` Keyword `bias=false` will switch off trainable bias for the layer. -The initialisation of the weight matrix is `W = init(out, in)`, controlled by -another keyword, with default `init=glorot_uniform`. The weight matrix and the -weight vector may also be provided explicitly. +The initialisation of the weight matrix is `W = init(out, in)`, calling the function +given to keyword `init`, with default [`glorot_uniform`](@doc Flux.glorot_uniform). +The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly. # Examples - -```julia-repl +```jldoctest julia> d = Dense(5, 2) Dense(5, 2) -julia> d(rand(Float32, 5)) -2-element Array{Float32,1}: - -0.16210233 - 0.123119034 +julia> d(rand(Float32, 5, 64)) |> size +(2, 64) -julia> d = Dense(5, 2; bias=false) -Dense(5, 2) +julia> d(rand(Float32, 5, 1, 1, 64)) |> size +(2, 1, 1, 64) + +julia> d1 = Dense(ones(2, 5), false, tanh) +Dense(5, 2, tanh; bias=false) + +julia> d1(ones(5)) +2-element Array{Float64,1}: + 0.9999092042625951 + 0.9999092042625951 + +julia> Flux.params(d1) # no trainable bias +Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]]) ``` """ -struct Dense{F,S<:AbstractArray,T<:Union{Zeros, AbstractVector}} - W::S - b::T +struct Dense{F, M<:AbstractMatrix, B} + weight::M + bias::B σ::F + function Dense(W::M, bias = true, σ::F = identity) where {M<:AbstractMatrix, F} + b = create_bias(W, bias, size(W,1)) + new{F,M,typeof(b)}(W, b, σ) + end end -Dense(W, bias) = Dense(W, create_bias(bias, zeros, size(W,1)), identity) - function Dense(in::Integer, out::Integer, σ = identity; initW = nothing, initb = nothing, init = glorot_uniform, bias=true) W = if initW !== nothing - @warn "keyword initW is deprecated, please use init" maxlog=1 _id=hash(initW) + Base.depwarn("keyword initW is deprecated, please use init (which similarly accepts a funtion like randn)", :Dense) initW(out, in) else init(out, in) end b = if bias === true && initb !== nothing - @warn "keyword initb is deprecated, please simply supply the " maxlog=1 _id=hash(initb) - create_bias(bias, initb, out) + Base.depwarn("keyword initb is deprecated, please simply supply the bias vector, bias=initb(out)", :Dense) + initb(out) else - create_bias(bias, zeros, out) + bias end return Dense(W, b, σ) @@ -133,41 +143,52 @@ end @functor Dense function (a::Dense)(x::AbstractArray) - W, b, σ = a.W, a.b, a.σ + W, b, σ = getfield(a, :weight), getfield(a, :bias), getfield(a, :σ) sz = size(x) - x = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions + x = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions x = σ.(W*x .+ b) return reshape(x, :, sz[2:end]...) end function Base.show(io::IO, l::Dense) - print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1)) + print(io, "Dense(", size(l.weight, 2), ", ", size(l.weight, 1)) l.σ == identity || print(io, ", ", l.σ) + l.bias == Zeros() && print(io, "; bias=false") print(io, ")") end """ Diagonal(α, β) - Diagonal(sz::Integer...; initα=ones, initβ=zeros) + Diagonal(size::Integer...) -Create an element-wise linear layer with learnable -arrays `α` and `β` of size `sz`. The layer performs +Create an element-wise linear layer, which performs y = α .* x .+ β -The input `x` must have size broadcast-compatible with `α` and `β`. -The parameters will be created with the calls -`α = initα(sz)` and `β = initβ(sz)`. +The learnable arrays are initialised `α = ones(Float32, size)` and +`β = zeros(Float32, size)`. + +Used by [`LayerNorm`](@ref). """ struct Diagonal{T} α::T β::T end -function Diagonal(sz::Integer...; - initα = i -> ones(Float32, i), - initβ = i -> zeros(Float32, i)) - Diagonal(initα(sz), initβ(sz)) +function Diagonal(sz::Integer...; initα = nothing, initβ = nothing) + α = if initα !== nothing + Base.depwarn("keyword initα is deprecated, please simply supply the desired vectors", :Diagonal) + initα(sz...) + else + ones(sz...) + end + β = if initβ !== nothing + Base.depwarn("keyword initβ is deprecated, please simply supply the desired vectors", :Diagonal) + initβ(sz...) + else + zeros(sz...) + end + Diagonal(α, β) end @functor Diagonal @@ -175,7 +196,7 @@ end (a::Diagonal)(x) = a.α .* x .+ a.β function Base.show(io::IO, l::Diagonal) - print(io, "Diagonal(", size(l.α), ")") + print(io, "Diagonal(", join(size(l.α), ", "), ")") end """ @@ -262,55 +283,65 @@ function Base.show(io::IO, b::SkipConnection) end """ - Bilinear(in1, in2, out) + Bilinear(in1, in2, out, σ=identity; bias=true, init=glorot_uniform) + Bilinear(weight::AbstractArray, [bias, σ]) Creates a Bilinear layer, which operates on two inputs at the same time. -It has parameters `W` and `b`, and its output given vectors `x`, `y` is of the form +It has parameters `weight` and `bias`, and its output given vectors `x`, `y` is +another vector `z` with, for `i ∈ 1:out`: - z[i] = σ.(x' * W[i,:,:] * y .+ b[i]) + z[i] = σ(dot(x' * weight[i,:,:] * y + bias[i]) If `x` and `y` are matrices, then each column of the output `z = B(x, y)` is of this form, -given that `B` is a Bilinear layer of appropriate size. +with `B` a Bilinear layer. If `y` is not given, it is taken to be equal to `x`, i.e. `B(x) == B(x, x)` The two inputs may also be provided as a tuple, `B((x, y)) == B(x, y)`, which is accepted as the input to a `Chain`. -```julia -# using Bilinear to generate interactions, on one input -x = randn(Float32, 11, 7) -B = Bilinear(11, 11, 3) -size(B(x)) == (3, 7) - -# using Bilinear on two data streams at once, as a tuple -x = randn(Float32, 10, 9) -y = randn(Float32, 2, 9) -m = Chain(Bilinear(10, 2, 3), Dense(3, 1)) -size(m((x, y))) == (1, 9) - -# using Bilinear as the recombinator in a SkipConnection -x = randn(Float32, 10, 9) -sc = SkipConnection(Dense(10, 10), Bilinear(10, 10, 5)) -size(sc(x)) == (5, 9) +Keywords `init` and `bias` work as for [`Dense`](@ref) layer. + +# Examples + +```jldoctest +julia> x, y = randn(Float32, 5, 32), randn(Float32, 5, 32); + +julia> B = Flux.Bilinear(5, 5, 7); + +julia> B(x) |> size # interactions based on one input +(7, 32) + +julia> B(x,y) == B((x,y)) # two inputs, may be given as a tuple +true + +julia> sc = SkipConnection( + Chain(Dense(5, 20, tanh), Dense(20, 9, tanh)), + Flux.Bilinear(9, 5, 3, bias=false), + ); # used as the recombinator, with skip as the second input + +julia> sc(x) |> size +(3, 32) ``` """ struct Bilinear{A,B,S} - W::A - b::B + weight::A + bias::B σ::S end @functor Bilinear -Bilinear(W, b) = Bilinear(W, b, identity) +Bilinear(weight::AbstractArray, bias = true) = Bilinear(weight, create_bias(weight, bias, size(weight,1)), identity) function Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity; - initW = glorot_uniform, initb = zeros) - return Bilinear(initW(out, in1, in2), initb(out), σ) + init = glorot_uniform, bias = true) + W = init(out, in1, in2) + b = create_bias(W, bias, out) + return Bilinear(W, b, σ) end function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix) - W, b, σ = a.W, a.b, a.σ + W, b, σ = a.weight, a.bias, a.σ d_z, d_x, d_y = size(W) d_x == size(x,1) && d_y == size(y,1) || throw(DimensionMismatch("number of rows in data must match W")) @@ -332,8 +363,9 @@ end (a::Bilinear)(x::NTuple{2, AbstractArray}) = a(x[1], x[2]) function Base.show(io::IO, l::Bilinear) - print(io, "Bilinear(", size(l.W, 2), ", ", size(l.W, 3), ", ", size(l.W, 1)) + print(io, "Bilinear(", size(l.weight, 2), ", ", size(l.weight, 3), ", ", size(l.weight, 1)) l.σ == identity || print(io, ", ", l.σ) + l.bias == Flux.Zeros() && print(io, ", bias=false") print(io, ")") end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 8724810762..63c1167e8a 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -30,7 +30,7 @@ function calc_padding(lt, ::SamePad, k::NTuple{N,T}, dilation, stride) where {N, end """ - Conv(filter, in => out, σ=identity; stride=1, pad=0, dilation=1) + Conv(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init]) Standard convolutional layer. `filter` is a tuple of integers specifying the size of the convolutional kernel; @@ -122,7 +122,7 @@ function Conv(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) pad = calc_padding(Conv, pad, size(w)[1:N-2], dilation, stride) - bias = create_bias(b, zeros, size(w, N)) + bias = create_bias(w, b, size(w, N)) return Conv(σ, w, bias, stride, pad, dilation) end @@ -166,7 +166,7 @@ end """ - ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, dilation=1) + ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init]) Standard convolutional transpose layer. `filter` is a tuple of integers specifying the size of the convolutional kernel, while @@ -216,7 +216,7 @@ function ConvTranspose(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVect stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride) - bias = create_bias(b, zeros, size(w, N-1)) + bias = create_bias(w, b, size(w, N-1)) return ConvTranspose(σ, w, bias, stride, pad, dilation) end @@ -266,7 +266,7 @@ function calc_padding(::Type{ConvTranspose}, pad::SamePad, k::NTuple{N,T}, dilat end """ - DepthwiseConv(filter, in=>out, σ=identity; stride=1, pad=0, dilation=1) + DepthwiseConv(filter, in=>out, σ=identity; stride=1, pad=0, dilation=1, [bias, init]) Depthwise convolutional layer. `filter` is a tuple of integers specifying the size of the convolutional kernel, while @@ -313,7 +313,7 @@ function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVect stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) pad = calc_padding(DepthwiseConv, pad, size(w)[1:N-2], dilation, stride) - bias = create_bias(b, zeros, prod(size(w)[N-1:end])) + bias = create_bias(w, b, prod(size(w)[N-1:end])) return DepthwiseConv(σ, w, bias, stride, pad, dilation) end @@ -355,7 +355,7 @@ end """ - CrossCor(filter, in => out, σ=identity; stride=1, pad=0, dilation=1) + CrossCor(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init]) Standard cross convolutional layer. `filter` is a tuple of integers specifying the size of the convolutional kernel; @@ -401,7 +401,7 @@ function CrossCor(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T} stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) pad = calc_padding(CrossCor, pad, size(w)[1:N-2], dilation, stride) - bias = create_bias(b, zeros, size(w, N)) + bias = create_bias(w, b, size(w, N)) return CrossCor(σ, w, bias, stride, pad, dilation) end diff --git a/src/utils.jl b/src/utils.jl index 30edab37d3..641f2cd0b6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -177,10 +177,10 @@ kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaimi """ orthogonal([rng=GLOBAL_RNG], dims...; gain = 1) -Return an `Array` of size `dims` which is a (semi) orthogonal matrix, as described in [1]. +Return an `Array` of size `dims` which is a (semi) orthogonal matrix, as described in [1]. The input must have at least 2 dimensions. -For `length(dims) > 2`, a `prod(dims[1:(end - 1)])` by `dims[end]` orthogonal matrix +For `length(dims) > 2`, a `prod(dims[1:(end - 1)])` by `dims[end]` orthogonal matrix is computed before reshaping it to the original dimensions. # Examples @@ -291,18 +291,23 @@ ones(dims...) = Base.ones(Float32, dims...) zeros(dims...) = Base.zeros(Float32, dims...) """ - create_bias(shallcreate::Bool, iftrue, dims...) - create_bias(x, ::Any...) + create_bias(weights, bias, length) -Return a bias parameter for a layer. +Return a bias parameter for a layer, based on the value given +to the constructor's keyword `bias=bias`. -Essentially handles the allowed input options for the `bias` keyword: - If `false`: Return the `Zeros` type which turns bias off. - If `true` : Return the result of `iftrue(dims)`. - If not a boolean, return self to handle the case of bias=somearray. +* `bias == true` creates a zero vector, of the same type as weights. +* `bias == false` returns `Zeros()`, a special struct which exists to encode the absence of bias. +* `bias::AbstractArray` uses the array provided, provided it has the correct size and eltype. If the type is wrong, it will be converted. """ -create_bias(shallcreate::Bool, iftrue, dims...) = shallcreate ? iftrue(dims...) : Zeros() -create_bias(x, ::Any...) = x +function create_bias(weights::AbstractArray{T}, bias::Union{Bool, AbstractArray}, dims::Integer...) where {T} + bias===true && return fill!(similar(weights, dims...), 0) + bias===false && return Zeros() + size(bias) == dims || throw(DimensionMismatch("expected bias of size $dims, but got $(size(bias))")) + eltype(bias) == T && return bias + @warn "converting bias to match element type of weights" typeof(weights) typeof(bias) maxlog=3 _id=hash(dims) + return T.(bias) +end """ unsqueeze(xs, dim) diff --git a/src/zeros.jl b/src/zeros.jl index fef9d1862e..1281f4c87a 100644 --- a/src/zeros.jl +++ b/src/zeros.jl @@ -47,3 +47,6 @@ broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = -b @adjoint broadcasted(::typeof(*), a::AbstractArray, b::Zeros) = zero(a), _ -> (nothing, zero(a), nothing) @adjoint broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = zero(b), _ -> (nothing, nothing, zero(b)) @adjoint broadcasted(::typeof(/), a::Zeros, b::AbstractArray) = zero(b), _ -> (nothing, nothing, zero(b)) + +# Pass-through for layer constructors +create_bias(weights::AbstractArray, bias::Flux.Zeros, dims::Integer...) = bias diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 0f6014784f..7047e095bc 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -29,11 +29,27 @@ import Flux: activations @testset "Dense" begin @testset "constructors" begin - @test size(Dense(10, 100).W) == (100, 10) - @test Dense(rand(100,10), rand(10)).σ == identity + @test size(Dense(10, 100).weight) == (100, 10) + @test size(Dense(10, 100).bias) == (100,) + @test Dense(rand(100,10), rand(100)).σ == identity + @test Dense(rand(100,10)).σ == identity + + @test Dense(rand(100,10), false).σ == identity + @test Dense(rand(100,10), false, tanh).σ == tanh + @test Dense(rand(100,10), rand(100)).σ == identity + @test Dense(rand(Float16, 100,10), true).bias isa Vector{Float16} # creates matching type + @test Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match + + @test Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64} + @test Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64} @test_throws MethodError Dense(10, 10.5) @test_throws MethodError Dense(10, 10.5, tanh) + @test_throws DimensionMismatch Dense(3,4; bias=rand(5)) + @test_throws DimensionMismatch Dense(rand(4,3), rand(5)) + @test_throws MethodError Dense(rand(5)) + @test_throws MethodError Dense(rand(5), rand(5)) + @test_throws MethodError Dense(rand(5), rand(5), tanh) end @testset "dimensions" begin @test length(Dense(10, 5)(randn(10))) == 5 @@ -44,16 +60,14 @@ import Flux: activations @test size(Dense(10, 5)(randn(10,2))) == (5,2) @test size(Dense(10, 5)(randn(10,2,3))) == (5,2,3) @test size(Dense(10, 5)(randn(10,2,3,4))) == (5,2,3,4) + @test_throws DimensionMismatch Dense(10, 5)(randn(11,2,3)) end @testset "zeros" begin - # old keywords - @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1) - @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2) - @test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1) - @test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] - @test Dense(10, 2, identity, initW = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] - # new + @test Dense(10, 1, identity, init = ones)(ones(10,1)) == 10*ones(1, 1) @test Dense(10, 1, identity, init = ones)(ones(10,2)) == 10*ones(1, 2) + @test Dense(10, 2, identity, init = ones)(ones(10,1)) == 10*ones(2, 1) + @test Dense(10, 2, identity, init = ones)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] + @test Dense(10, 2, identity, init = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] end end @@ -138,8 +152,24 @@ import Flux: activations @test size(b(x)) == (3,7) @test_nowarn gs = gradient(() -> sum(abs2.(b(x))), params(b)) end + + @testset "constructors" begin + b1 = Flux.Bilinear(randn(3,4,5)) + @test b1.bias isa Vector{Float64} + @test b1.σ == identity + + b2 = Flux.Bilinear(randn(3,4,5), false) + @test b2.bias == Flux.Zeros() + + b3 = Flux.Bilinear(randn(3,4,5), true, tanh) + @test b3.σ == tanh + @test size(b3(rand(4), rand(5))) == (3,) + + b4 = Flux.Bilinear(3,3,7; bias=1:7, init=Flux.zeros) + @test b4.bias isa Vector{Float32} + end end - + @testset "Parallel" begin @testset "zero sum" begin input = randn(10, 10, 10, 10) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 458ce69191..6806a7f175 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -183,3 +183,8 @@ end l = ltype(k, pad=SamePad()) @test size(l(data))[1:end-2] == cld.(size(data)[1:end-2], k) end + +@testset "bugs fixed" begin + # https://github.com/FluxML/Flux.jl/issues/1421 + @test Conv((5, 5), 10 => 20, identity; init = Base.randn).bias isa Vector{Float64} +end diff --git a/test/outputsize.jl b/test/outputsize.jl index eec96023df..3db6f42942 100644 --- a/test/outputsize.jl +++ b/test/outputsize.jl @@ -25,6 +25,10 @@ m = Flux.unsqueeze(3) @test outputsize(m, (5, 7, 13)) == (5, 7, 1, 13) + m = Flux.Bilinear(10, 10, 7) + @test outputsize(m, (10,)) == (7,) + @test outputsize(m, (10, 32)) == (7, 32) + m = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), flatten, Dense(1024, 10)) @test outputsize(m, (10, 10, 3, 50)) == (10, 50) @test outputsize(m, (10, 10, 3, 2)) == (10, 2) diff --git a/test/utils.jl b/test/utils.jl index f5140fd072..65b042c0de 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -282,8 +282,8 @@ end end @testset "Param remapping" begin - ls(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...) - dl(nin, nout, bias) = Dense(ls(nin, nout), bias(nout)) + ls(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...) # accepts dims in reverse order to Dense + dl(nin, nout, bias) = Dense(ls(nout, nin), bias(nout)) dm(bias) = Chain( dl(3, 5, bias), dl(5, 4, bias), @@ -299,21 +299,22 @@ end @testset "loadparams!" begin import Flux: loadparams! - pars(w, b::Zeros) = [w, zeros(size(w,2))] pars(w, b) = [w, b] + import Flux: loadparams!, Zeros + pars(w, b::Zeros) = [w, Flux.zeros(size(w,1))] pars(l) = pars(l.W, l.b) pararray(m) = mapreduce(pars, vcat, m) weights(m) = mapreduce(l -> [l.W], vcat, m) - @testset "Bias type $bt" for bt in (zeros, nobias) + @testset "Bias type $bt" for bt in (Flux.zeros, nobias) m = dm(bt) loadparams!(m, params(m)) testdense(m, bt) end @testset "$b1 to $b2" for (b1, b2, be) in ( - (zeros, ones, ones), # Load ones as bias to a model with zeros as bias -> model gets ones as bias - (ones, nobias, zeros), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias - (nobias, ones, nobias), # Load ones as bias to a model with Zeros as bias-> model bias does not change + (Flux.zeros, ones, ones), # Load ones as bias to a model with zeros as bias -> model gets ones as bias + (ones, nobias, Flux.zeros), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias + (nobias, ones, nobias), # Load ones as bias to a model with Zeros as bias-> model bias does not change ) m1 = dm(b1) m2 = dm(b2) From 3954f046a7bb9214c3ae8e4d6f76f3605b6028f8 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 12 Feb 2021 22:01:30 +0100 Subject: [PATCH 4/9] doc tweaks --- docs/src/models/layers.md | 4 ++-- src/layers/basic.jl | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index f0a2f08fee..cee27941de 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -5,7 +5,6 @@ These core layers form the foundation of almost all neural networks. ```@docs Chain Dense -Flux.Diagonal ``` ## Convolution and Pooling Layers @@ -57,7 +56,8 @@ But in contrast to the layers described in the other sections are not readily gr Maxout SkipConnection Parallel -Bilinear +Flux.Bilinear +Flux.Diagonal ``` ## Normalisation & Regularisation diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 9bf3f72364..5ab86902f8 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -94,10 +94,10 @@ Dense(5, 2) julia> d(rand(Float32, 5, 64)) |> size (2, 64) -julia> d(rand(Float32, 5, 1, 1, 64)) |> size +julia> d(rand(Float32, 5, 1, 1, 64)) |> size # treated as three batch dimensions (2, 1, 1, 64) -julia> d1 = Dense(ones(2, 5), false, tanh) +julia> d1 = Dense(ones(2, 5), false, tanh) # using provided weight matrix Dense(5, 2, tanh; bias=false) julia> d1(ones(5)) @@ -284,13 +284,13 @@ end """ Bilinear(in1, in2, out, σ=identity; bias=true, init=glorot_uniform) - Bilinear(weight::AbstractArray, [bias, σ]) + Bilinear(W::AbstractArray, [bias, σ]) Creates a Bilinear layer, which operates on two inputs at the same time. -It has parameters `weight` and `bias`, and its output given vectors `x`, `y` is -another vector `z` with, for `i ∈ 1:out`: +It its output, given vectors `x`, `y` is another vector `z` with, +for all `i ∈ 1:out`: - z[i] = σ(dot(x' * weight[i,:,:] * y + bias[i]) + z[i] = σ(x' * W[i,:,:] * y + bias[i]) If `x` and `y` are matrices, then each column of the output `z = B(x, y)` is of this form, with `B` a Bilinear layer. @@ -299,7 +299,9 @@ If `y` is not given, it is taken to be equal to `x`, i.e. `B(x) == B(x, x)` The two inputs may also be provided as a tuple, `B((x, y)) == B(x, y)`, which is accepted as the input to a `Chain`. -Keywords `init` and `bias` work as for [`Dense`](@ref) layer. +The initialisation works as for [`Dense`](@ref) layer, with `W = init(out, in1, in2)`. +By default the bias vector is `zeros(Float32, out)`, option `bias=false` will switch off +trainable bias. Either of these may be provided explicitly. # Examples @@ -370,7 +372,7 @@ function Base.show(io::IO, l::Bilinear) end """ -Parallel(connection, layers...) + Parallel(connection, layers...) Create a 'Parallel' layer that passes an input array to each path in `layers`, reducing the output with `connection`. From 644ec8c08799ddf4855f13e6dd0a6a4d6a8554cd Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 13 Feb 2021 10:25:28 +0100 Subject: [PATCH 5/9] split create_bias into two --- src/utils.jl | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 641f2cd0b6..14262f6e50 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -297,16 +297,20 @@ Return a bias parameter for a layer, based on the value given to the constructor's keyword `bias=bias`. * `bias == true` creates a zero vector, of the same type as weights. -* `bias == false` returns `Zeros()`, a special struct which exists to encode the absence of bias. +* `bias == false` returns `Zeros()`, a special struct which exists only to encode the absence of bias. * `bias::AbstractArray` uses the array provided, provided it has the correct size and eltype. If the type is wrong, it will be converted. """ -function create_bias(weights::AbstractArray{T}, bias::Union{Bool, AbstractArray}, dims::Integer...) where {T} - bias===true && return fill!(similar(weights, dims...), 0) - bias===false && return Zeros() - size(bias) == dims || throw(DimensionMismatch("expected bias of size $dims, but got $(size(bias))")) - eltype(bias) == T && return bias - @warn "converting bias to match element type of weights" typeof(weights) typeof(bias) maxlog=3 _id=hash(dims) - return T.(bias) +function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...) + bias ? fill!(similar(weights, dims...), 0) : Zeros() +end +function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...) + size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))")) + if eltype(bias) == eltype(weights) + return bias + else + @warn "converting bias to match element type of weights" typeof(weights) typeof(bias) maxlog=3 _id=hash(dims) + return broadcast(eltype(weights), bias) + end end """ From 3c4875e4e00c0dab391ead7dfe246d840859b650 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 13 Feb 2021 10:25:57 +0100 Subject: [PATCH 6/9] use inner constructor for Bilinear, more like Dense --- src/layers/basic.jl | 22 +++++++++++++--------- test/layers/basic.jl | 7 ++++++- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 5ab86902f8..6eb7313ece 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -287,7 +287,7 @@ end Bilinear(W::AbstractArray, [bias, σ]) Creates a Bilinear layer, which operates on two inputs at the same time. -It its output, given vectors `x`, `y` is another vector `z` with, +Its output, given vectors `x` & `y`, is another vector `z` with, for all `i ∈ 1:out`: z[i] = σ(x' * W[i,:,:] * y + bias[i]) @@ -323,23 +323,27 @@ julia> sc = SkipConnection( julia> sc(x) |> size (3, 32) + +julia> Flux.Bilinear(rand(4,8,16), false, tanh) # first dim of weight is the output +Bilinear(8, 16, 4, tanh, bias=false) ``` """ -struct Bilinear{A,B,S} +struct Bilinear{F,A,B} weight::A bias::B - σ::S + σ::F + function Bilinear(W::A, bias = true, σ::F = identity) where {A<:AbstractArray, F} + ndims(A) == 3 || throw(ArgumentError("expected a 3-array of weights")) + b = create_bias(W, bias, size(W,1)) + new{F,A,typeof(b)}(W, b, σ) + end end @functor Bilinear -Bilinear(weight::AbstractArray, bias = true) = Bilinear(weight, create_bias(weight, bias, size(weight,1)), identity) - function Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity; - init = glorot_uniform, bias = true) - W = init(out, in1, in2) - b = create_bias(W, bias, out) - return Bilinear(W, b, σ) + init = glorot_uniform, bias = true) + Bilinear(init(out, in1, in2), bias, σ) end function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 7047e095bc..d76056806a 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -161,12 +161,17 @@ import Flux: activations b2 = Flux.Bilinear(randn(3,4,5), false) @test b2.bias == Flux.Zeros() - b3 = Flux.Bilinear(randn(3,4,5), true, tanh) + b3 = Flux.Bilinear(randn(Float16, 3,4,5), true, tanh) @test b3.σ == tanh + @test b2.bias isa Vector{Float16} @test size(b3(rand(4), rand(5))) == (3,) b4 = Flux.Bilinear(3,3,7; bias=1:7, init=Flux.zeros) @test b4.bias isa Vector{Float32} + + @test_throws ArgumentError Flux.Bilinear(rand(3)) # expects a 3-array + @test_throws ArgumentError Flux.Bilinear(rand(3,4), false, tanh) + @test_throws DimensionMismatch Flux.Bilinear(rand(3,4,5), rand(6), tanh) # wrong length bias end end From bc7f7c7526aa8894087291d0ff163af8dd5ceddc Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 13 Feb 2021 10:56:00 +0100 Subject: [PATCH 7/9] make Conv constructors more forgiving about bias type --- src/layers/conv.jl | 38 +++++++++++++++++++------------------- test/layers/conv.jl | 12 ++++++++++++ 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 63c1167e8a..bef5d94b62 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -95,8 +95,8 @@ struct Conv{N,M,F,A,V} end """ - Conv(weight::AbstractArray, bias, [activation; stride, pad, dilation]) - + Conv(weight::AbstractArray, [bias, activation; stride, pad, dilation]) + Constructs a convolutional layer with the given weight and bias. Accepts the same keywords (and has the same defaults) as the `Conv((4,4), 3=>7, relu)` method. @@ -117,13 +117,13 @@ julia> params(c1) |> length 2 ``` """ -function Conv(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity; +function Conv(w::AbstractArray{T,N}, bias = true, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) pad = calc_padding(Conv, pad, size(w)[1:N-2], dilation, stride) - bias = create_bias(w, b, size(w, N)) - return Conv(σ, w, bias, stride, pad, dilation) + b = create_bias(w, bias, size(w, N)) + return Conv(σ, w, b, stride, pad, dilation) end function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; @@ -206,18 +206,18 @@ struct ConvTranspose{N,M,F,A,V} end """ - ConvTranspose(weight::AbstractArray, bias, [activation; stride, pad, dilation]) - + ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation]) + Constructs a layer with the given weight and bias arrays. Accepts the same keywords as the `ConvTranspose((4,4), 3=>7, relu)` method. """ -function ConvTranspose(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity; +function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride) - bias = create_bias(w, b, size(w, N-1)) - return ConvTranspose(σ, w, bias, stride, pad, dilation) + b = create_bias(w, bias, size(w, N-1)) + return ConvTranspose(σ, w, b, stride, pad, dilation) end function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; @@ -304,17 +304,17 @@ end """ DepthwiseConv(weight::AbstractArray, bias, [activation; stride, pad, dilation]) - + Constructs a layer with the given weight and bias arrays. Accepts the same keywords as the `DepthwiseConv((4,4), 3=>6, relu)` method. """ -function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity; +function DepthwiseConv(w::AbstractArray{T,N}, bias = true, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) pad = calc_padding(DepthwiseConv, pad, size(w)[1:N-2], dilation, stride) - bias = create_bias(w, b, prod(size(w)[N-1:end])) - return DepthwiseConv(σ, w, bias, stride, pad, dilation) + b = create_bias(w, bias, prod(size(w)[N-1:end])) + return DepthwiseConv(σ, w, b, stride, pad, dilation) end function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; @@ -391,18 +391,18 @@ struct CrossCor{N,M,F,A,V} end """ - CrossCor(weight::AbstractArray, bias, [activation; stride, pad, dilation]) - + CrossCor(weight::AbstractArray, [bias, activation; stride, pad, dilation]) + Constructs a layer with the given weight and bias arrays. Accepts the same keywords as the `CrossCor((4,4), 3=>7, relu)` method. """ -function CrossCor(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity; +function CrossCor(w::AbstractArray{T,N}, bias = true, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) pad = calc_padding(CrossCor, pad, size(w)[1:N-2], dilation, stride) - bias = create_bias(w, b, size(w, N)) - return CrossCor(σ, w, bias, stride, pad, dilation) + b = create_bias(w, bias, size(w, N)) + return CrossCor(σ, w, b, stride, pad, dilation) end function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 6806a7f175..d8840031ae 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -188,3 +188,15 @@ end # https://github.com/FluxML/Flux.jl/issues/1421 @test Conv((5, 5), 10 => 20, identity; init = Base.randn).bias isa Vector{Float64} end + +@testset "constructors: $fun" for fun in [Conv, CrossCor, ConvTranspose, DepthwiseConv] + @test fun(rand(2,3,4)).bias isa Vector{Float64} + @test fun(rand(2,3,4,5), false).bias isa Flux.Zeros + if fun == Conv + @test fun(rand(2,3,4,5,6), rand(6)).bias isa Vector{Float64} + @test fun(rand(2,3,4,5,6), 1:6).bias isa Vector{Float64} + elseif fun == DepthwiseConv + @test fun(rand(2,3,4,5,6), rand(30)).bias isa Vector{Float64} + end + @test_throws DimensionMismatch fun(rand(2,3,4), rand(6)) +end From 5a27ff5b7c9a4db49fe9a019c83595a8c4b8f09e Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 13 Feb 2021 10:57:43 +0100 Subject: [PATCH 8/9] typo in a test --- test/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index d76056806a..cb321cccfb 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -163,7 +163,7 @@ import Flux: activations b3 = Flux.Bilinear(randn(Float16, 3,4,5), true, tanh) @test b3.σ == tanh - @test b2.bias isa Vector{Float16} + @test b3.bias isa Vector{Float16} @test size(b3(rand(4), rand(5))) == (3,) b4 = Flux.Bilinear(3,3,7; bias=1:7, init=Flux.zeros) From ae879cc5ebdef009fbbcfaec8b4c8b7e93dfc076 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 13 Feb 2021 11:05:09 +0100 Subject: [PATCH 9/9] make Dense(x) prettier --- src/layers/basic.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 6eb7313ece..cae16801f6 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -143,11 +143,11 @@ end @functor Dense function (a::Dense)(x::AbstractArray) - W, b, σ = getfield(a, :weight), getfield(a, :bias), getfield(a, :σ) + W, b, σ = a.weight, a.bias, a.σ sz = size(x) - x = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions - x = σ.(W*x .+ b) - return reshape(x, :, sz[2:end]...) + y = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions + z = σ.(W*y .+ b) + return reshape(z, :, sz[2:end]...) end function Base.show(io::IO, l::Dense)