Skip to content

Commit

Permalink
Merge pull request #151 from theabhirath/conv_bn
Browse files Browse the repository at this point in the history
Improved time to first gradient
  • Loading branch information
ToucheSir authored May 1, 2022
2 parents e88e478 + 9f5295a commit 792076f
Show file tree
Hide file tree
Showing 14 changed files with 138 additions and 135 deletions.
8 changes: 4 additions & 4 deletions src/convnets/convmixer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ Creates a ConvMixer model.
function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9),
patch_size::Dims{2} = (7, 7), activation = gelu, nclasses = 1000)
stem = conv_bn(patch_size, inchannels, planes, activation; preact = true, stride = patch_size[1])
blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation;
preact = true, groups = planes, pad = SamePad())...), +),
conv_bn((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth]
blocks = [Chain(SkipConnection(conv_bn(kernel_size, planes, planes, activation;
preact = true, groups = planes, pad = SamePad()), +),
conv_bn((1, 1), planes, planes, activation; preact = true)) for _ in 1:depth]
head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses))
return Chain(Chain(stem..., blocks...), head)
return Chain(Chain(stem, Chain(blocks)), head)
end

convmixer_config = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9),
Expand Down
4 changes: 2 additions & 2 deletions src/convnets/convnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Creates a single block of ConvNeXt.
- `λ`: Init value for LayerScale
"""
function convnextblock(planes, drop_path_rate = 0., λ = 1f-6)
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
swapdims((3, 1, 2, 4)),
LayerNorm(planes; ϵ = 1f-6),
mlp_block(planes, 4 * planes),
Expand Down Expand Up @@ -61,7 +61,7 @@ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0., λ = 1f-6
LayerNorm(planes[end]),
Dense(planes[end], nclasses))

return Chain(Chain(backbone...), head)
return Chain(Chain(backbone), head)
end

# Configurations for ConvNeXt models
Expand Down
18 changes: 9 additions & 9 deletions src/convnets/densenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ Create a Densenet bottleneck layer
"""
function dense_bottleneck(inplanes, outplanes)
inner_channels = 4 * outplanes
m = Chain(conv_bn((1, 1), inplanes, inner_channels; bias = false, rev = true)...,
conv_bn((3, 3), inner_channels, outplanes; pad = 1, bias = false, rev = true)...)
m = Chain(conv_bn((1, 1), inplanes, inner_channels; bias = false, rev = true),
conv_bn((3, 3), inner_channels, outplanes; pad = 1, bias = false, rev = true))

SkipConnection(m, (mx, x) -> cat(x, mx; dims = 3))
SkipConnection(m, cat_channels)
end

"""
Expand All @@ -28,8 +28,7 @@ Create a DenseNet transition sequence
- `outplanes`: number of output feature maps
"""
transition(inplanes, outplanes) =
[conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)...,
MeanPool((2, 2))]
Chain(conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true), MeanPool((2, 2)))

"""
dense_block(inplanes, growth_rates)
Expand Down Expand Up @@ -60,20 +59,21 @@ Create a DenseNet model
- `nclasses`: the number of output classes
"""
function densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000)
layers = conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false)
layers = []
push!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false))
push!(layers, MaxPool((3, 3), stride = 2, pad = (1, 1)))

outplanes = 0
for (i, rates) in enumerate(growth_rates)
outplanes = inplanes + sum(rates)
append!(layers, dense_block(inplanes, rates))
(i != length(growth_rates)) &&
append!(layers, transition(outplanes, floor(Int, outplanes * reduction)))
(i != length(growth_rates)) &&
push!(layers, transition(outplanes, floor(Int, outplanes * reduction)))
inplanes = floor(Int, outplanes * reduction)
end
push!(layers, BatchNorm(outplanes, relu))

return Chain(Chain(layers...),
return Chain(Chain(layers),
Chain(AdaptiveMeanPool((1, 1)),
MLUtils.flatten,
Dense(outplanes, nclasses)))
Expand Down
8 changes: 4 additions & 4 deletions src/convnets/googlenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Create an inception module for use in GoogLeNet
"""
function _inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, pool_proj)
branch1 = Chain(Conv((1, 1), inplanes => out_1x1))

branch2 = Chain(Conv((1, 1), inplanes => red_3x3),
Conv((3, 3), red_3x3 => out_3x3; pad = 1))

branch3 = Chain(Conv((1, 1), inplanes => red_5x5),
Conv((5, 5), red_5x5 => out_5x5; pad = 2))

branch4 = Chain(MaxPool((3, 3), stride=1, pad = 1),
Conv((1, 1), inplanes => pool_proj))

return Parallel(cat_channels,
branch1, branch2, branch3, branch4)
end
Expand Down
92 changes: 46 additions & 46 deletions src/convnets/inception.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ Create an Inception-v3 style-A module
- `pool_proj`: the number of output feature maps for the pooling projection
"""
function inception_a(inplanes, pool_proj)
branch1x1 = Chain(conv_bn((1, 1), inplanes, 64)...)

branch5x5 = Chain(conv_bn((1, 1), inplanes, 48)...,
conv_bn((5, 5), 48, 64; pad = 2)...)
branch1x1 = conv_bn((1, 1), inplanes, 64)

branch3x3 = Chain(conv_bn((1, 1), inplanes, 64)...,
conv_bn((3, 3), 64, 96; pad = 1)...,
conv_bn((3, 3), 96, 96; pad = 1)...)
branch5x5 = Chain(conv_bn((1, 1), inplanes, 48),
conv_bn((5, 5), 48, 64; pad = 2))

branch3x3 = Chain(conv_bn((1, 1), inplanes, 64),
conv_bn((3, 3), 64, 96; pad = 1),
conv_bn((3, 3), 96, 96; pad = 1))

branch_pool = Chain(MeanPool((3, 3), pad = 1, stride = 1),
conv_bn((1, 1), inplanes, pool_proj)...)
conv_bn((1, 1), inplanes, pool_proj))

return Parallel(cat_channels,
branch1x1, branch5x5, branch3x3, branch_pool)
Expand All @@ -35,13 +35,13 @@ Create an Inception-v3 style-B module
- `inplanes`: number of input feature maps
"""
function inception_b(inplanes)
branch3x3_1 = Chain(conv_bn((3, 3), inplanes, 384; stride = 2)...)
branch3x3_1 = conv_bn((3, 3), inplanes, 384; stride = 2)

branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 64)...,
conv_bn((3, 3), 64, 96; pad = 1)...,
conv_bn((3, 3), 96, 96; stride = 2)...)
branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 64),
conv_bn((3, 3), 64, 96; pad = 1),
conv_bn((3, 3), 96, 96; stride = 2))

branch_pool = Chain(MaxPool((3, 3), stride = 2))
branch_pool = MaxPool((3, 3), stride = 2)

return Parallel(cat_channels,
branch3x3_1, branch3x3_2, branch_pool)
Expand All @@ -59,20 +59,20 @@ Create an Inception-v3 style-C module
- `n`: the "grid size" (kernel size) for the convolution layers
"""
function inception_c(inplanes, inner_planes, n = 7)
branch1x1 = Chain(conv_bn((1, 1), inplanes, 192)...)
branch1x1 = conv_bn((1, 1), inplanes, 192)

branch7x7_1 = Chain(conv_bn((1, 1), inplanes, inner_planes)...,
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))...,
conv_bn((n, 1), inner_planes, 192; pad = (3, 0))...)
branch7x7_1 = Chain(conv_bn((1, 1), inplanes, inner_planes),
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3)),
conv_bn((n, 1), inner_planes, 192; pad = (3, 0)))

branch7x7_2 = Chain(conv_bn((1, 1), inplanes, inner_planes)...,
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))...,
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))...,
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))...,
conv_bn((1, n), inner_planes, 192; pad = (0, 3))...)
branch7x7_2 = Chain(conv_bn((1, 1), inplanes, inner_planes),
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0)),
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3)),
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0)),
conv_bn((1, n), inner_planes, 192; pad = (0, 3)))

branch_pool = Chain(MeanPool((3, 3), pad = 1, stride=1),
conv_bn((1, 1), inplanes, 192)...)
branch_pool = Chain(MeanPool((3, 3), pad = 1, stride=1),
conv_bn((1, 1), inplanes, 192))

return Parallel(cat_channels,
branch1x1, branch7x7_1, branch7x7_2, branch_pool)
Expand All @@ -88,15 +88,15 @@ Create an Inception-v3 style-D module
- `inplanes`: number of input feature maps
"""
function inception_d(inplanes)
branch3x3 = Chain(conv_bn((1, 1), inplanes, 192)...,
conv_bn((3, 3), 192, 320; stride = 2)...)
branch3x3 = Chain(conv_bn((1, 1), inplanes, 192),
conv_bn((3, 3), 192, 320; stride = 2))

branch7x7x3 = Chain(conv_bn((1, 1), inplanes, 192)...,
conv_bn((1, 7), 192, 192; pad = (0, 3))...,
conv_bn((7, 1), 192, 192; pad = (3, 0))...,
conv_bn((3, 3), 192, 192; stride = 2)...)
branch7x7x3 = Chain(conv_bn((1, 1), inplanes, 192),
conv_bn((1, 7), 192, 192; pad = (0, 3)),
conv_bn((7, 1), 192, 192; pad = (3, 0)),
conv_bn((3, 3), 192, 192; stride = 2))

branch_pool = Chain(MaxPool((3, 3), stride=2))
branch_pool = MaxPool((3, 3), stride=2)

return Parallel(cat_channels,
branch3x3, branch7x7x3, branch_pool)
Expand All @@ -112,26 +112,26 @@ Create an Inception-v3 style-E module
- `inplanes`: number of input feature maps
"""
function inception_e(inplanes)
branch1x1 = Chain(conv_bn((1, 1), inplanes, 320)...)
branch1x1 = conv_bn((1, 1), inplanes, 320)

branch3x3_1 = Chain(conv_bn((1, 1), inplanes, 384)...)
branch3x3_1a = Chain(conv_bn((1, 3), 384, 384; pad = (0, 1))...)
branch3x3_1b = Chain(conv_bn((3, 1), 384, 384; pad = (1, 0))...)
branch3x3_1 = conv_bn((1, 1), inplanes, 384)
branch3x3_1a = conv_bn((1, 3), 384, 384; pad = (0, 1))
branch3x3_1b = conv_bn((3, 1), 384, 384; pad = (1, 0))

branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 448)...,
conv_bn((3, 3), 448, 384; pad = 1)...)
branch3x3_2a = Chain(conv_bn((1, 3), 384, 384; pad = (0, 1))...)
branch3x3_2b = Chain(conv_bn((3, 1), 384, 384; pad = (1, 0))...)
branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 448),
conv_bn((3, 3), 448, 384; pad = 1))
branch3x3_2a = conv_bn((1, 3), 384, 384; pad = (0, 1))
branch3x3_2b = conv_bn((3, 1), 384, 384; pad = (1, 0))

branch_pool = Chain(MeanPool((3, 3), pad = 1, stride = 1),
conv_bn((1, 1), inplanes, 192)...)
conv_bn((1, 1), inplanes, 192))

return Parallel(cat_channels,
branch1x1,
Chain(branch3x3_1,
Parallel(cat_channels,
branch3x3_1a, branch3x3_1b)),

Chain(branch3x3_2,
Parallel(cat_channels,
branch3x3_2a, branch3x3_2b)),
Expand All @@ -150,12 +150,12 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)).
`inception3` does not currently support pretrained weights.
"""
function inception3(; nclasses = 1000)
layer = Chain(Chain(conv_bn((3, 3), 3, 32; stride = 2)...,
conv_bn((3, 3), 32, 32)...,
conv_bn((3, 3), 32, 64; pad = 1)...,
layer = Chain(Chain(conv_bn((3, 3), 3, 32; stride = 2),
conv_bn((3, 3), 32, 32),
conv_bn((3, 3), 32, 64; pad = 1),
MaxPool((3, 3), stride = 2),
conv_bn((1, 1), 64, 80)...,
conv_bn((3, 3), 80, 192)...,
conv_bn((1, 1), 64, 80),
conv_bn((3, 3), 80, 192),
MaxPool((3, 3), stride = 2),
inception_a(192, 32),
inception_a(256, 64),
Expand Down
34 changes: 15 additions & 19 deletions src/convnets/mobilenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,15 @@ function mobilenetv1(width_mult, config;
for (dw, outch, stride, repeats) in config
outch = Int(outch * width_mult)
for _ in 1:repeats
layer = if dw
depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1)
else
conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1)
end
append!(layers, layer)
layer = dw ? depthwise_sep_conv_bn((3, 3), inchannels, outch, activation;
stride = stride, pad = 1) :
conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1)
push!(layers, layer)
inchannels = outch
end
end

return Chain(Chain(layers...),
return Chain(Chain(layers),
Chain(GlobalMeanPool(),
MLUtils.flatten,
Dense(inchannels, fcsize, activation),
Expand Down Expand Up @@ -120,7 +118,7 @@ function mobilenetv2(width_mult, configs; max_width = 1280, nclasses = 1000)
# building first layer
inplanes = _round_channels(32 * width_mult, width_mult == 0.1 ? 4 : 8)
layers = []
append!(layers, conv_bn((3, 3), 3, inplanes, stride = 2))
push!(layers, conv_bn((3, 3), 3, inplanes, stride = 2))

# building inverted residual blocks
for (t, c, n, s, a) in configs
Expand All @@ -136,8 +134,7 @@ function mobilenetv2(width_mult, configs; max_width = 1280, nclasses = 1000)
outplanes = (width_mult > 1) ? _round_channels(max_width * width_mult, width_mult == 0.1 ? 4 : 8) :
max_width

return Chain(Chain(layers...,
conv_bn((1, 1), inplanes, outplanes, relu6, bias = false)...),
return Chain(Chain(Chain(layers), conv_bn((1, 1), inplanes, outplanes, relu6, bias = false)),
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(outplanes, nclasses)))
end

Expand Down Expand Up @@ -186,7 +183,7 @@ end
(m::MobileNetv2)(x) = m.layers(x)

backbone(m::MobileNetv2) = m.layers[1]
classifier(m::MobileNetv2) = m.layers[2:end]
classifier(m::MobileNetv2) = m.layers[2]

# MobileNetv3

Expand Down Expand Up @@ -214,7 +211,7 @@ function mobilenetv3(width_mult, configs; max_width = 1024, nclasses = 1000)
# building first layer
inplanes = _round_channels(16 * width_mult, 8)
layers = []
append!(layers, conv_bn((3, 3), 3, inplanes, hardswish; stride = 2))
push!(layers, conv_bn((3, 3), 3, inplanes, hardswish; stride = 2))
explanes = 0
# building inverted residual blocks
for (k, t, c, r, a, s) in configs
Expand All @@ -229,13 +226,12 @@ function mobilenetv3(width_mult, configs; max_width = 1024, nclasses = 1000)
# building last several layers
output_channel = max_width
output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : output_channel
classifier = (Dense(explanes, output_channel, hardswish),
Dropout(0.2),
Dense(output_channel, nclasses))
classifier = Chain(Dense(explanes, output_channel, hardswish),
Dropout(0.2),
Dense(output_channel, nclasses))

return Chain(Chain(layers...,
conv_bn((1, 1), inplanes, explanes, hardswish, bias = false)...),
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier...))
return Chain(Chain(Chain(layers), conv_bn((1, 1), inplanes, explanes, hardswish, bias = false)),
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier))
end

# Configurations for small and large mode for MobileNetv3
Expand Down Expand Up @@ -310,4 +306,4 @@ end
(m::MobileNetv3)(x) = m.layers(x)

backbone(m::MobileNetv3) = m.layers[1]
classifier(m::MobileNetv3) = m.layers[2:end]
classifier(m::MobileNetv3) = m.layers[2]
16 changes: 8 additions & 8 deletions src/convnets/resnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ Create a basic residual block
"""
function basicblock(inplanes, outplanes, downsample = false)
stride = downsample ? 2 : 1
Chain(conv_bn((3, 3), inplanes, outplanes[1]; stride = stride, pad = 1, bias = false)...,
conv_bn((3, 3), outplanes[1], outplanes[2], identity; stride = 1, pad = 1, bias = false)...)
Chain(conv_bn((3, 3), inplanes, outplanes[1]; stride = stride, pad = 1, bias = false),
conv_bn((3, 3), outplanes[1], outplanes[2], identity; stride = 1, pad = 1, bias = false))
end

"""
Expand All @@ -36,9 +36,9 @@ The original paper uses `stride == [2, 1, 1]` when `downsample == true` instead.
"""
function bottleneck(inplanes, outplanes, downsample = false;
stride = [1, (downsample ? 2 : 1), 1])
Chain(conv_bn((1, 1), inplanes, outplanes[1]; stride = stride[1], bias = false)...,
conv_bn((3, 3), outplanes[1], outplanes[2]; stride = stride[2], pad = 1, bias = false)...,
conv_bn((1, 1), outplanes[2], outplanes[3], identity; stride = stride[3], bias = false)...)
Chain(conv_bn((1, 1), inplanes, outplanes[1]; stride = stride[1], bias = false),
conv_bn((3, 3), outplanes[1], outplanes[2]; stride = stride[2], pad = 1, bias = false),
conv_bn((1, 1), outplanes[2], outplanes[3], identity; stride = stride[3], bias = false))
end


Expand Down Expand Up @@ -82,7 +82,7 @@ function resnet(block, residuals::AbstractVector{<:NTuple{2, Any}}, connection =
inplanes = 64
baseplanes = 64
layers = []
append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = 3, bias = false))
push!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = 3, bias = false))
push!(layers, MaxPool((3, 3), stride = (2, 2), pad = (1, 1)))
for (i, nrepeats) in enumerate(block_config)
# output planes within a block
Expand All @@ -102,7 +102,7 @@ function resnet(block, residuals::AbstractVector{<:NTuple{2, Any}}, connection =
baseplanes *= 2
end

return Chain(Chain(layers...),
return Chain(Chain(layers),
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(inplanes, nclasses)))
end

Expand Down Expand Up @@ -246,7 +246,7 @@ function ResNet(depth::Int = 50; pretrain = false, nclasses = 1000)
model
end

# Compat with Methalhead 0.6; remove in 0.7
# Compat with Metalhead 0.6; remove in 0.7
@deprecate ResNet18(; kw...) ResNet(18; kw...)
@deprecate ResNet34(; kw...) ResNet(34; kw...)
@deprecate ResNet50(; kw...) ResNet(50; kw...)
Expand Down
Loading

0 comments on commit 792076f

Please sign in to comment.