From 79dcc993a25c9feaf9deefcb437d54c507719bbd Mon Sep 17 00:00:00 2001 From: RafaelT00 <92472445+RafaelT00@users.noreply.github.com> Date: Fri, 3 Nov 2023 05:44:35 +0100 Subject: [PATCH 01/12] Create shufflenet.jl ShuffleNet model --- src/convnets/shufflenet.jl | 160 +++++++++++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 src/convnets/shufflenet.jl diff --git a/src/convnets/shufflenet.jl b/src/convnets/shufflenet.jl new file mode 100644 index 00000000..d118f96a --- /dev/null +++ b/src/convnets/shufflenet.jl @@ -0,0 +1,160 @@ +using Flux + +""" +Channelshuffle(channels, groups) + +Channel shuffle operation from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices +([reference](https://arxiv.org/abs/1707.01083)). + +# Arguments + + - `channels`: number of channels + - `groups`: number of groups +""" +function ChannelShuffle(x::Array{Float32, 4}, g::Int) + width, height, channels, batch = size(x) + channels_per_group = channels÷g + if (channels % g) == 0 + x = reshape(x, (width, height, g, channels_per_group, batch)) + x = permutedims(x,(1,2,4,3,5)) + x = reshape(x, (width, height, channels, batch)) + end + return x +end + +""" +ShuffleUnit(in_channels::Integer, out_channels::Integer, grps::Integer, downsample::Bool, ignore_group::Bool) + +Shuffle Unit from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices +([reference](https://arxiv.org/abs/1707.01083)). + +# Arguments + + - `in_channels`: number of input channels + - `out_channels`: number of output channels + - `groups`: number of groups + - `downsample`: apply downsaple if true + - `ignore_group`: ignore group convolution if true +""" +function ShuffleUnit(in_channels::Integer, out_channels::Integer, groups::Integer, downsample::Bool, ignore_group::Bool) + mid_channels = out_channels ÷ 4 + groups = ignore_group ? 1 : groups + strd = downsample ? 2 : 1 + + if downsample + out_channels -= in_channels + end + + m = Chain(Conv((1,1), in_channels => mid_channels; groups,pad=SamePad()), + BatchNorm(mid_channels), + NNlib.relu, + x -> ChannelShuffle(x, groups), + DepthwiseConv((3,3), mid_channels => mid_channels; bias=false, stride=strd, pad=SamePad()), + BatchNorm(mid_channels), + NNlib.relu, + Conv((1,1), mid_channels => out_channels; groups, pad=SamePad()), + BatchNorm(out_channels), + NNlib.relu) + + if downsample + m = Parallel((mx, x) -> cat(mx, x, dims=3),m, MeanPool((3,3); pad=SamePad(), stride=2)) + else + m = SkipConnection(m, +) + end + return m +end + +""" +ShuffleNet(channels, init_block_channels::Integer, groups, num_classes; in_channels=3) + +ShuffleNet model from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices +([reference](https://arxiv.org/abs/1707.01083)). + +# Arguments + + - `channels`: list of channels per layer + - `init_block_channels`: number of output channels from the first layer + - `groups`: number of groups + - `num_classes`: number of classes + - `in_channels`: number of input channels +""" +function ShuffleNet(channels, init_block_channels::Integer, groups, num_classes; in_channels=3) + features = [] + + append!(features, [Conv((3,3), in_channels => init_block_channels; stride=2, pad=SamePad()), + BatchNorm(init_block_channels), + NNlib.relu, + MaxPool((3,3); stride=2, pad=SamePad())]) + + in_channels::Integer = init_block_channels + + for (i, num_channels) in enumerate(channels) + stage = [] + for (j, out_channels) in enumerate(num_channels) + downsample = j==1 + ignore_group = i==1 && j==1 + out_ch::Integer = trunc(out_channels) + push!(stage, ShuffleUnit(in_channels, out_ch, groups, downsample, ignore_group)) + in_channels = out_ch + end + append!(features, stage) + end + + model = Chain(features...) + + return Chain(model, GlobalMeanPool(), Flux.flatten, Dense(in_channels => num_classes)) +end + +""" +shufflenet(groups, width_scale, num_classes; in_channels=3) + +Wrapper for ShuffleNet. Create a ShuffleNet model from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices +([reference](https://arxiv.org/abs/1707.01083)). + +# Arguments + + - `groups`: number of groups + - `width_scale`: scaling factor for number of channels + - `num_classes`: number of classes + - `in_channels`: number of input channels +""" +function shufflenet(groups, width_scale, num_classes; in_channels=3) + init_block_channels = 24 + layers = [4, 8, 4] + + if groups == 1 + channels_per_layers = [144, 288, 576] + elseif groups == 2 + channels_per_layers = [200, 400, 800] + elseif groups == 3 + channels_per_layers = [240, 480, 960] + elseif groups == 4 + channels_per_layers = [272, 544, 1088] + elseif groups == 8 + channels_per_layers = [384, 768, 1536] + else + return error("The number of groups is not supported. Groups = ", groups) + end + + channels = [] + for i in eachindex(layers) + char = [channels_per_layers[i]] + new = repeat(char, layers[i]) + push!(channels, new) + end + + if width_scale != 1.0 + channels = channels*width_scale + + init_block_channels::Integer = trunc(init_block_channels * width_scale) + end + + net = ShuffleNet( + channels, + init_block_channels, + groups; + in_channels=in_channels, + num_classes=num_classes) + + return net +end From 08fb6b98b99a74fd59affb85207e81bde75b5b3a Mon Sep 17 00:00:00 2001 From: RafaelT00 <92472445+RafaelT00@users.noreply.github.com> Date: Fri, 3 Nov 2023 05:47:41 +0100 Subject: [PATCH 02/12] Update shufflenet.jl --- src/convnets/shufflenet.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/convnets/shufflenet.jl b/src/convnets/shufflenet.jl index d118f96a..0232268b 100644 --- a/src/convnets/shufflenet.jl +++ b/src/convnets/shufflenet.jl @@ -153,8 +153,8 @@ function shufflenet(groups, width_scale, num_classes; in_channels=3) channels, init_block_channels, groups; - in_channels=in_channels, - num_classes=num_classes) + in_channels, + num_classes) return net end From 9d01bf272063b871b6be53b41118945ab2383669 Mon Sep 17 00:00:00 2001 From: RafaelT00 Date: Tue, 7 May 2024 01:16:01 +0200 Subject: [PATCH 03/12] applied julia formatter --- src/convnets/shufflenet.jl | 63 ++++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/src/convnets/shufflenet.jl b/src/convnets/shufflenet.jl index 0232268b..6b91deeb 100644 --- a/src/convnets/shufflenet.jl +++ b/src/convnets/shufflenet.jl @@ -13,10 +13,10 @@ Channel shuffle operation from 'ShuffleNet: An Extremely Efficient Convolutional """ function ChannelShuffle(x::Array{Float32, 4}, g::Int) width, height, channels, batch = size(x) - channels_per_group = channels÷g + channels_per_group = channels ÷ g if (channels % g) == 0 x = reshape(x, (width, height, g, channels_per_group, batch)) - x = permutedims(x,(1,2,4,3,5)) + x = permutedims(x, (1, 2, 4, 3, 5)) x = reshape(x, (width, height, channels, batch)) end return x @@ -36,7 +36,8 @@ Shuffle Unit from 'ShuffleNet: An Extremely Efficient Convolutional Neural Netwo - `downsample`: apply downsaple if true - `ignore_group`: ignore group convolution if true """ -function ShuffleUnit(in_channels::Integer, out_channels::Integer, groups::Integer, downsample::Bool, ignore_group::Bool) +function ShuffleUnit(in_channels::Integer, out_channels::Integer, + groups::Integer, downsample::Bool, ignore_group::Bool) mid_channels = out_channels ÷ 4 groups = ignore_group ? 1 : groups strd = downsample ? 2 : 1 @@ -45,19 +46,21 @@ function ShuffleUnit(in_channels::Integer, out_channels::Integer, groups::Intege out_channels -= in_channels end - m = Chain(Conv((1,1), in_channels => mid_channels; groups,pad=SamePad()), - BatchNorm(mid_channels), - NNlib.relu, - x -> ChannelShuffle(x, groups), - DepthwiseConv((3,3), mid_channels => mid_channels; bias=false, stride=strd, pad=SamePad()), - BatchNorm(mid_channels), - NNlib.relu, - Conv((1,1), mid_channels => out_channels; groups, pad=SamePad()), - BatchNorm(out_channels), - NNlib.relu) - + m = Chain(Conv((1, 1), in_channels => mid_channels; groups, pad = SamePad()), + BatchNorm(mid_channels), + NNlib.relu, + x -> ChannelShuffle(x, groups), + DepthwiseConv((3, 3), mid_channels => mid_channels; + bias = false, stride = strd, pad = SamePad()), + BatchNorm(mid_channels), + NNlib.relu, + Conv((1, 1), mid_channels => out_channels; groups, pad = SamePad()), + BatchNorm(out_channels), + NNlib.relu) + if downsample - m = Parallel((mx, x) -> cat(mx, x, dims=3),m, MeanPool((3,3); pad=SamePad(), stride=2)) + m = Parallel((mx, x) -> cat(mx, x; dims = 3), m, + MeanPool((3, 3); pad = SamePad(), stride = 2)) else m = SkipConnection(m, +) end @@ -78,30 +81,32 @@ ShuffleNet model from 'ShuffleNet: An Extremely Efficient Convolutional Neural N - `num_classes`: number of classes - `in_channels`: number of input channels """ -function ShuffleNet(channels, init_block_channels::Integer, groups, num_classes; in_channels=3) +function ShuffleNet( + channels, init_block_channels::Integer, groups, num_classes; in_channels = 3) features = [] - append!(features, [Conv((3,3), in_channels => init_block_channels; stride=2, pad=SamePad()), - BatchNorm(init_block_channels), - NNlib.relu, - MaxPool((3,3); stride=2, pad=SamePad())]) + append!(features, + [Conv((3, 3), in_channels => init_block_channels; stride = 2, pad = SamePad()), + BatchNorm(init_block_channels), + NNlib.relu, + MaxPool((3, 3); stride = 2, pad = SamePad())]) in_channels::Integer = init_block_channels - + for (i, num_channels) in enumerate(channels) stage = [] for (j, out_channels) in enumerate(num_channels) - downsample = j==1 - ignore_group = i==1 && j==1 + downsample = j == 1 + ignore_group = i == 1 && j == 1 out_ch::Integer = trunc(out_channels) push!(stage, ShuffleUnit(in_channels, out_ch, groups, downsample, ignore_group)) in_channels = out_ch end append!(features, stage) end - + model = Chain(features...) - + return Chain(model, GlobalMeanPool(), Flux.flatten, Dense(in_channels => num_classes)) end @@ -118,7 +123,7 @@ Wrapper for ShuffleNet. Create a ShuffleNet model from 'ShuffleNet: An Extremely - `num_classes`: number of classes - `in_channels`: number of input channels """ -function shufflenet(groups, width_scale, num_classes; in_channels=3) +function shufflenet(groups, width_scale, num_classes; in_channels = 3) init_block_channels = 24 layers = [4, 8, 4] @@ -144,7 +149,7 @@ function shufflenet(groups, width_scale, num_classes; in_channels=3) end if width_scale != 1.0 - channels = channels*width_scale + channels = channels * width_scale init_block_channels::Integer = trunc(init_block_channels * width_scale) end @@ -152,8 +157,8 @@ function shufflenet(groups, width_scale, num_classes; in_channels=3) net = ShuffleNet( channels, init_block_channels, - groups; - in_channels, + groups; + in_channels, num_classes) return net From 894ae7a9de9655bb3c516053cd6af5b2eede9d29 Mon Sep 17 00:00:00 2001 From: RafaelT00 Date: Tue, 7 May 2024 03:42:33 +0200 Subject: [PATCH 04/12] deleted () from if --- src/convnets/shufflenet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/convnets/shufflenet.jl b/src/convnets/shufflenet.jl index 6b91deeb..b4ba3d03 100644 --- a/src/convnets/shufflenet.jl +++ b/src/convnets/shufflenet.jl @@ -14,7 +14,7 @@ Channel shuffle operation from 'ShuffleNet: An Extremely Efficient Convolutional function ChannelShuffle(x::Array{Float32, 4}, g::Int) width, height, channels, batch = size(x) channels_per_group = channels ÷ g - if (channels % g) == 0 + if channels % g == 0 x = reshape(x, (width, height, g, channels_per_group, batch)) x = permutedims(x, (1, 2, 4, 3, 5)) x = reshape(x, (width, height, channels, batch)) From 597aa2fc94fdb6d807ae71d914c1af5ee7ae1cfe Mon Sep 17 00:00:00 2001 From: RafaelT00 Date: Tue, 7 May 2024 22:27:28 +0200 Subject: [PATCH 05/12] fused relu into BatchNorm --- src/convnets/shufflenet.jl | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/convnets/shufflenet.jl b/src/convnets/shufflenet.jl index b4ba3d03..be00056a 100644 --- a/src/convnets/shufflenet.jl +++ b/src/convnets/shufflenet.jl @@ -1,7 +1,7 @@ using Flux """ -Channelshuffle(channels, groups) +channel_shuffle(channels, groups) Channel shuffle operation from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices ([reference](https://arxiv.org/abs/1707.01083)). @@ -11,7 +11,7 @@ Channel shuffle operation from 'ShuffleNet: An Extremely Efficient Convolutional - `channels`: number of channels - `groups`: number of groups """ -function ChannelShuffle(x::Array{Float32, 4}, g::Int) +function channel_shuffle(x::AbstractArray{Float32, 4}, g::Int) width, height, channels, batch = size(x) channels_per_group = channels ÷ g if channels % g == 0 @@ -47,16 +47,13 @@ function ShuffleUnit(in_channels::Integer, out_channels::Integer, end m = Chain(Conv((1, 1), in_channels => mid_channels; groups, pad = SamePad()), - BatchNorm(mid_channels), - NNlib.relu, - x -> ChannelShuffle(x, groups), + BatchNorm(mid_channels, NNlib.relu), + x -> channel_shuffle(x, groups), DepthwiseConv((3, 3), mid_channels => mid_channels; bias = false, stride = strd, pad = SamePad()), - BatchNorm(mid_channels), - NNlib.relu, + BatchNorm(mid_channels, NNlib.relu), Conv((1, 1), mid_channels => out_channels; groups, pad = SamePad()), - BatchNorm(out_channels), - NNlib.relu) + BatchNorm(out_channels, NNlib.relu)) if downsample m = Parallel((mx, x) -> cat(mx, x; dims = 3), m, @@ -87,8 +84,7 @@ function ShuffleNet( append!(features, [Conv((3, 3), in_channels => init_block_channels; stride = 2, pad = SamePad()), - BatchNorm(init_block_channels), - NNlib.relu, + BatchNorm(init_block_channels, NNlib.relu), MaxPool((3, 3); stride = 2, pad = SamePad())]) in_channels::Integer = init_block_channels From 873dc51d2fdb9451b3c0a462229ae39b6709063d Mon Sep 17 00:00:00 2001 From: RafaelT00 Date: Tue, 7 May 2024 22:34:08 +0200 Subject: [PATCH 06/12] replaced anonymous functions by fix2 --- src/convnets/shufflenet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/convnets/shufflenet.jl b/src/convnets/shufflenet.jl index be00056a..5f37dd81 100644 --- a/src/convnets/shufflenet.jl +++ b/src/convnets/shufflenet.jl @@ -48,7 +48,7 @@ function ShuffleUnit(in_channels::Integer, out_channels::Integer, m = Chain(Conv((1, 1), in_channels => mid_channels; groups, pad = SamePad()), BatchNorm(mid_channels, NNlib.relu), - x -> channel_shuffle(x, groups), + Base.Fix2(channel_shuffle, groups), DepthwiseConv((3, 3), mid_channels => mid_channels; bias = false, stride = strd, pad = SamePad()), BatchNorm(mid_channels, NNlib.relu), From 9d91f81c4d3e9f69f0eb7b38679a1f22b1da1182 Mon Sep 17 00:00:00 2001 From: RafaelT00 Date: Tue, 7 May 2024 22:37:27 +0200 Subject: [PATCH 07/12] using cat_channels instead of an annonymous function --- src/convnets/shufflenet.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/convnets/shufflenet.jl b/src/convnets/shufflenet.jl index 5f37dd81..68ff5843 100644 --- a/src/convnets/shufflenet.jl +++ b/src/convnets/shufflenet.jl @@ -56,8 +56,7 @@ function ShuffleUnit(in_channels::Integer, out_channels::Integer, BatchNorm(out_channels, NNlib.relu)) if downsample - m = Parallel((mx, x) -> cat(mx, x; dims = 3), m, - MeanPool((3, 3); pad = SamePad(), stride = 2)) + m = Parallel(cat_channels, m, MeanPool((3,3); pad=SamePad(), stride=2)) else m = SkipConnection(m, +) end From f366f779c4bc5c857535996b3f884845fdd990d1 Mon Sep 17 00:00:00 2001 From: RafaelT00 Date: Tue, 7 May 2024 22:40:23 +0200 Subject: [PATCH 08/12] applied JuliaFormatter --- src/convnets/shufflenet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/convnets/shufflenet.jl b/src/convnets/shufflenet.jl index 68ff5843..146a2145 100644 --- a/src/convnets/shufflenet.jl +++ b/src/convnets/shufflenet.jl @@ -102,7 +102,7 @@ function ShuffleNet( model = Chain(features...) - return Chain(model, GlobalMeanPool(), Flux.flatten, Dense(in_channels => num_classes)) + return Chain(model, GlobalMeanPool(), MLUtils.flatten, Dense(in_channels => num_classes)) end """ From 526df7ae426247f3f36448a705ee874fb815e407 Mon Sep 17 00:00:00 2001 From: RafaelT00 Date: Fri, 14 Jun 2024 02:38:11 +0200 Subject: [PATCH 09/12] added test for ShuffleNet --- src/convnets/shufflenet.jl | 13 +++++++------ test/convnet_tests.jl | 11 +++++++++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/convnets/shufflenet.jl b/src/convnets/shufflenet.jl index 146a2145..19799f21 100644 --- a/src/convnets/shufflenet.jl +++ b/src/convnets/shufflenet.jl @@ -1,4 +1,4 @@ -using Flux +using Flux, Metalhead, MLUtils """ channel_shuffle(channels, groups) @@ -56,7 +56,7 @@ function ShuffleUnit(in_channels::Integer, out_channels::Integer, BatchNorm(out_channels, NNlib.relu)) if downsample - m = Parallel(cat_channels, m, MeanPool((3,3); pad=SamePad(), stride=2)) + m = Parallel(Metalhead.cat_channels, m, MeanPool((3, 3); pad = SamePad(), stride = 2)) else m = SkipConnection(m, +) end @@ -102,7 +102,8 @@ function ShuffleNet( model = Chain(features...) - return Chain(model, GlobalMeanPool(), MLUtils.flatten, Dense(in_channels => num_classes)) + return Chain( + model, GlobalMeanPool(), MLUtils.flatten, Dense(in_channels => num_classes)) end """ @@ -152,9 +153,9 @@ function shufflenet(groups, width_scale, num_classes; in_channels = 3) net = ShuffleNet( channels, init_block_channels, - groups; - in_channels, - num_classes) + groups, + num_classes; + in_channels) return net end diff --git a/test/convnet_tests.jl b/test/convnet_tests.jl index 630de19e..010a0860 100644 --- a/test/convnet_tests.jl +++ b/test/convnet_tests.jl @@ -374,3 +374,14 @@ end @test size(model(x_256)) == (256, 256, 3, 1) _gc() end + + +@testitem "ShuffleNet" setup=[TestModels] begin + configs = TEST_FAST ? [(1, 1)] : [(1, 1), (2, 1), (3, 1), (4, 1), (8, 1), (1, 0.75), (3, 0.75), (1, 0.5), (3, 0.5), (1, 0.25), (3, 0.25)] + @testset for (groups, width_scale) in configs + m = shufflenet(groups, width_scale, 1000) |> gpu + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() + end +end From 92612181558dd1327479db82b02cd3f33c8c0021 Mon Sep 17 00:00:00 2001 From: RafaelT00 Date: Fri, 14 Jun 2024 06:42:34 +0200 Subject: [PATCH 10/12] created SheffleNet structure, better matching the code style of the rest of the repo --- src/convnets/shufflenet.jl | 61 ++++++++++++++++++++++++++++++-------- test/convnet_tests.jl | 11 +++---- 2 files changed, 54 insertions(+), 18 deletions(-) diff --git a/src/convnets/shufflenet.jl b/src/convnets/shufflenet.jl index 19799f21..428f796e 100644 --- a/src/convnets/shufflenet.jl +++ b/src/convnets/shufflenet.jl @@ -1,4 +1,4 @@ -using Flux, Metalhead, MLUtils +using Flux, Metalhead, MLUtils, Functors """ channel_shuffle(channels, groups) @@ -56,7 +56,8 @@ function ShuffleUnit(in_channels::Integer, out_channels::Integer, BatchNorm(out_channels, NNlib.relu)) if downsample - m = Parallel(Metalhead.cat_channels, m, MeanPool((3, 3); pad = SamePad(), stride = 2)) + m = Parallel( + Metalhead.cat_channels, m, MeanPool((3, 3); pad = SamePad(), stride = 2)) else m = SkipConnection(m, +) end @@ -64,7 +65,7 @@ function ShuffleUnit(in_channels::Integer, out_channels::Integer, end """ -ShuffleNet(channels, init_block_channels::Integer, groups, num_classes; in_channels=3) +create_shufflenet(channels, init_block_channels::Integer, groups, num_classes; in_channels=3) ShuffleNet model from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices ([reference](https://arxiv.org/abs/1707.01083)). @@ -77,8 +78,9 @@ ShuffleNet model from 'ShuffleNet: An Extremely Efficient Convolutional Neural N - `num_classes`: number of classes - `in_channels`: number of input channels """ -function ShuffleNet( - channels, init_block_channels::Integer, groups, num_classes; in_channels = 3) +function create_shufflenet( + channels, init_block_channels::Integer, groups::Integer, + num_classes::Integer; in_channels::Integer = 3) features = [] append!(features, @@ -101,15 +103,15 @@ function ShuffleNet( end model = Chain(features...) + classifier = Chain(GlobalMeanPool(), MLUtils.flatten, Dense(in_channels => num_classes)) - return Chain( - model, GlobalMeanPool(), MLUtils.flatten, Dense(in_channels => num_classes)) + return Chain(model, classifier) end """ shufflenet(groups, width_scale, num_classes; in_channels=3) -Wrapper for ShuffleNet. Create a ShuffleNet model from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices +Create a ShuffleNet model from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices ([reference](https://arxiv.org/abs/1707.01083)). # Arguments @@ -119,9 +121,11 @@ Wrapper for ShuffleNet. Create a ShuffleNet model from 'ShuffleNet: An Extremely - `num_classes`: number of classes - `in_channels`: number of input channels """ -function shufflenet(groups, width_scale, num_classes; in_channels = 3) + +function shufflenet(groups::Integer = 1, width_scale::Real = 1; + num_classes::Integer = 1000, in_channels::Integer = 3) init_block_channels = 24 - layers = [4, 8, 4] + nlayers = [4, 8, 4] if groups == 1 channels_per_layers = [144, 288, 576] @@ -138,9 +142,9 @@ function shufflenet(groups, width_scale, num_classes; in_channels = 3) end channels = [] - for i in eachindex(layers) + for i in eachindex(nlayers) char = [channels_per_layers[i]] - new = repeat(char, layers[i]) + new = repeat(char, nlayers[i]) push!(channels, new) end @@ -150,7 +154,7 @@ function shufflenet(groups, width_scale, num_classes; in_channels = 3) init_block_channels::Integer = trunc(init_block_channels * width_scale) end - net = ShuffleNet( + net = create_shufflenet( channels, init_block_channels, groups, @@ -159,3 +163,34 @@ function shufflenet(groups, width_scale, num_classes; in_channels = 3) return net end + +""" +ShuffleNet(groups, width_scale, num_classes; in_channels=3) + +Create a ShuffleNet model from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices +([reference](https://arxiv.org/abs/1707.01083)). + +# Arguments + + - `groups`: number of groups + - `width_scale`: scaling factor for number of channels + - `num_classes`: number of classes + - `in_channels`: number of input channels +""" +struct ShuffleNet + layers::Any +endstructures +@functor ShuffleNet + +function ShuffleNet(groups::Integer = 1, width_scale::Real = 1; + num_classes::Integer = 1000, in_channels::Integer = 3) + layers = shufflenet(groups, width_scale; num_classes, in_channels) + model = ShuffleNet(layers) + + return model +end + +(m::ShuffleNet)(x) = m.layers(x) + +backbone(m::ShuffleNet) = m.layers[1] +classifier(m::ShuffleNet) = m.layers[2:end] diff --git a/test/convnet_tests.jl b/test/convnet_tests.jl index 010a0860..9eaace6d 100644 --- a/test/convnet_tests.jl +++ b/test/convnet_tests.jl @@ -40,13 +40,13 @@ end [2, 2, 2, 2], [3, 4, 6, 3], [3, 4, 23, 3], - [3, 8, 36, 3], + [3, 8, 36, 3] ] @testset for layers in layer_list drop_list = [ (dropout_prob = 0.1, stochastic_depth_prob = 0.1, dropblock_prob = 0.1), (dropout_prob = 0.5, stochastic_depth_prob = 0.5, dropblock_prob = 0.5), - (dropout_prob = 0.8, stochastic_depth_prob = 0.8, dropblock_prob = 0.8), + (dropout_prob = 0.8, stochastic_depth_prob = 0.8, dropblock_prob = 0.8) ] @testset for drop_probs in drop_list m = Metalhead.resnet(block_fn, layers; drop_probs...) |> gpu @@ -375,11 +375,12 @@ end _gc() end - @testitem "ShuffleNet" setup=[TestModels] begin - configs = TEST_FAST ? [(1, 1)] : [(1, 1), (2, 1), (3, 1), (4, 1), (8, 1), (1, 0.75), (3, 0.75), (1, 0.5), (3, 0.5), (1, 0.25), (3, 0.25)] + configs = TEST_FAST ? [(1, 1)] : + [(1, 1), (2, 1), (3, 1), (4, 1), (8, 1), (1, 0.75), + (3, 0.75), (1, 0.5), (3, 0.5), (1, 0.25), (3, 0.25)] @testset for (groups, width_scale) in configs - m = shufflenet(groups, width_scale, 1000) |> gpu + m = ShuffleNet(groups, width_scale) |> gpu @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) _gc() From 73f73ca686aecc65f16693ed44893060c73e8111 Mon Sep 17 00:00:00 2001 From: RafaelT00 Date: Fri, 14 Jun 2024 21:06:08 +0200 Subject: [PATCH 11/12] corrected typo --- src/convnets/shufflenet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/convnets/shufflenet.jl b/src/convnets/shufflenet.jl index 428f796e..ea1df4d5 100644 --- a/src/convnets/shufflenet.jl +++ b/src/convnets/shufflenet.jl @@ -179,7 +179,7 @@ Create a ShuffleNet model from 'ShuffleNet: An Extremely Efficient Convolutional """ struct ShuffleNet layers::Any -endstructures +end @functor ShuffleNet function ShuffleNet(groups::Integer = 1, width_scale::Real = 1; From 7093459b59bd9802fc8d517c4cf8820c2d2cc48f Mon Sep 17 00:00:00 2001 From: RafaelT00 Date: Sat, 15 Jun 2024 00:03:42 +0200 Subject: [PATCH 12/12] added missing includes --- src/Metalhead.jl | 7 +++++-- src/convnets/shufflenet.jl | 4 ++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 7eb81282..1efa7b9f 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -69,6 +69,9 @@ include("mixers/gmlp.jl") # ViTs include("vit-based/vit.jl") +## ShuffleNet +include("convnets/shufflenet.jl") + # Load pretrained weights include("pretrain.jl") @@ -81,7 +84,7 @@ export AlexNet, VGG, ResNet, WideResNet, ResNeXt, DenseNet, SEResNet, SEResNeXt, Res2Net, Res2NeXt, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, MNASNet, EfficientNet, EfficientNetv2, ConvMixer, ConvNeXt, - MLPMixer, ResMLP, gMLP, ViT, UNet + MLPMixer, ResMLP, gMLP, ViT, UNet, ShuffleNet # useful for feature extraction export backbone, classifier @@ -92,7 +95,7 @@ for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, :MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet, :EfficientNet, :EfficientNetv2, :ConvMixer, :ConvNeXt, - :MLPMixer, :ResMLP, :gMLP, :ViT, :UNet) + :MLPMixer, :ResMLP, :gMLP, :ViT, :UNet, :ShuffleNet) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/shufflenet.jl b/src/convnets/shufflenet.jl index ea1df4d5..ee7fa62a 100644 --- a/src/convnets/shufflenet.jl +++ b/src/convnets/shufflenet.jl @@ -194,3 +194,7 @@ end backbone(m::ShuffleNet) = m.layers[1] classifier(m::ShuffleNet) = m.layers[2:end] + +im = rand32(224, 224, 3, 50); # a batch of 50 RGB images +m = ShuffleNet(1, 1;num_classes=10) +println(m(im) |> size) \ No newline at end of file