From 8f650aced2bdd704c589915c6cf81f32322b7167 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 22 Jan 2021 17:12:57 +0100 Subject: [PATCH] squash PR 1407, eleven commits, 2020 --- src/Flux.jl | 1 - src/deprecations.jl | 10 ++++ src/layers/basic.jl | 3 +- src/layers/conv.jl | 16 +++---- src/utils.jl | 4 +- src/zeros.jl | 52 --------------------- test/optimise.jl | 2 +- test/utils.jl | 108 ++++++++++++-------------------------------- 8 files changed, 52 insertions(+), 144 deletions(-) delete mode 100644 src/zeros.jl diff --git a/src/Flux.jl b/src/Flux.jl index 5e6776d601..09854c385a 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -34,7 +34,6 @@ using CUDA const use_cuda = Ref(false) include("utils.jl") -include("zeros.jl") include("onehot.jl") include("functor.jl") diff --git a/src/deprecations.jl b/src/deprecations.jl index 91ab097092..4353790308 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -3,7 +3,9 @@ @deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, true, true, nothing) @deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, true, true, nothing) @deprecate GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum) GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, nothing) + @deprecate outdims(f, inputsize) outputsize(f, inputsize) + @deprecate Conv(; weight, bias, activation=identity, kws...) Conv(weight, bias, activation; kws...) @deprecate ConvTranspose(; weight, bias, activation=identity, kws...) ConvTranspose(weight, bias, activation; kws...) @deprecate DepthwiseConv(; weight, bias, activation=identity, kws...) DepthwiseConv(weight, bias, activation; kws...) @@ -18,3 +20,11 @@ function Base.getproperty(a::Dense, s::Symbol) end return getfield(a, s) end + +struct Zeros # was used both Dense(10, 2, initb = Zeros) and Dense(rand(2,10), Zeros()) + function Zeros() + Base.depwarn("Zeros() and Zeros(dims...) are deprecated, please simply use bias=false instead", :Zeros) + false + end +end +Zeros(args...) = Zeros() diff --git a/src/layers/basic.jl b/src/layers/basic.jl index cae16801f6..1fbdb36761 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -67,7 +67,6 @@ end extraChain(::Tuple{}, x) = () - """ Dense(in, out, σ=identity; bias=true, init=glorot_uniform) Dense(W::AbstractMatrix, [bias, σ]) @@ -153,7 +152,7 @@ end function Base.show(io::IO, l::Dense) print(io, "Dense(", size(l.weight, 2), ", ", size(l.weight, 1)) l.σ == identity || print(io, ", ", l.σ) - l.bias == Zeros() && print(io, "; bias=false") + l.bias == false && print(io, "; bias=false") print(io, ")") end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index bef5d94b62..4f9688f52c 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -6,6 +6,10 @@ _paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end] expand(N, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) +conv_reshape_bias(c) = c.bias isa AbstractVector ? + reshape(c.bias, map(_->1, c.stride)..., :, 1) : + c.bias + """ SamePad() @@ -152,9 +156,8 @@ convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; function (c::Conv)(x::AbstractArray) # TODO: breaks gpu broadcast :( # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) - σ, b = c.σ, reshape(c.bias, ntuple(_->1, length(c.stride))..., :, 1) cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) - σ.(conv(x, c.weight, cdims) .+ b) + (c.σ).(conv(x, c.weight, cdims) .+ conv_reshape_bias(c)) end function Base.show(io::IO, l::Conv) @@ -248,9 +251,8 @@ end function (c::ConvTranspose)(x::AbstractArray) # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) - σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) cdims = conv_transpose_dims(c, x) - σ.(∇conv_data(x, c.weight, cdims) .+ b) + (c.σ).(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c)) end function Base.show(io::IO, l::ConvTranspose) @@ -341,9 +343,8 @@ depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1]) function (c::DepthwiseConv)(x) - σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) - σ.(depthwiseconv(x, c.weight, cdims) .+ b) + (c.σ).(depthwiseconv(x, c.weight, cdims) .+ conv_reshape_bias(c)) end function Base.show(io::IO, l::DepthwiseConv) @@ -422,9 +423,8 @@ end function (c::CrossCor)(x::AbstractArray) # TODO: breaks gpu broadcast :( # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) - σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) - σ.(crosscor(x, c.weight, cdims) .+ b) + (c.σ).(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c)) end function Base.show(io::IO, l::CrossCor) diff --git a/src/utils.jl b/src/utils.jl index 14262f6e50..30d046e62a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -297,11 +297,11 @@ Return a bias parameter for a layer, based on the value given to the constructor's keyword `bias=bias`. * `bias == true` creates a zero vector, of the same type as weights. -* `bias == false` returns `Zeros()`, a special struct which exists only to encode the absence of bias. +* `bias == false` returns `false`, to indicate no trainable bias. * `bias::AbstractArray` uses the array provided, provided it has the correct size and eltype. If the type is wrong, it will be converted. """ function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...) - bias ? fill!(similar(weights, dims...), 0) : Zeros() + bias ? fill!(similar(weights, dims...), 0) : false end function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...) size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))")) diff --git a/src/zeros.jl b/src/zeros.jl deleted file mode 100644 index 1281f4c87a..0000000000 --- a/src/zeros.jl +++ /dev/null @@ -1,52 +0,0 @@ -import Base: +, -, *,/, reshape, broadcasted - -""" - Zeros() - -Acts as a stand-in for an array of zeros that can be -used during training which is ignored by the optimisers. - -Useful to turn bias off for a forward pass of a layer. - -## Examples - -```julia-repl -julia> bias_less_conv = Conv((2,2), 1=>3; bias = false) -Conv((2, 2), 1=>3) - -julia> params(bias_less_conv) |> length -1 - -julia> bias_less_conv.bias -Flux.Zeros() -``` -""" -struct Zeros end -# To allow for things like Dense(10, 2, initb = Zeros) -Zeros(args...) = Zeros() - -Base.reshape(x::Zeros, dims...) = x - -+(::Zeros, b::AbstractArray) = b -+(a::AbstractArray, ::Zeros) = a -+(a::Zeros, ::Zeros) = a - --(::Zeros, b::AbstractArray) = -b --(a::AbstractArray, ::Zeros) = a --(a::Zeros, ::Zeros) = a - -# Some opportunities to avoid scalar indexing, intermediaries -# Since it replicates a little of what we expect Base to do, -# it should be possible to remove in the future, but for now, -# these help with performance. -broadcasted(::typeof(+), a::AbstractArray, b::Zeros) = a -broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = b -broadcasted(::typeof(-), a::AbstractArray, b::Zeros) = a -broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = -b -# Need adjoints for these or else the gradient w.r.t to the non-Zeros arg will be nothing as well -@adjoint broadcasted(::typeof(*), a::AbstractArray, b::Zeros) = zero(a), _ -> (nothing, zero(a), nothing) -@adjoint broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = zero(b), _ -> (nothing, nothing, zero(b)) -@adjoint broadcasted(::typeof(/), a::Zeros, b::AbstractArray) = zero(b), _ -> (nothing, nothing, zero(b)) - -# Pass-through for layer constructors -create_bias(weights::AbstractArray, bias::Flux.Zeros, dims::Integer...) = bias diff --git a/test/optimise.jl b/test/optimise.jl index 04cbf6f6c0..63ab91d58b 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -14,7 +14,7 @@ using Random Nesterov(), RMSProp(), Momentum()] Random.seed!(42) w′ = randn(10, 10) - b = Flux.Zeros() + b = false loss(x) = Flux.Losses.mse(w*x, w′*x .+ b) for t = 1: 10^5 θ = params([w′, b]) diff --git a/test/utils.jl b/test/utils.jl index 65b042c0de..a539858745 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -187,88 +187,39 @@ end @test eltype(f32(f64(m))[1].W) == Float32 end -@testset "Zeros" begin +@testset "Without bias" begin m = Dense(3,2; bias=false) - @test f64(m).b === m.b === Zeros() - @test f32(m).b === m.b === Zeros() + @test f64(m).b === m.b === false === Zeros() # Zeros() is deprecated + @test f32(m).b === m.b === false @testset "Gradients for broadcasted $op with sizes $s" for op in (+,-,*), s in ((1,), (2,3)) o = ones(s) z = zeros(s) - Z = Zeros() @testset "Explicit" begin gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...) g = gfun(o, z) - @test gfun(o, Z) == (g[1], nothing) + @test gfun(o, false) == (g[1], nothing) g = gfun(z, o) - @test gfun(Z, o) == (nothing, g[2]) + @test gfun(false, o) == (nothing, g[2]) end @testset "Implicit" begin gfun(args...) = gradient(() -> sum(op.(args...)), params(collect(args))) g = gfun(o, z) - gres = gfun(o, Z) + gres = gfun(o, false) @test gres[o] == g[o] - @test Z ∉ gres.params + @test false ∉ gres.params + @test length(gres.params) == 1 g = gfun(z, o) - gres = gfun(Z, o) - @test gres[o] == g[o] - @test Z ∉ gres.params - end - end - - @testset "Gradients for broadcasted / with sizes $s" for s in ((1,), (2,3)) - o = ones(s) - z = zeros(s) - Z = Zeros() # Only defined for 0-dim - - @testset "Explicit" begin - gfun(args...) = gradient((x, y) -> sum(x ./ y), args...) - g = gfun(z, o) - @test gfun(Z, o) == (nothing, g[2]) - end - - @testset "Implicit" begin - gfun(x,y) = gradient(() -> sum(x ./ y), params([x,y])) - - g = gfun(z, o) - gres = gfun(Z, o) - @test gres[o] == g[o] - @test Z ∉ gres.params - end - end - - @testset "Gradients for $op with sizes $s" for op in (+,-), s in (tuple(), (1,), (2,3)) - o = ones(s) - z = zeros(s) - Z = Zeros() - - - @testset "Explicit" begin - gfun(args...) = gradient((x, y) -> sum(op(x,y)), args...) - - g = gfun(o, z) - @test gfun(o, Z) == (g[1], nothing) - - g = gfun(z, o) - @test gfun(Z, o) == (nothing, g[2]) - end - @testset "Implicit" begin - gfun(args...) = gradient(() -> sum(op(args...)), params(collect(args))) - g = gfun(o, z) - gres = gfun(o, Z) + gres = gfun(false, o) @test gres[o] == g[o] - @test Z ∉ gres.params - - g = gfun(z, o) - gres = gfun(Z, o) - @test gres[o] == g[o] - @test Z ∉ gres.params + @test false ∉ gres.params + @test length(gres.params) == 1 end end end @@ -281,52 +232,53 @@ end @test stack(unstack(stacked_array, 1), 1) == stacked_array end + @testset "Param remapping" begin - ls(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...) # accepts dims in reverse order to Dense - dl(nin, nout, bias) = Dense(ls(nout, nin), bias(nout)) - dm(bias) = Chain( - dl(3, 5, bias), - dl(5, 4, bias), - dl(4, 3, bias) + count32(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...) # accepts dims in reverse order to Dense + dl(nin, nout, bt) = Dense(count32(nout, nin), bt(nout)) # this accepts dims in same order as Dense + densechain(bt) = Chain( + dl(3, 5, bt), + dl(5, 4, bt), + dl(4, 3, bt) ) + nobias(n) = false - nobias(n) = Zeros() - testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt))) - @test l1.W == l2.W - @test l1.b == l2.b - @test typeof(l1.b) === typeof(l2.b) + testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, densechain(bt))) + @test l1.weight == l2.weight + @test l1.bias == l2.bias + @test typeof(l1.bias) === typeof(l2.bias) end @testset "loadparams!" begin - import Flux: loadparams! pars(w, b) = [w, b] import Flux: loadparams!, Zeros pars(w, b::Zeros) = [w, Flux.zeros(size(w,1))] pars(l) = pars(l.W, l.b) pararray(m) = mapreduce(pars, vcat, m) weights(m) = mapreduce(l -> [l.W], vcat, m) - @testset "Bias type $bt" for bt in (Flux.zeros, nobias) - m = dm(bt) + @testset "Bias type $bt" for bt in (zeros, nobias) + m = densechain(bt) loadparams!(m, params(m)) testdense(m, bt) end - + #= @testset "$b1 to $b2" for (b1, b2, be) in ( (Flux.zeros, ones, ones), # Load ones as bias to a model with zeros as bias -> model gets ones as bias (ones, nobias, Flux.zeros), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias (nobias, ones, nobias), # Load ones as bias to a model with Zeros as bias-> model bias does not change ) - m1 = dm(b1) - m2 = dm(b2) + m1 = densechain(b1) + m2 = densechain(b2) loadparams!(m1, b1 == nobias ? weights(m2) : pararray(m2)) testdense(m1, be) end + =# end @testset "destructure" begin import Flux: destructure @testset "Bias type $bt" for bt in (zeros, nobias) - m = dm(bt) + m = densechain(bt) p, re = destructure(m) testdense(re(p), bt) end