-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible way to implement a LoopVectorization extension for conv2d & meanpool2d & activations #540
Open
jonas208
wants to merge
37
commits into
FluxML:master
Choose a base branch
from
jonas208:lv-ext2
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 13 commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
6d32073
Add files via upload
jonas208 2a3cf3e
Delete ext/NNlibLoopVectorizationExt/conv_old.jl
jonas208 28a027a
Delete ext/NNlibLoopVectorizationExt/pooling_old.jl
jonas208 5339aaf
Add files via upload
jonas208 9e0dc6d
Add files via upload
jonas208 dd0f0ed
Add files via upload
jonas208 b341d1c
Add files via upload
jonas208 94f7964
Update runtests.jl
jonas208 6cc2e75
Add files via upload
jonas208 c5c79ee
Add files via upload
jonas208 132e35c
Add files via upload
jonas208 aa019e9
Add files via upload
jonas208 52e2a78
Add files via upload
jonas208 d63f8a5
Add files via upload
jonas208 ae86d13
Add files via upload
jonas208 776835d
Add files via upload
jonas208 13205da
Add files via upload
jonas208 5850341
Add files via upload
jonas208 00b28f2
Add files via upload
jonas208 af04cc6
Delete runtests.jl
jonas208 990a34c
Delete Project.toml
jonas208 db0ad66
Add files via upload
jonas208 a4e18e6
Add files via upload
jonas208 f584377
Add files via upload
jonas208 274db10
Add files via upload
jonas208 6c33d5c
Add files via upload
jonas208 7affd46
Add files via upload
jonas208 3130f8a
Add files via upload
jonas208 d87f909
Add files via upload
jonas208 82abca8
Add files via upload
jonas208 c5ec713
Delete bench_torch.py
jonas208 3f1c6dc
Add files via upload
jonas208 8dde5f9
Add files via upload
jonas208 5505157
Add files via upload
jonas208 0aa3a3f
Add files via upload
jonas208 35f2b77
Add files via upload
jonas208 07943d7
Add files via upload
jonas208 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
10 changes: 10 additions & 0 deletions
10
ext/NNlibLoopVectorizationExt/NNlibLoopVectorizationExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
module NNlibLoopVectorizationExt | ||
|
||
using NNlib | ||
using LoopVectorization | ||
using Random, Statistics | ||
|
||
include("conv.jl") | ||
include("pooling.jl") | ||
|
||
end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
#= | ||
Accelerated convolution for 2d-images using the power of LoopVectorization. | ||
The acceleration is usually greatest when the inputs have a large spatial size and few channels. | ||
Using stride > 1, dilation > 1 or groups > 1 can slow down things a bit. | ||
|
||
Since the current state of LoopVectorization ∇conv_filter! isn't really faster than the | ||
original implementation in some situations, it is left out for the moment. | ||
|
||
Implementation copied from here (Jonas Steinebach, MIT): | ||
https://github.com/jonas208/GradValley.jl/blob/main/src/functional/gv_convolution.jl | ||
=# | ||
|
||
function NNlib.conv!(output::Array{T,4}, input::Array{T,4}, weight::Array{T,4}, cdims::ConvDims; kw...) where {T<:Real} | ||
|
||
# fix for groupcount > 1 (NNlib.check_dims would throw an error otherwise) | ||
size_weight_check_dims = (size(weight)[1:2]..., size(weight)[3]*cdims.groupcount, size(weight)[4]) | ||
cdims_check_dims = DenseConvDims(size(input), size_weight_check_dims, stride=cdims.stride, padding=cdims.padding, dilation=cdims.dilation, groups=1, flipkernel=cdims.flipkernel) | ||
NNlib.check_dims(size(input), size_weight_check_dims, size(output), cdims_check_dims) | ||
|
||
# padding is done naively at the moment | ||
if cdims.padding != (0, 0, 0, 0) | ||
input = NNlib.pad_zeros(input, cdims.padding, dims=(1, 2)) | ||
end | ||
|
||
output_width, output_height, _ = size(output) | ||
input_width, input_height, in_channels, batches = size(input) | ||
weight_width, weight_height, in_channels_weight, out_channels = size(weight) | ||
|
||
# it's necessary to flip the kernel if real convolution is performed (flipkernel=false) | ||
if !NNlib.flipkernel(cdims) | ||
weight = reverse(weight, dims=(1, 2)) | ||
end | ||
|
||
groups = cdims.groupcount | ||
x_stride, y_stride = cdims.stride | ||
x_dilation, y_dilation = cdims.dilation | ||
out_channels_per_group = out_channels ÷ groups | ||
|
||
if cdims.groupcount == 1 && cdims.stride == (1, 1) && cdims.dilation == (1, 1) # very specialized case for maximum performance | ||
# println("forward: very specialized case for maximum performance") | ||
|
||
@tturbo for index_batch in 1:batches | ||
for out_channel in 1:out_channels, y_out in 1:output_height, x_out in 1:output_width | ||
value = zero(T) | ||
for in_channel in 1:in_channels, y_w in 1:weight_height, x_w in 1:weight_width | ||
value += input[x_out + x_w - 1, y_out + y_w - 1, in_channel, index_batch] * weight[x_w, y_w, in_channel, out_channel] | ||
end | ||
output[x_out, y_out, out_channel, index_batch] = value | ||
end | ||
end | ||
|
||
elseif groups == 1 # second specialized case for better performance | ||
# println("forward: second specialized case for better performance") | ||
|
||
@tturbo for index_batch in 1:batches | ||
for out_channel in 1:out_channels, y_out in 1:output_height, x_out in 1:output_width | ||
m = y_out + (y_stride - 1) * (y_out - 1) | ||
n = x_out + (x_stride - 1) * (x_out - 1) | ||
value = zero(T) | ||
for in_channel in 1:in_channels, y_w in 1:weight_height, x_w in 1:weight_width | ||
y_in = m + (y_w - 1) * y_dilation | ||
x_in = n + (x_w - 1) * x_dilation | ||
value += input[x_in, y_in, in_channel, index_batch] * weight[x_w, y_w, in_channel, out_channel] | ||
end | ||
output[x_out, y_out, out_channel, index_batch] = value | ||
end | ||
end | ||
|
||
else # general case for any convolution | ||
# println("forward: general case for any convolution") | ||
|
||
@tturbo for index_batch in 1:batches | ||
for group in 1:groups, out_channel_per_group in 1:out_channels_per_group, y_out in 1:output_height, x_out in 1:output_width | ||
m = y_out + (y_stride - 1) * (y_out - 1) | ||
n = x_out + (x_stride - 1) * (x_out - 1) | ||
out_channel = (group * out_channels_per_group + 1) - out_channel_per_group | ||
value = zero(T) | ||
for in_channel_weight in 1:in_channels_weight, y_w in 1:weight_height, x_w in 1:weight_width | ||
y_in = m + (y_w - 1) * y_dilation | ||
x_in = n + (x_w - 1) * x_dilation | ||
in_channel_input = in_channel_weight + (group - 1) * in_channels_weight | ||
value += input[x_in, y_in, in_channel_input, index_batch] * weight[x_w, y_w, in_channel_weight, out_channel] | ||
end | ||
output[x_out, y_out, out_channel, index_batch] = value | ||
end | ||
end | ||
|
||
end | ||
|
||
return output | ||
end | ||
|
||
function NNlib.∇conv_data!(input_gradient::Array{T,4}, output_gradient::Array{T,4}, weight::Array{T,4}, cdims::ConvDims; kw...) where {T<:Real} | ||
|
||
# fix for groupcount > 1 (NNlib.check_dims would throw an error otherwise) | ||
size_weight_check_dims = (size(weight)[1:2]..., size(weight)[3]*cdims.groupcount, size(weight)[4]) | ||
cdims_check_dims = DenseConvDims(size(input_gradient), size_weight_check_dims, stride=cdims.stride, padding=cdims.padding, dilation=cdims.dilation, groups=1, flipkernel=cdims.flipkernel) | ||
NNlib.check_dims(size(input_gradient), size_weight_check_dims, size(output_gradient), cdims_check_dims) | ||
|
||
# storing all the necessary shapes | ||
output_width, output_height, out_channels, current_batch_size = size(output_gradient) | ||
weight_width, weight_height, in_channels_weight, out_channels = size(weight) | ||
|
||
# because in the actual computation section, values are added, it's saver to reset the given input_gradient first | ||
input_gradient .= zero(T) | ||
# check if input_gradient must be padded (padding is done naively at the moment) | ||
if cdims.padding != (0, 0, 0, 0) | ||
input_gradient_padded = NNlib.pad_zeros(input_gradient, cdims.padding, dims=(1, 2)) | ||
else | ||
input_gradient_padded = input_gradient | ||
end | ||
|
||
# store the size of input after padding | ||
input_width, input_height, in_channels, current_batch_size = size(input_gradient_padded) # size after padding | ||
|
||
# it's necessary to flip the kernel if real convolution is performed (flipkernel=false) | ||
if !NNlib.flipkernel(cdims) | ||
weight = reverse(weight, dims=(1, 2)) | ||
end | ||
|
||
groups = cdims.groupcount | ||
x_stride, y_stride = cdims.stride | ||
x_dilation, y_dilation = cdims.dilation | ||
out_channels_per_group = out_channels ÷ groups | ||
|
||
# actual computation (using @tturbo instead of Threads.@threads + @turbo may end up in wrong results) | ||
if groups == 1 && cdims.stride == (1, 1) && cdims.dilation == (1, 1) # very specialized case for maximum performance | ||
# println("backward: very specialized case for maximum performance") | ||
|
||
Threads.@threads for index_batch in 1:current_batch_size | ||
@turbo for out_channel in 1:out_channels, y_out in 1:output_height, x_out in 1:output_width | ||
for in_channel in 1:in_channels, y_w in 1:weight_height, x_w in 1:weight_width | ||
input_gradient_padded[x_out + x_w - 1, y_out + y_w - 1, in_channel, index_batch] += weight[x_w, y_w, in_channel, out_channel] * output_gradient[x_out, y_out, out_channel, index_batch] | ||
end | ||
end | ||
end | ||
|
||
elseif groups == 1 # second specialized case for better performance | ||
# println("backward: second specialized case for better performance") | ||
|
||
Threads.@threads for index_batch in 1:current_batch_size | ||
@turbo for out_channel in 1:out_channels, y_out in 1:output_height, x_out in 1:output_width | ||
m = y_out + (y_stride - 1) * (y_out - 1) | ||
n = x_out + (x_stride - 1) * (x_out - 1) | ||
for in_channel in 1:in_channels, y_w in 1:weight_height, x_w in 1:weight_width | ||
y_in = m + (y_w - 1) * y_dilation | ||
x_in = n + (x_w - 1) * x_dilation | ||
input_gradient_padded[x_in, y_in, in_channel, index_batch] += weight[x_w, y_w, in_channel, out_channel] * output_gradient[x_out, y_out, out_channel, index_batch] | ||
end | ||
end | ||
end | ||
|
||
else # general case for any convolution | ||
# println("backward: general case for any convolution") | ||
|
||
Threads.@threads for index_batch in 1:current_batch_size | ||
for out_channel_per_group in 1:out_channels_per_group # putting @turbo here may end up in wrong results | ||
@turbo for group in 1:groups, y_out in 1:output_height, x_out in 1:output_width | ||
m = y_out + (y_stride - 1) * (y_out - 1) | ||
n = x_out + (x_stride - 1) * (x_out - 1) | ||
out_channel = (group * out_channels_per_group + 1) - out_channel_per_group | ||
for in_channel_weight in 1:in_channels_weight, y_w in 1:weight_height, x_w in 1:weight_width | ||
y_in = m + (y_w - 1) * y_dilation | ||
x_in = n + (x_w - 1) * x_dilation | ||
in_channel_input = in_channel_weight + (group - 1) * in_channels_weight | ||
input_gradient_padded[x_in, y_in, in_channel_input, index_batch] += weight[x_w, y_w, in_channel_weight, out_channel] * output_gradient[x_out, y_out, out_channel, index_batch] | ||
end | ||
end | ||
end | ||
end | ||
|
||
end | ||
|
||
# depad | ||
if cdims.padding != (0, 0, 0, 0) | ||
x_pad1, x_pad2, y_pad1, y_pad2 = cdims.padding | ||
input_gradient .= input_gradient_padded[x_pad1+1:input_width-x_pad2, y_pad1+1:input_height-y_pad2, :, :] | ||
end | ||
|
||
return input_gradient | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
#= | ||
Accelerated mean pooling for 2d-images using the power of LoopVectorization. | ||
The speed up is usually lower compared to conv but can be approximately up to 2x. | ||
|
||
Since the current state of LoopVectorization ∇meanpool! isn't really faster than the | ||
original implementation in some situations, it is left out for the moment. | ||
|
||
Implementation inspired from here (Jonas Steinebach, MIT): | ||
https://github.com/jonas208/GradValley.jl/blob/main/src/functional/gv_pooling.jl | ||
=# | ||
|
||
function NNlib.meanpool!(output::Array{T,4}, input::Array{T,4}, pdims::PoolDims; kw...) where {T<:Real} | ||
NNlib.check_dims(size(input), size(output), pdims) | ||
|
||
# storing all the necessary shapes | ||
input_width, input_height, channels, current_batch_size = size(input) | ||
output_width, output_height, channels, current_batch_size = size(output) | ||
kernel_width, kernel_height = pdims.kernel_size | ||
|
||
x_stride, y_stride = pdims.stride | ||
x_dilation, y_dilation = pdims.dilation | ||
x_pad1, x_pad2, y_pad1, y_pad2 = pdims.padding | ||
|
||
# A helper function to project from output (w, h) to input (input_w, input_h) | ||
@inline project(idx, stride, pad) = (idx - 1) * stride - pad + 1 | ||
|
||
# We use calc_padding_regions to split outselves up into separate regions that may or | ||
# may not need to worry about padding: | ||
pdims_3d = PoolDims((input_width, input_height, 1, channels, current_batch_size), (kernel_width, kernel_height, 1), stride=(x_stride, y_stride, 1), padding=(x_pad1, x_pad2, y_pad1, y_pad2, 0, 0), dilation=(x_dilation, y_dilation, 1)) | ||
# println(pdims_3d.padding) | ||
padded_regions, central_region = NNlib.calc_padding_regions(pdims_3d) | ||
|
||
# We represent division by kernel size by rolling it | ||
# into the `alpha` multiplier. | ||
_alpha = T(1 / prod(pdims.kernel_size)) | ||
|
||
# Start with the central region | ||
w_region, h_region, _ = central_region | ||
|
||
if pdims.stride == (1, 1) && pdims.dilation == (1, 1) # specialized case for better performance | ||
# println("specialized case for better performance") | ||
|
||
@tturbo for index_batch in 1:current_batch_size | ||
# compute pooling for each channel separatly | ||
for channel in 1:channels, y_out in h_region, x_out in w_region | ||
kernel_sum = zero(T) | ||
for y_w in 1:kernel_height, x_w in 1:kernel_width | ||
# kernel_sum += input[x_out + x_w - 1, y_out + y_w - 1, channel, index_batch] | ||
kernel_sum += input[x_out + x_w - 1 - x_pad1, y_out + y_w - 1 - y_pad1, channel, index_batch] | ||
end | ||
output[x_out, y_out, channel, index_batch] = kernel_sum * _alpha | ||
end | ||
end | ||
|
||
else # general case for any meanpooling | ||
# println("general case for any meanpooling") | ||
|
||
@tturbo for index_batch in 1:current_batch_size | ||
# compute pooling for each channel separatly | ||
for channel in 1:channels, y_out in h_region, x_out in w_region | ||
m = y_out + (y_stride - 1) * (y_out - 1) - y_pad1 | ||
n = x_out + (x_stride - 1) * (x_out - 1) - x_pad1 | ||
kernel_sum = zero(T) | ||
for y_w in 1:kernel_height, x_w in 1:kernel_width | ||
y_in = m + (y_w - 1) * y_dilation # - y_pad1 | ||
x_in = n + (x_w - 1) * x_dilation # - x_pad1 | ||
kernel_sum += input[x_in, y_in, channel, index_batch] | ||
end | ||
output[x_out, y_out, channel, index_batch] = kernel_sum * _alpha | ||
end | ||
end | ||
|
||
end | ||
|
||
# Next, the padded regions | ||
@inbounds for (w_region, h_region, d_region) in padded_regions | ||
for index_batch in 1:current_batch_size, channel in 1:channels | ||
for d in d_region # for skipping the d_regions | ||
for h in h_region | ||
ph = project(h, y_stride, y_pad1) | ||
for w in w_region | ||
pw = project(w, x_stride, x_pad1) | ||
m = zero(T) | ||
|
||
for kh in 1:kernel_height | ||
input_kh = ph + (kh - 1) * y_dilation | ||
if input_kh <= 0 || input_kh > input_height | ||
continue | ||
end | ||
|
||
for kw in 1:kernel_width | ||
input_kw = pw + (kw - 1) * x_dilation | ||
if input_kw <= 0 || input_kw > input_width | ||
continue | ||
end | ||
|
||
m += input[input_kw, input_kh, channel, index_batch] | ||
end | ||
end | ||
|
||
output[w, h, channel, index_batch] = _alpha * m | ||
end | ||
end | ||
end | ||
end | ||
end | ||
|
||
return output | ||
end |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if it makes a difference, but LV's handling of indices written within an index expression might be better.
It has separate code for building indices by parsing index expressions vs parsing the loop body.
I'd have to look at the details to see the degree to which they're different. In theory, they should do the same thing.
LV would definitely benefit from the special case of
x_dilation
being1
.Might be worth branching over (also, special case
-1
or some other some other common small factors).