From 4f3f3bb8f888e363e875deb5d6399db42d0a2395 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Fri, 28 Jul 2023 11:29:06 +0200 Subject: [PATCH 1/3] Accept left offsets when applying position encodings --- CMakeLists.txt | 3 + include/ctranslate2/layers/common.h | 8 +- include/ctranslate2/ops/ops.h | 1 + .../ctranslate2/ops/position_encodings_add.h | 26 +++++++ src/layers/common.cc | 34 ++------- src/layers/transformer.cc | 4 +- src/layers/whisper.cc | 2 +- src/ops/position_encodings_add.cc | 48 ++++++++++++ src/ops/position_encodings_add_cpu.cc | 48 ++++++++++++ src/ops/position_encodings_add_gpu.cu | 73 +++++++++++++++++++ tests/layers_test.cc | 34 ++++++++- 11 files changed, 246 insertions(+), 35 deletions(-) create mode 100644 include/ctranslate2/ops/position_encodings_add.h create mode 100644 src/ops/position_encodings_add.cc create mode 100644 src/ops/position_encodings_add_cpu.cc create mode 100644 src/ops/position_encodings_add_gpu.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 79717045b..babae3f67 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -155,6 +155,8 @@ set(SOURCES src/ops/mul.cc src/ops/multinomial.cc src/ops/multinomial_cpu.cc + src/ops/position_encodings_add.cc + src/ops/position_encodings_add_cpu.cc src/ops/quantize.cc src/ops/quantize_cpu.cc src/ops/relu.cc @@ -511,6 +513,7 @@ if (WITH_CUDA) src/ops/layer_norm_gpu.cu src/ops/mean_gpu.cu src/ops/multinomial_gpu.cu + src/ops/position_encodings_add_gpu.cu src/ops/rms_norm_gpu.cu src/ops/rotary_gpu.cu src/ops/softmax_gpu.cu diff --git a/include/ctranslate2/layers/common.h b/include/ctranslate2/layers/common.h index 7624e1a72..cb54ef1f3 100644 --- a/include/ctranslate2/layers/common.h +++ b/include/ctranslate2/layers/common.h @@ -89,10 +89,14 @@ namespace ctranslate2 { // Base class for position encoders. class PositionEncoder : public Layer { public: - void operator()(StorageView& input, dim_t index = 0); - void operator()(const StorageView& input, StorageView& output, dim_t index = 0); + void operator()(const StorageView& input, + StorageView& output, + dim_t step = 0, + const StorageView* offsets = nullptr); protected: virtual const StorageView& get_position_encoding(dim_t max_time) = 0; + private: + const ops::PositionEncodingsAdd _add_op; }; // Concrete position encoder loading encoding vectors from the model. diff --git a/include/ctranslate2/ops/ops.h b/include/ctranslate2/ops/ops.h index 78e667848..b11383d3f 100644 --- a/include/ctranslate2/ops/ops.h +++ b/include/ctranslate2/ops/ops.h @@ -39,3 +39,4 @@ #include "median_filter.h" #include "rotary.h" #include "alibi_add.h" +#include "position_encodings_add.h" diff --git a/include/ctranslate2/ops/position_encodings_add.h b/include/ctranslate2/ops/position_encodings_add.h new file mode 100644 index 000000000..de255690f --- /dev/null +++ b/include/ctranslate2/ops/position_encodings_add.h @@ -0,0 +1,26 @@ +#pragma once + +#include "op.h" + +namespace ctranslate2 { + namespace ops { + + class PositionEncodingsAdd : public Op { + public: + void operator()(const StorageView& input, + const StorageView& encodings, + StorageView& output, + const StorageView* offsets = nullptr, + const dim_t step = 0) const; + + private: + template + void compute(const dim_t step, + const StorageView* offsets, + const StorageView& input, + const StorageView& encodings, + StorageView& output) const; + }; + + } +} diff --git a/src/layers/common.cc b/src/layers/common.cc index 162012e9a..7659aa240 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -4,7 +4,6 @@ #include "ctranslate2/ops/activation.h" #include "cpu/backend.h" -#include "dispatch.h" namespace ctranslate2 { namespace layers { @@ -148,34 +147,13 @@ namespace ctranslate2 { } - void PositionEncoder::operator()(StorageView& input, dim_t index) { + void PositionEncoder::operator()(const StorageView& input, + StorageView& output, + dim_t step, + const StorageView* offsets) { const dim_t time = input.dim(1); - const dim_t depth = input.dim(-1); - const dim_t max_time = std::max(time, index + 1); - const StorageView& encodings = get_position_encoding(max_time); - const dim_t num_encodings = encodings.dim(0); - - if (max_time > num_encodings) - throw std::runtime_error("No position encodings are defined for positions >= " - + std::to_string(num_encodings) - + ", but got position " - + std::to_string(max_time - 1)); - if (depth != encodings.dim(1)) - throw std::invalid_argument("Shape mismatch: position encodings have depth " - + std::to_string(encodings.dim(1)) - + ", but the input has depth " - + std::to_string(depth)); - - DEVICE_AND_TYPE_DISPATCH(input.device(), input.dtype(), - primitives::add_batch_broadcast(encodings.data() + index * depth, - input.data(), - time * depth, - input.size())); - } - - void PositionEncoder::operator()(const StorageView& input, StorageView& output, dim_t index) { - output = input; - operator()(output, index); + const StorageView& encodings = get_position_encoding(step + time); + _add_op(input, encodings, output, offsets, step); } diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index d97c5f20b..885663569 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -259,7 +259,7 @@ namespace ctranslate2 { if (_embeddings_scale) ops::Mul()(input, *_embeddings_scale, input); if (_position_encoder) - (*_position_encoder)(input); + (*_position_encoder)(input, input); if (_layernorm_embedding) (*_layernorm_embedding)(input, input); @@ -460,7 +460,7 @@ namespace ctranslate2 { if (layer_in.rank() == 2) layer_in.expand_dims(1); if (_position_encoder) - (*_position_encoder)(layer_in, std::max(step, dim_t(0))); + (*_position_encoder)(layer_in, layer_in, std::max(step, dim_t(0))); if (_layernorm_embedding) (*_layernorm_embedding)(layer_in, layer_in); diff --git a/src/layers/whisper.cc b/src/layers/whisper.cc index 2bad8b669..3b15224e8 100644 --- a/src/layers/whisper.cc +++ b/src/layers/whisper.cc @@ -52,7 +52,7 @@ namespace ctranslate2 { _gelu(output, output); _transpose(output, input); - _position_embedding(input); + _position_embedding(input, input); for (const auto& layer : _layers) { (*layer)(input, nullptr, output); diff --git a/src/ops/position_encodings_add.cc b/src/ops/position_encodings_add.cc new file mode 100644 index 000000000..4896f41a6 --- /dev/null +++ b/src/ops/position_encodings_add.cc @@ -0,0 +1,48 @@ +#include "ctranslate2/ops/position_encodings_add.h" + +#include "dispatch.h" + +namespace ctranslate2 { + namespace ops { + + void PositionEncodingsAdd::operator()(const StorageView& input, + const StorageView& encodings, + StorageView& output, + const StorageView* offsets, + const dim_t step) const { + PROFILE("PositionEncodingsAdd"); + + const dim_t time = input.dim(1); + const dim_t depth = input.dim(2); + const dim_t max_time = time + step; + + if (max_time > encodings.dim(0)) + throw std::runtime_error("No position encodings are defined for positions >= " + + std::to_string(encodings.dim(0)) + + ", but got position " + + std::to_string(max_time - 1)); + + if (depth != encodings.dim(1)) + throw std::invalid_argument("Shape mismatch: position encodings have depth " + + std::to_string(encodings.dim(1)) + + ", but the input has depth " + + std::to_string(depth)); + + output.resize_as(input); + + DEVICE_AND_FLOAT_DISPATCH( + "PositionEncodingsAdd", input.device(), input.dtype(), + ({ + if (offsets) + compute(step, offsets, input, encodings, output); + else + primitives::add_batch_broadcast(encodings.data() + step * depth, + input.data(), + output.data(), + time * depth, + input.size()); + })); + } + + } +} diff --git a/src/ops/position_encodings_add_cpu.cc b/src/ops/position_encodings_add_cpu.cc new file mode 100644 index 000000000..ef03b6dde --- /dev/null +++ b/src/ops/position_encodings_add_cpu.cc @@ -0,0 +1,48 @@ +#include "ctranslate2/ops/position_encodings_add.h" + +#include "cpu/parallel.h" + +namespace ctranslate2 { + namespace ops { + + template + void PositionEncodingsAdd::compute(const dim_t step, + const StorageView* offsets, + const StorageView& input, + const StorageView& encodings, + StorageView& output) const { + const dim_t batch_size = input.dim(0); + const dim_t time = input.dim(1); + const dim_t depth = input.dim(2); + + cpu::parallel_for(0, batch_size * time, 1, [&](const dim_t begin, const dim_t end) { + for (dim_t i = begin; i < end; ++i) { + const dim_t b = i / time; + const dim_t t = i % time; + + const dim_t offset = offsets ? offsets->at(b) : 0; + const dim_t encoding_offset = t - offset + step; + + if (encoding_offset < 0) + continue; + + primitives::add(encodings.index({encoding_offset, 0}), + input.index({b, t, 0}), + output.index({b, t, 0}), + depth); + } + }); + } + +#define DECLARE_IMPL(T) \ + template void \ + PositionEncodingsAdd::compute(const dim_t, \ + const StorageView*, \ + const StorageView&, \ + const StorageView&, \ + StorageView&) const; + + DECLARE_IMPL(float) + + } +} diff --git a/src/ops/position_encodings_add_gpu.cu b/src/ops/position_encodings_add_gpu.cu new file mode 100644 index 000000000..e5ad79194 --- /dev/null +++ b/src/ops/position_encodings_add_gpu.cu @@ -0,0 +1,73 @@ +#include "ctranslate2/ops/position_encodings_add.h" + +#include "type_dispatch.h" +#include "cuda/helpers.h" + +namespace ctranslate2 { + namespace ops { + + template + __global__ void position_encodings_add_kernel(const T* input, + const T* encodings, + T* output, + const int32_t* offsets, + cuda::index_t step, + cuda::index_t max_time, + cuda::index_t depth, + const AddFunc& add_func) { + const cuda::index_t batch = blockIdx.x / max_time; + const cuda::index_t time = blockIdx.x % max_time; + + const int32_t offset = offsets ? offsets[batch] : 0; + const int32_t encoding_offset = time - offset + step; + + if (encoding_offset < 0) + return; + + input += blockIdx.x * depth; + output += blockIdx.x * depth; + encodings += encoding_offset * depth; + + for (cuda::index_t i = threadIdx.x; i < depth; i += blockDim.x) { + output[i] = add_func(input[i], encodings[i]); + } + } + + template + void PositionEncodingsAdd::compute(const dim_t step, + const StorageView* offsets, + const StorageView& input, + const StorageView& encodings, + StorageView& output) const { + const dim_t batch_size = input.dim(0); + const dim_t time = input.dim(1); + const dim_t depth = input.dim(2); + + const dim_t blocks = std::min(batch_size * time, cuda::max_blocks); + const dim_t threads = std::min(depth, cuda::max_threads); + + position_encodings_add_kernel<<>>( + cuda::device_cast(input.data()), + cuda::device_cast(encodings.data()), + cuda::device_cast(output.data()), + offsets ? offsets->data() : nullptr, + step, + time, + depth, + cuda::plus>()); + } + +#define DECLARE_IMPL(T) \ + template void \ + PositionEncodingsAdd::compute(const dim_t, \ + const StorageView*, \ + const StorageView&, \ + const StorageView&, \ + StorageView&) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} diff --git a/tests/layers_test.cc b/tests/layers_test.cc index f1359bc05..c62f362f8 100644 --- a/tests/layers_test.cc +++ b/tests/layers_test.cc @@ -223,6 +223,36 @@ TEST(LayerTest, PadderIgnore) { expect_storage_eq(x, original); } +TEST_P(LayerDeviceFPTest, PositionEncoderOffset) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + + layers::SinusoidalPositionEncoder position_encoder(4, dtype, device); + + StorageView offsets({2}, std::vector{3, 1}, device); + dim_t step = 5; + + StorageView expected_encodings(dtype, device); + + { + StorageView zero({2, 5, 4}, 0.f, device); + StorageView encodings(dtype, device); + position_encoder(zero.to(dtype), encodings); + + StorageView position_ids({2, 1}, std::vector{2, 4}, device); + ops::Gather(/*axis=*/1, /*batch_dims=*/1)(encodings, position_ids, expected_encodings); + } + + { + StorageView zero({2, 1, 4}, 0.f, device); + StorageView encodings(dtype, device); + position_encoder(zero.to(dtype), encodings, step, &offsets); + + expect_storage_eq(encodings.to_float32(), expected_encodings.to_float32(), error); + } +} + TEST(LayerTest, PositionEncoderNoSharedState) { // Test case for issue: http://forum.opennmt.net/t/ctranslate2-c-api-returns-strange-results-when-initializing-2-models/3208 layers::SinusoidalPositionEncoder position_encoder_1(4); @@ -233,7 +263,7 @@ TEST(LayerTest, PositionEncoderNoSharedState) { {1, 1, 4}, std::vector{0.1, -2.3, 0.5, 1.2}); StorageView expected( {1, 1, 4}, std::vector{0.941471, -2.2999, 1.0403, 2.2}); - position_encoder_1(input); + position_encoder_1(input, input); expect_storage_eq(input, expected, 1e-5); } @@ -242,7 +272,7 @@ TEST(LayerTest, PositionEncoderNoSharedState) { {1, 1, 6}, std::vector{-0.2, -1.3, 0.1, -0.6, 2.0, 1.1}); StorageView expected( {1, 1, 6}, std::vector{0.641471, -1.29, 0.1001, -0.0596977, 2.99995, 2.1}); - position_encoder_2(input); + position_encoder_2(input, input); expect_storage_eq(input, expected, 1e-5); } } From a954445979b6cf5359bb6a45cf18906fe2d32f7b Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Fri, 28 Jul 2023 11:54:12 +0200 Subject: [PATCH 2/3] Remove const for attribute --- include/ctranslate2/layers/common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ctranslate2/layers/common.h b/include/ctranslate2/layers/common.h index cb54ef1f3..a3a65bc29 100644 --- a/include/ctranslate2/layers/common.h +++ b/include/ctranslate2/layers/common.h @@ -96,7 +96,7 @@ namespace ctranslate2 { protected: virtual const StorageView& get_position_encoding(dim_t max_time) = 0; private: - const ops::PositionEncodingsAdd _add_op; + ops::PositionEncodingsAdd _add_op; }; // Concrete position encoder loading encoding vectors from the model. From c71bd47229d5ef90d705b9fef32fd5c52acfe2e1 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Fri, 28 Jul 2023 15:32:07 +0200 Subject: [PATCH 3/3] Fix MSVC compilation --- src/ops/position_encodings_add.cc | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/ops/position_encodings_add.cc b/src/ops/position_encodings_add.cc index 4896f41a6..280f3f05f 100644 --- a/src/ops/position_encodings_add.cc +++ b/src/ops/position_encodings_add.cc @@ -30,18 +30,20 @@ namespace ctranslate2 { output.resize_as(input); - DEVICE_AND_FLOAT_DISPATCH( - "PositionEncodingsAdd", input.device(), input.dtype(), - ({ - if (offsets) - compute(step, offsets, input, encodings, output); - else - primitives::add_batch_broadcast(encodings.data() + step * depth, - input.data(), - output.data(), - time * depth, - input.size()); - })); + if (offsets) { + DEVICE_AND_FLOAT_DISPATCH( + "PositionEncodingsAdd", input.device(), input.dtype(), + (compute(step, offsets, input, encodings, output))); + + } else { + DEVICE_AND_FLOAT_DISPATCH( + "PositionEncodingsAdd", input.device(), input.dtype(), + (primitives::add_batch_broadcast(encodings.data() + step * depth, + input.data(), + output.data(), + time * depth, + input.size()))); + } } }