Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible way to implement a LoopVectorization extension for conv2d & meanpool2d & activations #540

Open
wants to merge 37 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
6d32073
Add files via upload
jonas208 Sep 26, 2023
2a3cf3e
Delete ext/NNlibLoopVectorizationExt/conv_old.jl
jonas208 Sep 26, 2023
28a027a
Delete ext/NNlibLoopVectorizationExt/pooling_old.jl
jonas208 Sep 26, 2023
5339aaf
Add files via upload
jonas208 Sep 26, 2023
9e0dc6d
Add files via upload
jonas208 Sep 26, 2023
dd0f0ed
Add files via upload
jonas208 Sep 26, 2023
b341d1c
Add files via upload
jonas208 Sep 26, 2023
94f7964
Update runtests.jl
jonas208 Sep 26, 2023
6cc2e75
Add files via upload
jonas208 Sep 27, 2023
c5c79ee
Add files via upload
jonas208 Sep 27, 2023
132e35c
Add files via upload
jonas208 Sep 27, 2023
aa019e9
Add files via upload
jonas208 Sep 27, 2023
52e2a78
Add files via upload
jonas208 Sep 27, 2023
d63f8a5
Add files via upload
jonas208 Sep 28, 2023
ae86d13
Add files via upload
jonas208 Sep 28, 2023
776835d
Add files via upload
jonas208 Sep 28, 2023
13205da
Add files via upload
jonas208 Sep 28, 2023
5850341
Add files via upload
jonas208 Sep 28, 2023
00b28f2
Add files via upload
jonas208 Sep 28, 2023
af04cc6
Delete runtests.jl
jonas208 Sep 28, 2023
990a34c
Delete Project.toml
jonas208 Sep 28, 2023
db0ad66
Add files via upload
jonas208 Sep 28, 2023
a4e18e6
Add files via upload
jonas208 Sep 29, 2023
f584377
Add files via upload
jonas208 Sep 30, 2023
274db10
Add files via upload
jonas208 Sep 30, 2023
6c33d5c
Add files via upload
jonas208 Sep 30, 2023
7affd46
Add files via upload
jonas208 Oct 3, 2023
3130f8a
Add files via upload
jonas208 Oct 3, 2023
d87f909
Add files via upload
jonas208 Oct 7, 2023
82abca8
Add files via upload
jonas208 Oct 7, 2023
c5ec713
Delete bench_torch.py
jonas208 Oct 8, 2023
3f1c6dc
Add files via upload
jonas208 Oct 8, 2023
8dde5f9
Add files via upload
jonas208 Oct 8, 2023
5505157
Add files via upload
jonas208 Oct 8, 2023
0aa3a3f
Add files via upload
jonas208 Oct 8, 2023
35f2b77
Add files via upload
jonas208 Oct 8, 2023
07943d7
Add files via upload
jonas208 Oct 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add files via upload
jonas208 authored Oct 3, 2023
commit 7affd464c174413d8d24ff05f5ef7b30041a0862
455 changes: 327 additions & 128 deletions ext/NNlibLoopVectorizationExt/conv.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,273 @@
#=
Accelerated convolution for 2d-images using the power of LoopVectorization.
The acceleration is usually greatest when the inputs have a large spatial size and few channels.
Using stride > 1, dilation > 1 or groups > 1 can slow down things a bit.
# # pad naivly, static iters
function NNlib.conv!(output::Array{T,4}, input::Array{T,4}, weight::Array{T,4}, cdims::ConvDims) where {T<:Real}

# fix for groupcount > 1 (NNlib.check_dims would throw an error otherwise)
size_weight_check_dims = (size(weight)[1:2]..., size(weight)[3]*cdims.groupcount, size(weight)[4])
cdims_check_dims = DenseConvDims(size(input), size_weight_check_dims, stride=cdims.stride, padding=cdims.padding, dilation=cdims.dilation, groups=1, flipkernel=cdims.flipkernel)
NNlib.check_dims(size(input), size_weight_check_dims, size(output), cdims_check_dims)

# padding is done naively at the moment
if cdims.padding != (0, 0, 0, 0)
input = NNlib.pad_zeros(input, cdims.padding, dims=(1, 2))
end

output_width, output_height, _ = size(output)
input_width, input_height, in_channels, batch_size = size(input)
weight_width, weight_height, in_channels_weight, out_channels = size(weight)

# it's necessary to flip the kernel if real convolution is performed (flipkernel=false)
if !NNlib.flipkernel(cdims)
weight = reverse(weight, dims=(1, 2))
end

groups = cdims.groupcount
x_stride, y_stride = cdims.stride
x_dilation, y_dilation = cdims.dilation
out_channels_per_group = out_channels ÷ groups

if cdims.groupcount == 1 && cdims.stride == (1, 1) && cdims.dilation == (1, 1) # very specialized case for maximum performance
# println("forward: very specialized case for maximum performance")

@tturbo for index_batch in 1:batch_size
for out_channel in 1:out_channels, y_out in 1:output_height, x_out in 1:output_width
value = zero(T)
for in_channel in static(1):static(in_channels), y_w in static(1):static(weight_height), x_w in static(1):static(weight_width)
value += input[x_out + x_w - 1, y_out + y_w - 1, in_channel, index_batch] * weight[x_w, y_w, in_channel, out_channel]
end
output[x_out, y_out, out_channel, index_batch] = value
end
end

elseif groups == 1 && cdims.dilation == (1, 1) # second specialized case for better performance
# println("forward: second specialized case for better performance")

@tturbo for index_batch in 1:batch_size
for out_channel in 1:out_channels, y_out in 1:output_height, x_out in 1:output_width
m = y_out + static((y_stride - 1)) * (y_out - 1)
n = x_out + static((x_stride - 1)) * (x_out - 1)
value = zero(T)
for in_channel in static(1):static(in_channels), y_w in static(1):static(weight_height), x_w in static(1):static(weight_width)
# y_in = m + (y_w - 1)
# x_in = n + (x_w - 1)
value += input[n + (x_w - 1), m + (y_w - 1), in_channel, index_batch] * weight[x_w, y_w, in_channel, out_channel]
end
output[x_out, y_out, out_channel, index_batch] = value
end
end

elseif groups == 1 # third specialized case for better performance
# println("forward: third specialized case for better performance")

@tturbo for index_batch in 1:batch_size
for out_channel in 1:out_channels, y_out in 1:output_height, x_out in 1:output_width
m = y_out + static((y_stride - 1)) * (y_out - 1)
n = x_out + static((x_stride - 1)) * (x_out - 1)
value = zero(T)
for in_channel in static(1):static(in_channels), y_w in static(1):static(weight_height), x_w in static(1):static(weight_width)
# y_in = m + (y_w - 1) * y_dilation
# x_in = n + (x_w - 1) * x_dilation
value += input[n + (x_w - 1) * x_dilation, m + (y_w - 1) * y_dilation, in_channel, index_batch] * weight[x_w, y_w, in_channel, out_channel]
end
output[x_out, y_out, out_channel, index_batch] = value
end
end

else # general case for any convolution
# println("forward: general case for any convolution")

@tturbo for index_batch in 1:batch_size
for group in 1:groups, out_channel_per_group in 1:out_channels_per_group, y_out in 1:output_height, x_out in 1:output_width
m = y_out + static((y_stride - 1)) * (y_out - 1)
n = x_out + static((x_stride - 1)) * (x_out - 1)
out_channel = (group * out_channels_per_group + 1) - out_channel_per_group
value = zero(T)
for in_channel_weight in static(1):static(in_channels_weight), y_w in static(1):static(weight_height), x_w in static(1):static(weight_width)
# y_in = m + (y_w - 1) * y_dilation
# x_in = n + (x_w - 1) * x_dilation
in_channel_input = in_channel_weight + (group - 1) * in_channels_weight
value += input[n + (x_w - 1) * x_dilation, m + (y_w - 1) * y_dilation, in_channel_input, index_batch] * weight[x_w, y_w, in_channel_weight, out_channel]
end
output[x_out, y_out, out_channel, index_batch] = value
end
end

end

return output
end

#= # pad bounds check
function NNlib.conv!(output::Array{T,4}, input::Array{T,4}, weight::Array{T,4}, cdims::ConvDims; kw...) where {T<:Real}
# fix for groupcount > 1 (NNlib.check_dims would throw an error otherwise)
size_weight_check_dims = (size(weight)[1:2]..., size(weight)[3]*cdims.groupcount, size(weight)[4])
cdims_check_dims = DenseConvDims(size(input), size_weight_check_dims, stride=cdims.stride, padding=cdims.padding, dilation=cdims.dilation, groups=1, flipkernel=cdims.flipkernel)
NNlib.check_dims(size(input), size_weight_check_dims, size(output), cdims_check_dims)
output_width, output_height, _ = size(output)
input_width, input_height, in_channels, batch_size = size(input)
weight_width, weight_height, in_channels_weight, out_channels = size(weight)
# it's necessary to flip the kernel if real convolution is performed (flipkernel=false)
if !NNlib.flipkernel(cdims)
weight = reverse(weight, dims=(1, 2))
end
groups = cdims.groupcount
x_stride, y_stride = cdims.stride
x_dilation, y_dilation = cdims.dilation
x_pad1, x_pad2, y_pad1, y_pad2 = cdims.padding
out_channels_per_group = out_channels ÷ groups
# We use calc_padding_regions to split outselves up into separate regions that may or
# may not need to worry about padding:
cdims_3d = DenseConvDims((input_width, input_height, 1, in_channels, batch_size), (weight_width, weight_height, 1, in_channels_weight, out_channels), stride=(x_stride, y_stride, 1), padding=(x_pad1, x_pad2, y_pad1, y_pad2, 0, 0), dilation=(x_dilation, y_dilation, 1))
# println(pdims_3d.padding)
padded_regions, central_region = NNlib.calc_padding_regions(cdims_3d)
# Start with the central region
w_region, h_region, _ = central_region
if cdims.groupcount == 1 && cdims.stride == (1, 1) && cdims.dilation == (1, 1) # very specialized case for maximum performance
# println("forward: very specialized case for maximum performance")
@tturbo for index_batch in 1:batch_size
for out_channel in 1:out_channels, y_out in h_region, x_out in w_region
value = zero(T)
for in_channel in static(1):static(in_channels), y_w in static(1):static(weight_height), x_w in static(1):static(weight_width)
value += input[x_out + x_w - 1 - x_pad1, y_out + y_w - 1 - y_pad1, in_channel, index_batch] * weight[x_w, y_w, in_channel, out_channel]
end
output[x_out, y_out, out_channel, index_batch] = value
end
end
elseif groups == 1 && cdims.dilation == (1, 1) # second specialized case for better performance
# println("forward: second specialized case for better performance")
@tturbo for index_batch in 1:batch_size
for out_channel in 1:out_channels, y_out in h_region, x_out in w_region
m = y_out + static((y_stride - 1)) * (y_out - 1) - static(y_pad1)
n = x_out + static((x_stride - 1)) * (x_out - 1) - static(x_pad1)
value = zero(T)
for in_channel in static(1):static(in_channels), y_w in static(1):static(weight_height), x_w in static(1):static(weight_width)
# y_in = m + (y_w - 1)
# x_in = n + (x_w - 1)
value += input[n + (x_w - 1), m + (y_w - 1), in_channel, index_batch] * weight[x_w, y_w, in_channel, out_channel]
end
output[x_out, y_out, out_channel, index_batch] = value
end
end
elseif groups == 1 # third specialized case for better performance
# println("forward: third specialized case for better performance")
@tturbo for index_batch in 1:batch_size
for out_channel in 1:out_channels, y_out in h_region, x_out in w_region
m = y_out + static((y_stride - 1)) * (y_out - 1) - static(y_pad1)
n = x_out + static((x_stride - 1)) * (x_out - 1) - static(x_pad1)
value = zero(T)
for in_channel in static(1):static(in_channels), y_w in static(1):static(weight_height), x_w in static(1):static(weight_width)
# y_in = m + (y_w - 1) * y_dilation
# x_in = n + (x_w - 1) * x_dilation
value += input[n + (x_w - 1) * x_dilation, m + (y_w - 1) * y_dilation, in_channel, index_batch] * weight[x_w, y_w, in_channel, out_channel]
end
output[x_out, y_out, out_channel, index_batch] = value
end
end
else # general case for any convolution
# println("forward: general case for any convolution")
@tturbo for index_batch in 1:batch_size
for group in 1:groups, out_channel_per_group in 1:out_channels_per_group, y_out in h_region, x_out in w_region
m = y_out + static((y_stride - 1)) * (y_out - 1) - static(y_pad1)
n = x_out + static((x_stride - 1)) * (x_out - 1) - static(x_pad1)
out_channel = (group * out_channels_per_group + 1) - out_channel_per_group
value = zero(T)
for in_channel_weight in static(1):static(in_channels_weight), y_w in static(1):static(weight_height), x_w in static(1):static(weight_width)
# y_in = m + (y_w - 1) * y_dilation
# x_in = n + (x_w - 1) * x_dilation
in_channel_input = in_channel_weight + (group - 1) * in_channels_weight
value += input[n + (x_w - 1) * x_dilation, m + (y_w - 1) * y_dilation, in_channel_input, index_batch] * weight[x_w, y_w, in_channel_weight, out_channel]
end
output[x_out, y_out, out_channel, index_batch] = value
end
end
end
# @show w_region # 1:0 warning: when padding is unequal and one of x_pad1 or y_pad1 is 0, emty collections are possible
# @show h_region # 1:0 if isempty(1:0) -> true
# println()
if cdims.padding != (0, 0, 0, 0)
# Next, the padded regions
for (w_region, h_region, d_region) in padded_regions # @inbounds
for z_out in d_region # for skipping the d_regions
if cdims.groupcount == 1
@tturbo for index_batch in 1:batch_size # @turbo
for out_channel in 1:out_channels, y_out in h_region, x_out in w_region
m = y_out + static((y_stride - 1)) * (y_out - 1) - static(y_pad1)
n = x_out + static((x_stride - 1)) * (x_out - 1) - static(x_pad1)
value = zero(T)
for in_channel in static(1):static(in_channels), y_w in static(1):static(weight_height), x_w in static(1):static(weight_width)
y_in = m + (y_w - 1) * y_dilation
x_in = n + (x_w - 1) * x_dilation
Since the current state of LoopVectorization ∇conv_filter! isn't really faster than the
original implementation in some situations, it is left out for the moment.
is_in_bound_x = (x_in >= 1) & (x_in <= input_width)
is_in_bound_y = (y_in >= 1) & (y_in <= input_height)
Implementation for forward pass mostly copied from here (Jonas Steinebach, MIT license):
https://github.com/jonas208/GradValley.jl/blob/main/src/functional/gv_convolution.jl
input_value = (is_in_bound_x & is_in_bound_y) ? input[x_in, y_in, in_channel, index_batch] : zero(T)
value += input_value * weight[x_w, y_w, in_channel, out_channel]
Implementation for backward pass mostly copied from here (Chris Elrod, MIT license):
https://github.com/PumasAI/SimpleChains.jl/blob/main/src/conv.jl
# value += (ib0 & ib1) ? input[x_in, y_in, in_channel, index_batch] * weight[x_w, y_w, in_channel, out_channel] : zero(T)
end
output[x_out, y_out, out_channel, index_batch] = value
end
end
else
@tturbo for index_batch in 1:batch_size # @turbo
for group in 1:groups, out_channel_per_group in 1:out_channels_per_group, y_out in h_region, x_out in w_region
m = y_out + static((y_stride - 1)) * (y_out - 1) - static(y_pad1)
n = x_out + static((x_stride - 1)) * (x_out - 1) - static(x_pad1)
out_channel = (group * out_channels_per_group + 1) - out_channel_per_group
value = zero(T)
for in_channel_weight in static(1):static(in_channels_weight), y_w in static(1):static(weight_height), x_w in static(1):static(weight_width)
y_in = m + (y_w - 1) * y_dilation
x_in = n + (x_w - 1) * x_dilation
in_channel_input = in_channel_weight + (group - 1) * in_channels_weight
is_in_bound_x = (x_in >= 1) & (x_in <= input_width)
is_in_bound_y = (y_in >= 1) & (y_in <= input_height)
input_value = (is_in_bound_x & is_in_bound_y) ? input[x_in, y_in, in_channel_input, index_batch] : zero(T)
value += input_value * weight[x_w, y_w, in_channel_weight, out_channel]
# value += (ib0 & ib1) ? input[x_in, y_in, in_channel_input, index_batch] * weight[x_w, y_w, in_channel_weight, out_channel] : zero(T)
end
output[x_out, y_out, out_channel, index_batch] = value
end
end
end
end
end
end
return output
end
=#

#=
function NNlib.conv!(output::Array{T,4}, input::Array{T,4}, weight::Array{T,4}, cdims::ConvDims; kw...) where {T<:Real}
# fix for groupcount > 1 (NNlib.check_dims would throw an error otherwise)
@@ -92,66 +347,81 @@ function NNlib.conv!(output::Array{T,4}, input::Array{T,4}, weight::Array{T,4},
return output
end
=#

function ∇conv_data_im2col_grouped!(input_gradient::Array{T,4}, output_gradient::Array{T,4}, weight::Array{T,4}, cdims::ConvDims) where {T<:Real}

∇conv_data!(
NNlib.insert_singleton_spatial_dimension(input_gradient, 1),
NNlib.insert_singleton_spatial_dimension(output_gradient, 1),
NNlib.insert_singleton_spatial_dimension(weight, 1),
NNlib.insert_singleton_spatial_dimension(cdims, 1)
)

return input_gradient
end

function ∇conv_filter_im2col_grouped!(weight_gradient::Array{T,4}, input::Array{T,4}, output_gradient::Array{T,4}, cdims::ConvDims) where {T<:Real}

∇conv_filter!(
NNlib.insert_singleton_spatial_dimension(weight_gradient, 1),
NNlib.insert_singleton_spatial_dimension(input, 1),
NNlib.insert_singleton_spatial_dimension(output_gradient, 1),
NNlib.insert_singleton_spatial_dimension(cdims, 1)
)

return weight_gradient
end

function NNlib.∇conv_data!(input_gradient::Array{T,4}, output_gradient::Array{T,4}, weight::Array{T,4}, cdims::ConvDims) where {T<:Real}

function NNlib.∇conv_data!(input_gradient::Array{T,4}, output_gradient::Array{T,4}, weight::Array{T,4}, cdims::ConvDims; kw...) where {T<:Real}

# fix for groupcount > 1 (NNlib.check_dims would throw an error otherwise)
size_weight_check_dims = (size(weight)[1:2]..., size(weight)[3]*cdims.groupcount, size(weight)[4])
cdims_check_dims = DenseConvDims(size(input_gradient), size_weight_check_dims, stride=cdims.stride, padding=cdims.padding, dilation=cdims.dilation, groups=1, flipkernel=cdims.flipkernel)
NNlib.check_dims(size(input_gradient), size_weight_check_dims, size(output_gradient), cdims_check_dims)

# storing all the necessary shapes
output_width, output_height, out_channels, current_batch_size = size(output_gradient)
weight_width, weight_height, in_channels_weight, out_channels = size(weight)

# because in the actual computation section, values are added, it's saver to reset the given input_gradient first
input_gradient .= zero(T)
# check if input_gradient must be padded (padding is done naively at the moment)
if cdims.padding != (0, 0, 0, 0)
input_gradient_padded = NNlib.pad_zeros(input_gradient, cdims.padding, dims=(1, 2))
else
input_gradient_padded = input_gradient
end

# store the size of input after padding
input_width, input_height, in_channels, current_batch_size = size(input_gradient_padded) # size after padding

# it's necessary to flip the kernel if real convolution is performed (flipkernel=false)
if !NNlib.flipkernel(cdims)
weight = reverse(weight, dims=(1, 2))
end
if cdims.groupcount == 1 && cdims.stride == (1, 1) && cdims.dilation == (1, 1) # very specialized case for maximum performance
# println("backward: very specialized case for maximum performance")

groups = cdims.groupcount
x_stride, y_stride = cdims.stride
x_dilation, y_dilation = cdims.dilation
out_channels_per_group = out_channels ÷ groups
# storing all the necessary shapes
output_width, output_height, out_channels, batch_size = size(output_gradient)
weight_width, weight_height, in_channels_weight, out_channels = size(weight)

# because in the actual computation section, values are added, it's saver to reset the given input_gradient first
input_gradient .= zero(T)
# check if input_gradient must be padded (padding is done naively at the moment)
if cdims.padding != (0, 0, 0, 0)
input_gradient_padded = NNlib.pad_zeros(input_gradient, cdims.padding, dims=(1, 2))
else
input_gradient_padded = input_gradient
end

@inline static_size(x::AbstractArray{T, N}) where {T, N} = static.(size(x))
# store the size of input after padding
input_width, input_height, in_channels, batch_size = size(input_gradient_padded) # size after padding

# actual computation (using @tturbo instead of Threads.@threads + @turbo may end up in wrong results)
if groups == 1 && cdims.stride == (1, 1) && cdims.dilation == (1, 1) # very specialized case for maximum performance
# println("backward: very specialized case for maximum performance")
# it's necessary to flip the kernel if real convolution is performed (flipkernel=false)
if !NNlib.flipkernel(cdims)
weight = reverse(weight, dims=(1, 2))
end

output_gradient = OffsetArray(output_gradient, OffsetArrays.Origin(0, 0, 0, 0))
input_gradient_padded = OffsetArray(input_gradient_padded, OffsetArrays.Origin(0, 0, 0, 0))
weight = OffsetArray(weight, OffsetArrays.Origin(0, 0, 0, 0))

input_width, input_height, in_channels, batch_size = static_size(input_gradient_padded)
weight_width, weight_height, in_channels_weight, out_channels = static_size(weight)

y_upper_bound = static(output_height) # input_width - weight_width + static(1)
x_upper_bound = static(output_width) # input_height - weight_height + static(1)
input_width, input_height, in_channels, batch_size = size(input_gradient_padded)
weight_width, weight_height, in_channels_weight, out_channels = size(weight)

@tturbo for index_batch in 0:batch_size-1
for x_in in 0:input_width-1, y_in in 0:input_height-1, in_channel in 0:in_channels-1 # @tturbo unroll = (2, 1)
for x_in in 0:input_width-1, y_in in 0:input_height-1, in_channel in 0:in_channels-1

value = zero(T)
for x_w in 0:weight_width-1, y_w in 0:weight_height-1, out_channel in 0:out_channels-1
ib0 = (x_in - x_w >= 0) & (x_in - x_w < x_upper_bound)
ib1 = (y_in - y_w >= 0) & (y_in - y_w < y_upper_bound)
output_gradient_value = (ib0 & ib1) ? output_gradient[x_in-x_w, y_in-y_w, out_channel, index_batch] : zero(T)
for x_w in static(0):static(weight_width-1), y_w in static(0):static(weight_height-1), out_channel in static(0):static(out_channels-1)

is_in_bound_x = (x_in - x_w >= 0) & (x_in - x_w < output_width)
is_in_bound_y = (y_in - y_w >= 0) & (y_in - y_w < output_height)
output_gradient_value = (is_in_bound_x & is_in_bound_y) ? output_gradient[x_in - x_w, y_in - y_w, out_channel, index_batch] : zero(T)
value += weight[x_w, y_w, in_channel, out_channel] * output_gradient_value
# value += (ib0 & ib1) ? output_gradient[x_in-x_w, y_in-y_w, out_channel, index_batch] * weight[x_w, y_w, in_channel, out_channel] : zero(T)

end
input_gradient_padded[x_in, y_in, in_channel, index_batch] = value

@@ -160,86 +430,15 @@ function NNlib.∇conv_data!(input_gradient::Array{T,4}, output_gradient::Array{

input_gradient_padded = input_gradient_padded.parent

elseif groups == 1 && cdims.dilation == (1, 1) # second specialized case for better performance
# println("backward: second specialized case for better performance")

#=
Threads.@threads for index_batch in 1:current_batch_size
@turbo for out_channel in 1:out_channels, y_out in 1:output_height, x_out in 1:output_width
m = y_out + (y_stride - 1) * (y_out - 1)
n = x_out + (x_stride - 1) * (x_out - 1)
for in_channel in 1:in_channels, y_w in 1:weight_height, x_w in 1:weight_width
y_in = m + (y_w - 1) * y_dilation
x_in = n + (x_w - 1) * x_dilation
input_gradient_padded[x_in, y_in, in_channel, index_batch] += weight[x_w, y_w, in_channel, out_channel] * output_gradient[x_out, y_out, out_channel, index_batch]
end
end
end
=#

y_out_indices = Array{Int, 3}(undef, weight_width, weight_height, input_height)
x_out_indices = Array{Int, 3}(undef, weight_width, weight_height, input_width)
x_out_indices .= -1
y_out_indices .= -1

@turbo for y_out in 1:output_height, x_out in 1:output_width #
m = y_out + (y_stride - 1) * (y_out - 1)
n = x_out + (x_stride - 1) * (x_out - 1)
for y_w in 1:weight_height, x_w in 1:weight_width
y_in = m + (y_w - 1) * y_dilation
x_in = n + (x_w - 1) * x_dilation
y_out_indices[x_w, y_w, y_in] = y_out
x_out_indices[x_w, y_w, x_in] = x_out
end
end

@tturbo for index_batch in 1:current_batch_size
for x_in in 1:input_width, y_in in 1:input_height, in_channel in 1:in_channels # @tturbo unroll = (2, 1)

value = zero(T)
for x_w in 1:weight_width, y_w in 1:weight_height, out_channel in 1:out_channels

x_out = x_out_indices[x_w, y_w, x_in]
y_out = y_out_indices[x_w, y_w, y_in]

ib0 = x_out > -1 # !=
ib1 = y_out > -1 # !=

output_gradient_value = (ib0 & ib1) ? output_gradient[x_out, y_out, out_channel, index_batch] : zero(T)
# output_gradient_value = T(2.0) # output_gradient[x_out, y_out, out_channel, index_batch]
value += weight[x_w, y_w, in_channel, out_channel] * output_gradient_value
end
input_gradient[x_in, y_in, in_channel, index_batch] = value

end
# depad
if cdims.padding != (0, 0, 0, 0)
x_pad1, x_pad2, y_pad1, y_pad2 = cdims.padding
input_gradient .= input_gradient_padded[x_pad1+1:input_width-x_pad2, y_pad1+1:input_height-y_pad2, :, :]
end

else # general case for any convolution
# println("backward: general case for any convolution")

Threads.@threads for index_batch in 1:current_batch_size
for out_channel_per_group in 1:out_channels_per_group # putting @turbo here may end up in wrong results
@turbo for group in 1:groups, y_out in 1:output_height, x_out in 1:output_width
m = y_out + (y_stride - 1) * (y_out - 1)
n = x_out + (x_stride - 1) * (x_out - 1)
out_channel = (group * out_channels_per_group + 1) - out_channel_per_group
for in_channel_weight in 1:in_channels_weight, y_w in 1:weight_height, x_w in 1:weight_width
y_in = m + (y_w - 1) * y_dilation
x_in = n + (x_w - 1) * x_dilation
in_channel_input = in_channel_weight + (group - 1) * in_channels_weight
input_gradient_padded[x_in, y_in, in_channel_input, index_batch] += weight[x_w, y_w, in_channel_weight, out_channel] * output_gradient[x_out, y_out, out_channel, index_batch]
end
end
end
end

end

# depad
if cdims.padding != (0, 0, 0, 0)
x_pad1, x_pad2, y_pad1, y_pad2 = cdims.padding
input_gradient .= input_gradient_padded[x_pad1+1:input_width-x_pad2, y_pad1+1:input_height-y_pad2, :, :]
end
input_gradient = ∇conv_data_im2col_grouped!(input_gradient, output_gradient, weight, cdims)
end

return input_gradient
end