From 966a8779997d5159c4142efed51c8fb8e04a4a16 Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Fri, 28 Jul 2023 11:44:33 +0200 Subject: [PATCH] Review updates. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Thomas Grützmacher Co-authored-by: Yu-Hsiang Tsai Co-authored-by: Marcel Koch --- .../base/batch_multi_vector_kernels.hpp.inc | 80 +++++++++++-------- core/base/batch_multi_vector.cpp | 80 +++++++++++++++++++ core/test/base/batch_multi_vector.cpp | 13 +++ cuda/base/batch_multi_vector_kernels.cu | 2 +- cuda/base/batch_struct.hpp | 2 +- dpcpp/base/batch_struct.hpp | 2 +- hip/base/batch_multi_vector_kernels.hip.cpp | 2 +- hip/base/batch_struct.hip.hpp | 2 +- .../ginkgo/core/base/batch_lin_op_helpers.hpp | 1 + .../ginkgo/core/base/batch_multi_vector.hpp | 77 +++--------------- .../test/base/batch_multi_vector_kernels.cpp | 14 +--- test/base/batch_multi_vector_kernels.cpp | 23 +++++- 12 files changed, 176 insertions(+), 122 deletions(-) diff --git a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc index 3df2bc14c84..5e63f451d19 100644 --- a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc +++ b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc @@ -47,10 +47,15 @@ __device__ __forceinline__ void scale( } template -__global__ -__launch_bounds__(default_block_size, sm_multiplier) void scale_kernel( - const gko::batch_multi_vector::uniform_batch alpha, - const gko::batch_multi_vector::uniform_batch x, Mapping map) +__global__ __launch_bounds__( + default_block_size, + sm_oversubscription) void scale_kernel(const gko::batch_multi_vector:: + uniform_batch + alpha, + const gko::batch_multi_vector:: + uniform_batch + x, + Mapping map) { for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_entries; batch_id += gridDim.x) { @@ -78,11 +83,20 @@ __device__ __forceinline__ void add_scaled( } template -__global__ -__launch_bounds__(default_block_size, sm_multiplier) void add_scaled_kernel( - const gko::batch_multi_vector::uniform_batch alpha, - const gko::batch_multi_vector::uniform_batch x, - const gko::batch_multi_vector::uniform_batch y, Mapping map) +__global__ __launch_bounds__( + default_block_size, + sm_oversubscription) void add_scaled_kernel(const gko::batch_multi_vector:: + uniform_batch< + const ValueType> + alpha, + const gko::batch_multi_vector:: + uniform_batch< + const ValueType> + x, + const gko::batch_multi_vector:: + uniform_batch + y, + Mapping map) { for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_entries; batch_id += gridDim.x) { @@ -139,24 +153,12 @@ __device__ __forceinline__ void compute_gen_dot_product( template -__global__ __launch_bounds__( - default_block_size, - sm_multiplier) void compute_gen_dot_product_kernel(const gko:: - batch_multi_vector:: - uniform_batch< - const ValueType> - x, - const gko:: - batch_multi_vector:: - uniform_batch< - const ValueType> - y, - const gko:: - batch_multi_vector:: - uniform_batch< - ValueType> - result, - Mapping map) +__global__ + __launch_bounds__(default_block_size, sm_oversubscription) void compute_gen_dot_product_kernel( + const gko::batch_multi_vector::uniform_batch x, + const gko::batch_multi_vector::uniform_batch y, + const gko::batch_multi_vector::uniform_batch result, + Mapping map) { for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_entries; batch_id += gridDim.x) { @@ -218,11 +220,19 @@ __device__ __forceinline__ void compute_norm2( template -__global__ -__launch_bounds__(default_block_size, sm_multiplier) void compute_norm2_kernel( - const gko::batch_multi_vector::uniform_batch x, - const gko::batch_multi_vector::uniform_batch> - result) +__global__ __launch_bounds__( + default_block_size, + sm_oversubscription) void compute_norm2_kernel(const gko:: + batch_multi_vector:: + uniform_batch< + const ValueType> + x, + const gko:: + batch_multi_vector:: + uniform_batch< + remove_complex< + ValueType>> + result) { for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_entries; batch_id += gridDim.x) { @@ -255,9 +265,9 @@ __device__ __forceinline__ void copy( template __global__ -__launch_bounds__(default_block_size, sm_multiplier) void copy_kernel( - const gko::batch_multi_vector::uniform_batch src, - const gko::batch_multi_vector::uniform_batch dst) + __launch_bounds__(default_block_size, sm_oversubscription) void copy_kernel( + const gko::batch_multi_vector::uniform_batch src, + const gko::batch_multi_vector::uniform_batch dst) { for (size_type batch_id = blockIdx.x; batch_id < src.num_batch_entries; batch_id += gridDim.x) { diff --git a/core/base/batch_multi_vector.cpp b/core/base/batch_multi_vector.cpp index 9b5b908f5d1..559ff4478fd 100644 --- a/core/base/batch_multi_vector.cpp +++ b/core/base/batch_multi_vector.cpp @@ -65,6 +65,86 @@ GKO_REGISTER_OPERATION(copy, batch_multi_vector::copy); } // namespace } // namespace batch_multi_vector +namespace detail { + + +template +batch_dim<2> compute_batch_size( + const std::vector*>& matrices) +{ + auto common_size = matrices[0]->get_size(); + for (size_type i = 1; i < matrices.size(); ++i) { + GKO_ASSERT_EQUAL_DIMENSIONS(common_size, matrices[i]->get_size()); + } + return batch_dim<2>{matrices.size(), common_size}; +} + + +} // namespace detail + + +template +BatchMultiVector::BatchMultiVector( + std::shared_ptr exec, const batch_dim<2>& size) + : EnablePolymorphicObject>(exec), + batch_size_(size), + values_(exec, compute_num_elems(size)) +{} + + +template +BatchMultiVector::BatchMultiVector( + std::shared_ptr exec, + const std::vector*>& matrices) + : EnablePolymorphicObject>(exec), + batch_size_{detail::compute_batch_size(matrices)}, + values_(exec, compute_num_elems(batch_size_)) +{ + for (size_type i = 0; i < this->get_num_batch_entries(); ++i) { + auto local_exec = matrices[i]->get_executor(); + exec->copy_from( + local_exec.get(), matrices[i]->get_num_stored_elements(), + matrices[i]->get_const_values(), + this->get_values() + this->get_size().get_cumulative_offset(i)); + } +} + + +template +BatchMultiVector::BatchMultiVector( + std::shared_ptr exec, size_type num_duplications, + const matrix::Dense* input) + : BatchMultiVector( + exec, gko::batch_dim<2>(num_duplications, input->get_size())) +{ + size_type offset = 0; + for (size_type i = 0; i < num_duplications; ++i) { + exec->copy_from(input->get_executor().get(), + input->get_num_stored_elements(), + input->get_const_values(), this->get_values() + offset); + offset += input->get_num_stored_elements(); + } +} + + +template +BatchMultiVector::BatchMultiVector( + std::shared_ptr exec, size_type num_duplications, + const BatchMultiVector* input) + : BatchMultiVector( + exec, + gko::batch_dim<2>(input->get_num_batch_entries() * num_duplications, + input->get_common_size())) +{ + size_type offset = 0; + for (size_type i = 0; i < num_duplications; ++i) { + exec->copy_from(input->get_executor().get(), + input->get_num_stored_elements(), + input->get_const_values(), this->get_values() + offset); + offset += input->get_num_stored_elements(); + } +} + template std::unique_ptr> diff --git a/core/test/base/batch_multi_vector.cpp b/core/test/base/batch_multi_vector.cpp index a201a80f741..486a8301cf6 100644 --- a/core/test/base/batch_multi_vector.cpp +++ b/core/test/base/batch_multi_vector.cpp @@ -97,6 +97,7 @@ TYPED_TEST_SUITE(BatchMultiVector, gko::test::ValueTypes); TYPED_TEST(BatchMultiVector, CanBeEmpty) { auto empty = gko::BatchMultiVector::create(this->exec); + this->assert_empty(empty.get()); } @@ -104,6 +105,7 @@ TYPED_TEST(BatchMultiVector, CanBeEmpty) TYPED_TEST(BatchMultiVector, KnowsItsSizeAndValues) { ASSERT_NE(this->mtx->get_const_values(), nullptr); + this->assert_equal_to_original_mtx(this->mtx.get()); } @@ -119,7 +121,9 @@ TYPED_TEST(BatchMultiVector, CanGetValuesForEntry) TYPED_TEST(BatchMultiVector, CanBeCopied) { auto mtx_copy = gko::BatchMultiVector::create(this->exec); + mtx_copy->copy_from(this->mtx.get()); + this->assert_equal_to_original_mtx(this->mtx.get()); this->mtx->at(0, 0, 0) = 7; this->mtx->at(0, 1) = 7; @@ -130,7 +134,9 @@ TYPED_TEST(BatchMultiVector, CanBeCopied) TYPED_TEST(BatchMultiVector, CanBeMoved) { auto mtx_copy = gko::BatchMultiVector::create(this->exec); + this->mtx->move_to(mtx_copy.get()); + this->assert_equal_to_original_mtx(mtx_copy.get()); } @@ -138,6 +144,7 @@ TYPED_TEST(BatchMultiVector, CanBeMoved) TYPED_TEST(BatchMultiVector, CanBeCloned) { auto mtx_clone = this->mtx->clone(); + this->assert_equal_to_original_mtx( dynamic_castmtx.get())>(mtx_clone.get())); } @@ -146,6 +153,7 @@ TYPED_TEST(BatchMultiVector, CanBeCloned) TYPED_TEST(BatchMultiVector, CanBeCleared) { this->mtx->clear(); + this->assert_empty(this->mtx.get()); } @@ -153,6 +161,7 @@ TYPED_TEST(BatchMultiVector, CanBeCleared) TYPED_TEST(BatchMultiVector, CanBeConstructedWithSize) { using size_type = gko::size_type; + auto m = gko::BatchMultiVector::create( this->exec, gko::batch_dim<2>(2, gko::dim<2>(2, 4))); @@ -281,6 +290,7 @@ TYPED_TEST(BatchMultiVector, CanBeConstructedFromBatchMultiVectorMatrices) TYPED_TEST(BatchMultiVector, CanBeListConstructed) { using value_type = typename TestFixture::value_type; + auto m = gko::batch_initialize>( {{1.0, 2.0}, {1.0, 3.0}}, this->exec); @@ -296,6 +306,7 @@ TYPED_TEST(BatchMultiVector, CanBeListConstructed) TYPED_TEST(BatchMultiVector, CanBeListConstructedByCopies) { using value_type = typename TestFixture::value_type; + auto m = gko::batch_initialize>( 2, I({1.0, 2.0}), this->exec); @@ -312,6 +323,7 @@ TYPED_TEST(BatchMultiVector, CanBeDoubleListConstructed) { using value_type = typename TestFixture::value_type; using T = value_type; + auto m = gko::batch_initialize>( {{I{1.0, 1.0, 0.0}, I{2.0, 4.0, 3.0}, I{3.0, 6.0, 1.0}}, {I{1.0, 2.0, -1.0}, I{3.0, 4.0, -2.0}, I{5.0, 6.0, -3.0}}}, @@ -401,6 +413,7 @@ TYPED_TEST(BatchMultiVector, CanBeReadFromSparseMatrixData) { using value_type = typename TestFixture::value_type; auto m = gko::BatchMultiVector::create(this->exec); + // clang-format off m->read({gko::matrix_data{{2, 2}, {{0, 0, 1.0}, diff --git a/cuda/base/batch_multi_vector_kernels.cu b/cuda/base/batch_multi_vector_kernels.cu index 3fd80a2aa41..3e44b006552 100644 --- a/cuda/base/batch_multi_vector_kernels.cu +++ b/cuda/base/batch_multi_vector_kernels.cu @@ -65,7 +65,7 @@ namespace batch_multi_vector { constexpr auto default_block_size = 256; -constexpr int sm_multiplier = 4; +constexpr int sm_oversubscription = 4; // clang-format off diff --git a/cuda/base/batch_struct.hpp b/cuda/base/batch_struct.hpp index d9907b41531..70bc42aecac 100644 --- a/cuda/base/batch_struct.hpp +++ b/cuda/base/batch_struct.hpp @@ -51,7 +51,7 @@ namespace cuda { /** @file batch_struct.hpp * * Helper functions to generate a batch struct from a batch LinOp, - * while also shallow-casting to the requried CUDA scalar type. + * while also shallow-casting to the required CUDA scalar type. * * A specialization is needed for every format of every kind of linear algebra * object. These are intended to be called on the host. diff --git a/dpcpp/base/batch_struct.hpp b/dpcpp/base/batch_struct.hpp index c9ee5800b3e..4f8d8aa0350 100644 --- a/dpcpp/base/batch_struct.hpp +++ b/dpcpp/base/batch_struct.hpp @@ -50,7 +50,7 @@ namespace dpcpp { /** @file batch_struct.hpp * * Helper functions to generate a batch struct from a batch LinOp, - * while also shallow-casting to the requried DPCPP scalar type. + * while also shallow-casting to the required DPCPP scalar type. * * A specialization is needed for every format of every kind of linear algebra * object. These are intended to be called on the host. diff --git a/hip/base/batch_multi_vector_kernels.hip.cpp b/hip/base/batch_multi_vector_kernels.hip.cpp index 40e828b5d45..bb465ac7709 100644 --- a/hip/base/batch_multi_vector_kernels.hip.cpp +++ b/hip/base/batch_multi_vector_kernels.hip.cpp @@ -66,7 +66,7 @@ namespace batch_multi_vector { constexpr auto default_block_size = 256; -constexpr int sm_multiplier = 4; +constexpr int sm_oversubscription = 4; // clang-format off diff --git a/hip/base/batch_struct.hip.hpp b/hip/base/batch_struct.hip.hpp index 3171e7e1df8..55f81f7eaff 100644 --- a/hip/base/batch_struct.hip.hpp +++ b/hip/base/batch_struct.hip.hpp @@ -51,7 +51,7 @@ namespace hip { /** @file batch_struct.hpp * * Helper functions to generate a batch struct from a batch LinOp, - * while also shallow-casting to the requried Hip scalar type. + * while also shallow-casting to the required Hip scalar type. * * A specialization is needed for every format of every kind of linear algebra * object. These are intended to be called on the host. diff --git a/include/ginkgo/core/base/batch_lin_op_helpers.hpp b/include/ginkgo/core/base/batch_lin_op_helpers.hpp index 6dd9297614a..5d1a2f8ed0d 100644 --- a/include/ginkgo/core/base/batch_lin_op_helpers.hpp +++ b/include/ginkgo/core/base/batch_lin_op_helpers.hpp @@ -37,6 +37,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include #include +#include #include diff --git a/include/ginkgo/core/base/batch_multi_vector.hpp b/include/ginkgo/core/base/batch_multi_vector.hpp index f7c8258121f..40f551dbd5f 100644 --- a/include/ginkgo/core/base/batch_multi_vector.hpp +++ b/include/ginkgo/core/base/batch_multi_vector.hpp @@ -89,6 +89,7 @@ class BatchMultiVector friend class EnableCreateMethod; friend class EnablePolymorphicObject; friend class BatchMultiVector>; + friend class BatchMultiVector>; public: using BatchReadableFromMatrixData::read; @@ -107,8 +108,6 @@ class BatchMultiVector using absolute_type = remove_complex>; using complex_type = to_complex>; - using row_major_range = gko::range>; - /** * Creates a BatchMultiVector with the configuration of another * BatchMultiVector. @@ -118,8 +117,6 @@ class BatchMultiVector static std::unique_ptr create_with_config_of( ptr_param other); - friend class BatchMultiVector>; - void convert_to( BatchMultiVector>* result) const override; @@ -175,13 +172,10 @@ class BatchMultiVector * * @return the pointer to the array of values */ - value_type* get_values(size_type batch_id = 0) noexcept - { - return values_.get_data(); - } + value_type* get_values() noexcept { return values_.get_data(); } /** - * @copydoc get_values(size_type) + * @copydoc get_values() * * @note This is the constant version of the function, which can be * significantly more memory efficient than the non-constant version, @@ -374,22 +368,11 @@ class BatchMultiVector void fill(ValueType value); private: - inline batch_dim<2> compute_batch_size( - const std::vector*>& matrices) - { - auto common_size = matrices[0]->get_size(); - for (int i = 1; i < matrices.size(); ++i) { - GKO_ASSERT_EQUAL_DIMENSIONS(common_size, matrices[i]->get_size()); - } - return batch_dim<2>{matrices.size(), common_size}; - } - inline size_type compute_num_elems(const batch_dim<2>& size) { return size.get_cumulative_offset(size.get_num_batch_entries()); } - protected: /** * Sets the size of the BatchMultiVector. @@ -403,14 +386,10 @@ class BatchMultiVector * size. * * @param exec Executor associated to the vector - * @param size size of the vector + * @param size size of the batch multi vector */ BatchMultiVector(std::shared_ptr exec, - const batch_dim<2>& size = batch_dim<2>{}) - : EnablePolymorphicObject>(exec), - batch_size_(size), - values_(exec, compute_num_elems(size)) - {} + const batch_dim<2>& size = batch_dim<2>{}); /** * Creates a BatchMultiVector from an already allocated (and @@ -446,24 +425,12 @@ class BatchMultiVector * * @note This is a utility function that can serve as a first step to port * to batched data-structures and solvers. Even if the matrices are in - * device memory, this method can have siginificant overhead, as new + * device memory, this method can have significant overhead, as new * allocations and deep copies are necessary and hence this constructor must * not be used in performance sensitive applications */ BatchMultiVector(std::shared_ptr exec, - const std::vector*>& matrices) - : EnablePolymorphicObject>(exec), - batch_size_{compute_batch_size(matrices)}, - values_(exec, compute_num_elems(batch_size_)) - { - for (size_type i = 0; i < this->get_num_batch_entries(); ++i) { - auto local_exec = matrices[i]->get_executor(); - exec->copy_from( - local_exec.get(), matrices[i]->get_num_stored_elements(), - matrices[i]->get_const_values(), - this->get_values() + this->get_size().get_cumulative_offset(i)); - } - } + const std::vector*>& matrices); /** * Creates a BatchMultiVector matrix by duplicating BatchMultiVector object @@ -474,26 +441,13 @@ class BatchMultiVector * * @note This is a utility function that can serve as a first step to port * to batched data-structures and solvers. Even if the matrices are in - * device memory, this method can have siginificant overhead, as new + * device memory, this method can have significant overhead, as new * allocations and deep copies are necessary and hence this constructor must * not be used in performance sensitive applications. */ BatchMultiVector(std::shared_ptr exec, size_type num_duplications, - const BatchMultiVector* input) - : BatchMultiVector( - exec, gko::batch_dim<2>( - input->get_num_batch_entries() * num_duplications, - input->get_common_size())) - { - size_type offset = 0; - for (size_type i = 0; i < num_duplications; ++i) { - exec->copy_from( - input->get_executor().get(), input->get_num_stored_elements(), - input->get_const_values(), this->get_values() + offset); - offset += input->get_num_stored_elements(); - } - } + const BatchMultiVector* input); /** * Creates a BatchMultiVector matrix by a duplicating a matrix::Dense object @@ -504,18 +458,7 @@ class BatchMultiVector */ BatchMultiVector(std::shared_ptr exec, size_type num_duplications, - const matrix::Dense* input) - : BatchMultiVector( - exec, gko::batch_dim<2>(num_duplications, input->get_size())) - { - size_type offset = 0; - for (size_type i = 0; i < num_duplications; ++i) { - exec->copy_from( - input->get_executor().get(), input->get_num_stored_elements(), - input->get_const_values(), this->get_values() + offset); - offset += input->get_num_stored_elements(); - } - } + const matrix::Dense* input); /** * Creates a BatchMultiVector with the same configuration as the diff --git a/reference/test/base/batch_multi_vector_kernels.cpp b/reference/test/base/batch_multi_vector_kernels.cpp index f6ae66d8249..506695c8d4f 100644 --- a/reference/test/base/batch_multi_vector_kernels.cpp +++ b/reference/test/base/batch_multi_vector_kernels.cpp @@ -59,7 +59,6 @@ class BatchMultiVector : public ::testing::Test { using Mtx = gko::BatchMultiVector; using DenseMtx = gko::matrix::Dense; using ComplexMtx = gko::to_complex; - using RealMtx = gko::remove_complex; BatchMultiVector() : exec(gko::ReferenceExecutor::create()), mtx_0(gko::batch_initialize( @@ -165,7 +164,7 @@ TYPED_TEST(BatchMultiVector, ScalesDataWithScalar) } -TYPED_TEST(BatchMultiVector, ScalesDataWithStride) +TYPED_TEST(BatchMultiVector, ScalesDataWithMultipleScalars) { using Mtx = typename TestFixture::Mtx; using T = typename TestFixture::value_type; @@ -261,15 +260,12 @@ TYPED_TEST(BatchMultiVector, ComputeDotFailsOnWrongInputSize) TYPED_TEST(BatchMultiVector, ComputeDotFailsOnWrongResultSize) { using Mtx = typename TestFixture::Mtx; + auto result = Mtx::create(this->exec, gko::batch_dim<2>(2, gko::dim<2>{1, 2})); - auto result2 = - Mtx::create(this->exec, gko::batch_dim<2>(2, gko::dim<2>{1, 2})); ASSERT_THROW(this->mtx_0->compute_dot(this->mtx_1.get(), result.get()), gko::DimensionMismatch); - ASSERT_THROW(this->mtx_0->compute_dot(this->mtx_1.get(), result2.get()), - gko::DimensionMismatch); } @@ -305,16 +301,12 @@ TYPED_TEST(BatchMultiVector, ComputeConjDotFailsOnWrongInputSize) TYPED_TEST(BatchMultiVector, ComputeConjDotFailsOnWrongResultSize) { using Mtx = typename TestFixture::Mtx; + auto result = Mtx::create(this->exec, gko::batch_dim<2>(2, gko::dim<2>{1, 2})); - auto result2 = - Mtx::create(this->exec, gko::batch_dim<2>(2, gko::dim<2>{1, 2})); ASSERT_THROW(this->mtx_0->compute_conj_dot(this->mtx_1.get(), result.get()), gko::DimensionMismatch); - ASSERT_THROW( - this->mtx_0->compute_conj_dot(this->mtx_1.get(), result2.get()), - gko::DimensionMismatch); } diff --git a/test/base/batch_multi_vector_kernels.cpp b/test/base/batch_multi_vector_kernels.cpp index 015adbce798..b397ef3b1af 100644 --- a/test/base/batch_multi_vector_kernels.cpp +++ b/test/base/batch_multi_vector_kernels.cpp @@ -40,7 +40,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include -#include #include @@ -60,11 +59,11 @@ class BatchMultiVector : public CommonTestFixture { BatchMultiVector() : rand_engine(15) {} template - std::unique_ptr gen_mtx(const size_t batch_size, int num_rows, - int num_cols) + std::unique_ptr gen_mtx(const size_t num_batch_entries, + int num_rows, int num_cols) { return gko::test::generate_uniform_batch_random_matrix( - batch_size, num_rows, num_cols, + num_batch_entries, num_rows, num_cols, std::uniform_int_distribution<>(num_cols, num_cols), std::normal_distribution<>(-1.0, 1.0), rand_engine, false, ref); } @@ -75,6 +74,8 @@ class BatchMultiVector : public CommonTestFixture { const int num_rows = 252; x = gen_mtx(batch_size, num_rows, num_vecs); y = gen_mtx(batch_size, num_rows, num_vecs); + c_x = gen_mtx(batch_size, num_rows, num_vecs); + c_y = gen_mtx(batch_size, num_rows, num_vecs); if (different_alpha) { alpha = gen_mtx(batch_size, 1, num_vecs); beta = gen_mtx(batch_size, 1, num_vecs); @@ -84,6 +85,8 @@ class BatchMultiVector : public CommonTestFixture { } dx = gko::clone(exec, x); dy = gko::clone(exec, y); + dc_x = gko::clone(exec, c_x); + dc_y = gko::clone(exec, c_y); dalpha = gko::clone(exec, alpha); dbeta = gko::clone(exec, beta); expected = Mtx::create( @@ -97,6 +100,7 @@ class BatchMultiVector : public CommonTestFixture { const size_t batch_size = 11; std::unique_ptr x; std::unique_ptr c_x; + std::unique_ptr c_y; std::unique_ptr y; std::unique_ptr alpha; std::unique_ptr beta; @@ -105,6 +109,7 @@ class BatchMultiVector : public CommonTestFixture { std::unique_ptr dresult; std::unique_ptr dx; std::unique_ptr dc_x; + std::unique_ptr dc_y; std::unique_ptr dy; std::unique_ptr dalpha; std::unique_ptr dbeta; @@ -216,11 +221,16 @@ TEST_F(BatchMultiVector, ComputeDotIsEquivalentToRef) gko::batch_dim<2>(batch_size, gko::dim<2>{1, x->get_common_size()[1]}); auto dot_expected = Mtx::create(this->ref, dot_size); auto ddot = Mtx::create(this->exec, dot_size); + auto cdot_expected = ComplexMtx::create(this->ref, dot_size); + auto dc_dot = ComplexMtx::create(this->exec, dot_size); x->compute_dot(y.get(), dot_expected.get()); dx->compute_dot(dy.get(), ddot.get()); + c_x->compute_dot(c_y.get(), cdot_expected.get()); + dc_x->compute_dot(dc_y.get(), dc_dot.get()); GKO_ASSERT_BATCH_MTX_NEAR(dot_expected, ddot, 5 * r::value); + GKO_ASSERT_BATCH_MTX_NEAR(cdot_expected, dc_dot, 5 * r::value); } @@ -246,11 +256,16 @@ TEST_F(BatchMultiVector, ComputeConjDotIsEquivalentToRef) gko::batch_dim<2>(batch_size, gko::dim<2>{1, x->get_common_size()[1]}); auto dot_expected = Mtx::create(this->ref, dot_size); auto ddot = Mtx::create(this->exec, dot_size); + auto cdot_expected = ComplexMtx::create(this->ref, dot_size); + auto dc_dot = ComplexMtx::create(this->exec, dot_size); x->compute_conj_dot(y.get(), dot_expected.get()); dx->compute_conj_dot(dy.get(), ddot.get()); + c_x->compute_conj_dot(c_y.get(), cdot_expected.get()); + dc_x->compute_conj_dot(dc_y.get(), dc_dot.get()); GKO_ASSERT_BATCH_MTX_NEAR(dot_expected, ddot, 5 * r::value); + GKO_ASSERT_BATCH_MTX_NEAR(cdot_expected, dc_dot, 5 * r::value); }