diff --git a/ci/docker/Dockerfile.base b/ci/docker/Dockerfile.base new file mode 100644 index 0000000..f1c85e8 --- /dev/null +++ b/ci/docker/Dockerfile.base @@ -0,0 +1,6 @@ +FROM nvcr.io/nvidia/pytorch:23.10-py3 + +RUN apt-get update + +# install boost test framework +RUN apt-get install -y libboost-test-dev diff --git a/ci/pipeline.yml b/ci/pipeline.yml new file mode 100644 index 0000000..21625eb --- /dev/null +++ b/ci/pipeline.yml @@ -0,0 +1,36 @@ +include: + - remote: 'https://gitlab.com/cscs-ci/recipes/-/raw/master/templates/v2/.ci-ext.yml' + +stages: + - build + - test + +build_base_image_job: + stage: build + extends: .container-builder-dynamic-name + timeout: 2h + variables: + DOCKERFILE: ci/docker/Dockerfile.base + WATCH_FILECHANGES: $DOCKERFILE + PERSIST_IMAGE_NAME: $CSCS_REGISTRY_PATH/base/public/mops + +test_job: + stage: test + extends: .container-runner-daint-gpu + image: $BASE_IMAGE + timeout: 2h + script: + - export CUDA_HOME="/usr/local/cuda" + - python3 -m pip install --upgrade pip + - echo "Install Tox" + - python3 -m pip install tox + - echo "Run the Tox Script" + - tox + - echo "Tox script completed" + + variables: + SLURM_JOB_NUM_NODES: 1 + SLURM_PARTITION: normal + SLURM_NTASKS: 1 + SLURM_TIMELIMIT: '00:40:00' + GIT_STRATEGY: fetch diff --git a/mops-torch/src/hpe.cpp b/mops-torch/src/hpe.cpp index dc74031..08f5664 100644 --- a/mops-torch/src/hpe.cpp +++ b/mops-torch/src/hpe.cpp @@ -1,3 +1,8 @@ +#ifdef MOPS_CUDA_ENABLED +#include +#include +#endif + #include "mops/torch/hpe.hpp" #include "mops/torch/utils.hpp" @@ -38,15 +43,25 @@ torch::Tensor HomogeneousPolynomialEvaluation::forward( }); } else if (A.device().is_cuda()) { +#ifndef MOPS_CUDA_ENABLED + C10_THROW_ERROR(ValueError, "MOPS was not compiled with CUDA support " + A.device().str()); +#else + c10::cuda::CUDAGuard deviceGuard{A.device()}; + cudaStream_t currstream = c10::cuda::getCurrentCUDAStream(); + void* stream = reinterpret_cast(currstream); + AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "homogeneous_polynomial_evaluation", [&]() { mops::cuda::homogeneous_polynomial_evaluation( details::torch_to_mops_1d(output), details::torch_to_mops_2d(A), details::torch_to_mops_1d(C), - details::torch_to_mops_2d(indices_A) + details::torch_to_mops_2d(indices_A), + stream ); }); +#endif + } else { C10_THROW_ERROR( ValueError, @@ -108,6 +123,12 @@ torch::Tensor HomogeneousPolynomialEvaluationBackward::forward( ); }); } else if (A.device().is_cuda()) { +#ifndef MOPS_CUDA_ENABLED + C10_THROW_ERROR(ValueError, "MOPS was not compiled with CUDA support " + A.device().str()); +#else + c10::cuda::CUDAGuard deviceGuard{A.device()}; + cudaStream_t currstream = c10::cuda::getCurrentCUDAStream(); + void* stream = reinterpret_cast(currstream); AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "homogeneous_polynomial_evaluation_vjp", [&]() { auto mops_grad_A = mops::Tensor{nullptr, {0, 0}}; @@ -121,9 +142,11 @@ torch::Tensor HomogeneousPolynomialEvaluationBackward::forward( details::torch_to_mops_1d(grad_output), details::torch_to_mops_2d(A), details::torch_to_mops_1d(C), - details::torch_to_mops_2d(indices_A) + details::torch_to_mops_2d(indices_A), + stream ); }); +#endif } else { C10_THROW_ERROR( ValueError, diff --git a/mops-torch/src/opsa.cpp b/mops-torch/src/opsa.cpp index 0ff4daa..43f04bc 100644 --- a/mops-torch/src/opsa.cpp +++ b/mops-torch/src/opsa.cpp @@ -1,3 +1,8 @@ +#ifdef MOPS_CUDA_ENABLED +#include +#include +#endif + #include "mops/torch/opsa.hpp" #include "mops/torch/utils.hpp" @@ -48,6 +53,10 @@ torch::Tensor OuterProductScatterAdd::forward( #ifndef MOPS_CUDA_ENABLED C10_THROW_ERROR(ValueError, "MOPS was not compiled with CUDA support " + A.device().str()); #else + c10::cuda::CUDAGuard deviceGuard{A.device()}; + cudaStream_t currstream = c10::cuda::getCurrentCUDAStream(); + void* stream = reinterpret_cast(currstream); + output = torch::empty( {output_size, A.size(1), B.size(1)}, torch::TensorOptions().dtype(A.scalar_type()).device(A.device()) @@ -58,7 +67,8 @@ torch::Tensor OuterProductScatterAdd::forward( details::torch_to_mops_3d(output), details::torch_to_mops_2d(A), details::torch_to_mops_2d(B), - details::torch_to_mops_1d(indices_output) + details::torch_to_mops_1d(indices_output), + stream ); }); @@ -130,6 +140,10 @@ std::vector OuterProductScatterAddBackward::forward( #ifndef MOPS_CUDA_ENABLED C10_THROW_ERROR(ValueError, "MOPS was not compiled with CUDA support " + A.device().str()); #else + c10::cuda::CUDAGuard deviceGuard{A.device()}; + cudaStream_t currstream = c10::cuda::getCurrentCUDAStream(); + void* stream = reinterpret_cast(currstream); + AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "outer_product_scatter_add_vjp", [&]() { auto mops_grad_A = mops::Tensor{nullptr, {0, 0}}; @@ -150,7 +164,8 @@ std::vector OuterProductScatterAddBackward::forward( details::torch_to_mops_3d(grad_output), details::torch_to_mops_2d(A), details::torch_to_mops_2d(B), - details::torch_to_mops_1d(indices_output) + details::torch_to_mops_1d(indices_output), + stream ); }); #endif @@ -228,9 +243,52 @@ std::vector OuterProductScatterAddBackward::backward( #ifndef MOPS_CUDA_ENABLED C10_THROW_ERROR(ValueError, "MOPS was not compiled with CUDA support " + A.device().str()); #else - C10_THROW_ERROR( - ValueError, "outer_product_scatter_add_vjp_vjp is not implemented for CUDA yet" - ); + c10::cuda::CUDAGuard deviceGuard{A.device()}; + cudaStream_t currstream = c10::cuda::getCurrentCUDAStream(); + void* stream = reinterpret_cast(currstream); + + AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "outer_product_scatter_add_vjp", [&]() { + auto mops_grad_grad_output = mops::Tensor{nullptr, {0, 0, 0}}; + if (grad_output.requires_grad()) { + grad_grad_output = torch::empty_like(grad_output); + mops_grad_grad_output = details::torch_to_mops_3d(grad_grad_output); + } + + auto mops_grad_A_2 = mops::Tensor{nullptr, {0, 0}}; + if (A.requires_grad()) { + grad_A_2 = torch::empty_like(A); + mops_grad_A_2 = details::torch_to_mops_2d(grad_A_2); + } + + auto mops_grad_B_2 = mops::Tensor{nullptr, {0, 0}}; + if (B.requires_grad()) { + grad_B_2 = torch::empty_like(B); + mops_grad_B_2 = details::torch_to_mops_2d(grad_B_2); + } + + auto mops_grad_grad_A = mops::Tensor{nullptr, {0, 0}}; + if (grad_grad_A.defined()) { + mops_grad_grad_A = details::torch_to_mops_2d(grad_grad_A); + } + + auto mops_grad_grad_B = mops::Tensor{nullptr, {0, 0}}; + if (grad_grad_B.defined()) { + mops_grad_grad_B = details::torch_to_mops_2d(grad_grad_B); + } + + mops::cuda::outer_product_scatter_add_vjp_vjp( + mops_grad_grad_output, + mops_grad_A_2, + mops_grad_B_2, + mops_grad_grad_A, + mops_grad_grad_B, + details::torch_to_mops_3d(grad_output), + details::torch_to_mops_2d(A), + details::torch_to_mops_2d(B), + details::torch_to_mops_1d(indices_output), + stream + ); + }); #endif } else { C10_THROW_ERROR( diff --git a/mops-torch/src/opsaw.cpp b/mops-torch/src/opsaw.cpp index 608598f..a8f313d 100644 --- a/mops-torch/src/opsaw.cpp +++ b/mops-torch/src/opsaw.cpp @@ -1,3 +1,8 @@ +#ifdef MOPS_CUDA_ENABLED +#include +#include +#endif + #include "mops/torch/opsaw.hpp" #include "mops/torch/utils.hpp" diff --git a/mops-torch/src/sap.cpp b/mops-torch/src/sap.cpp index a68cca5..c7ae354 100644 --- a/mops-torch/src/sap.cpp +++ b/mops-torch/src/sap.cpp @@ -1,3 +1,8 @@ +#ifdef MOPS_CUDA_ENABLED +#include +#include +#endif + #include "mops/torch/sap.hpp" #include "mops/torch/utils.hpp" @@ -59,6 +64,14 @@ torch::Tensor SparseAccumulationOfProducts::forward( ); }); } else if (A.device().is_cuda()) { + +#ifndef MOPS_CUDA_ENABLED + C10_THROW_ERROR(ValueError, "MOPS was not compiled with CUDA support " + A.device().str()); +#else + c10::cuda::CUDAGuard deviceGuard{A.device()}; + cudaStream_t currstream = c10::cuda::getCurrentCUDAStream(); + void* stream = reinterpret_cast(currstream); + output = torch::empty( {A.size(0), output_size}, torch::TensorOptions().dtype(A.scalar_type()).device(A.device()) @@ -72,9 +85,11 @@ torch::Tensor SparseAccumulationOfProducts::forward( details::torch_to_mops_1d(C), details::torch_to_mops_1d(indices_A), details::torch_to_mops_1d(indices_B), - details::torch_to_mops_1d(indices_output) + details::torch_to_mops_1d(indices_output), + stream ); }); +#endif } else { C10_THROW_ERROR( ValueError, @@ -170,6 +185,14 @@ std::vector SparseAccumulationOfProductsBackward::forward( ); }); } else if (A.device().is_cuda()) { + +#ifndef MOPS_CUDA_ENABLED + C10_THROW_ERROR(ValueError, "MOPS was not compiled with CUDA support " + A.device().str()); +#else + c10::cuda::CUDAGuard deviceGuard{A.device()}; + cudaStream_t currstream = c10::cuda::getCurrentCUDAStream(); + void* stream = reinterpret_cast(currstream); + AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "sparse_accumulation_of_products_vjp", [&]() { auto mops_grad_A = mops::Tensor{nullptr, {0, 0}}; if (A.requires_grad()) { @@ -192,9 +215,11 @@ std::vector SparseAccumulationOfProductsBackward::forward( details::torch_to_mops_1d(C), details::torch_to_mops_1d(indices_A), details::torch_to_mops_1d(indices_B), - details::torch_to_mops_1d(indices_output) + details::torch_to_mops_1d(indices_output), + stream ); }); +#endif } else { C10_THROW_ERROR( ValueError, @@ -276,6 +301,10 @@ std::vector SparseAccumulationOfProductsBackward::backward( #ifndef MOPS_CUDA_ENABLED C10_THROW_ERROR(ValueError, "MOPS was not compiled with CUDA support " + A.device().str()); #else + c10::cuda::CUDAGuard deviceGuard{A.device()}; + cudaStream_t currstream = c10::cuda::getCurrentCUDAStream(); + void* stream = reinterpret_cast(currstream); + AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "sparse_accumulation_of_products_vjp_vjp", [&]() { auto mops_grad_grad_output = mops::Tensor{nullptr, {0, 0}}; if (grad_output.requires_grad()) { @@ -317,7 +346,8 @@ std::vector SparseAccumulationOfProductsBackward::backward( details::torch_to_mops_1d(C), details::torch_to_mops_1d(indices_A), details::torch_to_mops_1d(indices_B), - details::torch_to_mops_1d(indices_output) + details::torch_to_mops_1d(indices_output), + stream ); }); #endif diff --git a/mops-torch/src/sasaw.cpp b/mops-torch/src/sasaw.cpp index 2868e23..ee8ea76 100644 --- a/mops-torch/src/sasaw.cpp +++ b/mops-torch/src/sasaw.cpp @@ -1,3 +1,8 @@ +#ifdef MOPS_CUDA_ENABLED +#include +#include +#endif + #include "mops/torch/sasaw.hpp" #include "mops/torch/utils.hpp" diff --git a/mops/CMakeLists.txt b/mops/CMakeLists.txt index 846eb47..f75a638 100644 --- a/mops/CMakeLists.txt +++ b/mops/CMakeLists.txt @@ -124,6 +124,8 @@ if(CMAKE_CUDA_COMPILER AND MOPS_CUDA) "src/opsa/opsa.cu" "src/hpe/hpe.cu" "src/sap/sap.cu" + "src/sasaw/sasaw.cu" + "src/opsaw/opsaw.cu" ) endif() diff --git a/mops/include/mops/hpe.h b/mops/include/mops/hpe.h index ca18489..d1cf15b 100644 --- a/mops/include/mops/hpe.h +++ b/mops/include/mops/hpe.h @@ -69,7 +69,8 @@ int MOPS_EXPORT mops_cuda_homogeneous_polynomial_evaluation_f32( mops_tensor_1d_f32_t output, mops_tensor_2d_f32_t A, mops_tensor_1d_f32_t C, - mops_tensor_2d_i32_t indices_A + mops_tensor_2d_i32_t indices_A, + void* cuda_stream ); /// CUDA version of mops::homogeneous_polynomial_evaluation for 64-bit floats @@ -77,7 +78,8 @@ int MOPS_EXPORT mops_cuda_homogeneous_polynomial_evaluation_f64( mops_tensor_1d_f64_t output, mops_tensor_2d_f64_t A, mops_tensor_1d_f64_t C, - mops_tensor_2d_i32_t indices_A + mops_tensor_2d_i32_t indices_A, + void* cuda_stream ); /// CUDA version of mops::homogeneous_polynomial_evaluation_vjp for 32-bit floats @@ -86,7 +88,8 @@ int MOPS_EXPORT mops_cuda_homogeneous_polynomial_evaluation_vjp_f32( mops_tensor_1d_f32_t grad_output, mops_tensor_2d_f32_t A, mops_tensor_1d_f32_t C, - mops_tensor_2d_i32_t indices_A + mops_tensor_2d_i32_t indices_A, + void* cuda_stream ); /// CUDA version of mops::homogeneous_polynomial_evaluation_vjp for 64-bit floats @@ -95,7 +98,8 @@ int MOPS_EXPORT mops_cuda_homogeneous_polynomial_evaluation_vjp_f64( mops_tensor_1d_f64_t grad_output, mops_tensor_2d_f64_t A, mops_tensor_1d_f64_t C, - mops_tensor_2d_i32_t indices_A + mops_tensor_2d_i32_t indices_A, + void* cuda_stream ); /// CUDA version of mops::homogeneous_polynomial_evaluation_vjp_vjp for 32-bit floats @@ -106,7 +110,8 @@ int MOPS_EXPORT mops_cuda_homogeneous_polynomial_evaluation_vjp_vjp_f32( mops_tensor_1d_f32_t grad_output, mops_tensor_2d_f32_t A, mops_tensor_1d_f32_t C, - mops_tensor_2d_i32_t indices_A + mops_tensor_2d_i32_t indices_A, + void* cuda_stream ); /// CUDA version of mops::homogeneous_polynomial_evaluation_vjp_vjp for 64-bit floats @@ -117,7 +122,8 @@ int MOPS_EXPORT mops_cuda_homogeneous_polynomial_evaluation_vjp_vjp_f64( mops_tensor_1d_f64_t grad_output, mops_tensor_2d_f64_t A, mops_tensor_1d_f64_t C, - mops_tensor_2d_i32_t indices_A + mops_tensor_2d_i32_t indices_A, + void* cuda_stream ); #ifdef __cplusplus diff --git a/mops/include/mops/hpe.hpp b/mops/include/mops/hpe.hpp index 472417f..1a66edd 100644 --- a/mops/include/mops/hpe.hpp +++ b/mops/include/mops/hpe.hpp @@ -122,15 +122,27 @@ namespace cuda { /// CUDA version of mops::homogeneous_polynomial_evaluation template void MOPS_EXPORT homogeneous_polynomial_evaluation( - Tensor output, Tensor A, Tensor C, Tensor indices_A + Tensor output, + Tensor A, + Tensor C, + Tensor indices_A, + void* cuda_stream = nullptr ); extern template void homogeneous_polynomial_evaluation( - Tensor output, Tensor A, Tensor C, Tensor indices_A + Tensor output, + Tensor A, + Tensor C, + Tensor indices_A, + void* cuda_stream ); extern template void homogeneous_polynomial_evaluation( - Tensor output, Tensor A, Tensor C, Tensor indices_A + Tensor output, + Tensor A, + Tensor C, + Tensor indices_A, + void* cuda_stream ); template @@ -139,7 +151,8 @@ void MOPS_EXPORT homogeneous_polynomial_evaluation_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* cuda_stream = nullptr ); extern template void homogeneous_polynomial_evaluation_vjp( @@ -147,7 +160,8 @@ extern template void homogeneous_polynomial_evaluation_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* cuda_stream ); extern template void homogeneous_polynomial_evaluation_vjp( @@ -155,7 +169,8 @@ extern template void homogeneous_polynomial_evaluation_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* cuda_stream ); template @@ -166,7 +181,8 @@ void MOPS_EXPORT homogeneous_polynomial_evaluation_vjp_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* cuda_stream = nullptr ); extern template void homogeneous_polynomial_evaluation_vjp_vjp( @@ -176,7 +192,8 @@ extern template void homogeneous_polynomial_evaluation_vjp_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* cuda_stream ); extern template void homogeneous_polynomial_evaluation_vjp_vjp( @@ -186,7 +203,8 @@ extern template void homogeneous_polynomial_evaluation_vjp_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* cuda_stream ); } // namespace cuda diff --git a/mops/include/mops/opsa.h b/mops/include/mops/opsa.h index 7f608ed..e1ade1a 100644 --- a/mops/include/mops/opsa.h +++ b/mops/include/mops/opsa.h @@ -75,7 +75,8 @@ int MOPS_EXPORT mops_cuda_outer_product_scatter_add_f32( mops_tensor_3d_f32_t output, mops_tensor_2d_f32_t A, mops_tensor_2d_f32_t B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); /// CUDA version of mops::outer_product_scatter_add for 64-bit floats @@ -83,7 +84,8 @@ int MOPS_EXPORT mops_cuda_outer_product_scatter_add_f64( mops_tensor_3d_f64_t output, mops_tensor_2d_f64_t A, mops_tensor_2d_f64_t B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); /// CUDA version of mops::outer_product_scatter_add_vjp for 32-bit floats @@ -93,7 +95,8 @@ int MOPS_EXPORT mops_cuda_outer_product_scatter_add_vjp_f32( mops_tensor_3d_f32_t grad_output, mops_tensor_2d_f32_t A, mops_tensor_2d_f32_t B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); /// CUDA version of mops::outer_product_scatter_add_vjp for 64-bit floats @@ -103,7 +106,8 @@ int MOPS_EXPORT mops_cuda_outer_product_scatter_add_vjp_f64( mops_tensor_3d_f64_t grad_output, mops_tensor_2d_f64_t A, mops_tensor_2d_f64_t B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); /// CUDA version of mops::outer_product_scatter_add_vjp_vjp for 32-bit floats @@ -116,7 +120,8 @@ int MOPS_EXPORT mops_cuda_outer_product_scatter_add_vjp_vjp_f32( mops_tensor_3d_f32_t grad_output, mops_tensor_2d_f32_t A, mops_tensor_2d_f32_t B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); /// CUDA version of mops::outer_product_scatter_add_vjp_vjp for 64-bit floats @@ -129,7 +134,8 @@ int MOPS_EXPORT mops_cuda_outer_product_scatter_add_vjp_vjp_f64( mops_tensor_3d_f64_t grad_output, mops_tensor_2d_f64_t A, mops_tensor_2d_f64_t B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); #ifdef __cplusplus diff --git a/mops/include/mops/opsa.hpp b/mops/include/mops/opsa.hpp index 9e7543f..eb6fbe9 100644 --- a/mops/include/mops/opsa.hpp +++ b/mops/include/mops/opsa.hpp @@ -149,15 +149,24 @@ void MOPS_EXPORT outer_product_scatter_add( Tensor output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream = nullptr ); extern template void outer_product_scatter_add( - Tensor output, Tensor A, Tensor B, Tensor indices_output + Tensor output, + Tensor A, + Tensor B, + Tensor indices_output, + void* cuda_stream ); extern template void outer_product_scatter_add( - Tensor output, Tensor A, Tensor B, Tensor indices_output + Tensor output, + Tensor A, + Tensor B, + Tensor indices_output, + void* cuda_stream ); template @@ -167,7 +176,8 @@ void MOPS_EXPORT outer_product_scatter_add_vjp( Tensor grad_output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream = nullptr ); // these templates will be precompiled and provided in the mops library @@ -177,7 +187,8 @@ extern template void outer_product_scatter_add_vjp( Tensor grad_output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); extern template void outer_product_scatter_add_vjp( @@ -186,7 +197,8 @@ extern template void outer_product_scatter_add_vjp( Tensor grad_output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); /// TODO @@ -200,7 +212,8 @@ void MOPS_EXPORT outer_product_scatter_add_vjp_vjp( Tensor grad_output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream = nullptr ); // these templates will be precompiled and provided in the mops library @@ -213,7 +226,8 @@ extern template void outer_product_scatter_add_vjp_vjp( Tensor grad_output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); extern template void outer_product_scatter_add_vjp_vjp( @@ -225,7 +239,8 @@ extern template void outer_product_scatter_add_vjp_vjp( Tensor grad_output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); } // namespace cuda diff --git a/mops/include/mops/opsaw.h b/mops/include/mops/opsaw.h index feafb11..ffa8566 100644 --- a/mops/include/mops/opsaw.h +++ b/mops/include/mops/opsaw.h @@ -95,7 +95,8 @@ int MOPS_EXPORT mops_cuda_outer_product_scatter_add_with_weights_f32( mops_tensor_2d_f32_t B, mops_tensor_2d_f32_t W, mops_tensor_1d_i32_t indices_W, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); /// CUDA version of mops::outer_product_scatter_add_with_weights for 64-bit floats @@ -105,7 +106,8 @@ int MOPS_EXPORT mops_cuda_outer_product_scatter_add_with_weights_f64( mops_tensor_2d_f64_t B, mops_tensor_2d_f64_t W, mops_tensor_1d_i32_t indices_W, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); /// CUDA version of mops::outer_product_scatter_add_with_weights_vjp for 32-bit floats @@ -118,7 +120,8 @@ int MOPS_EXPORT mops_cuda_outer_product_scatter_add_with_weights_vjp_f32( mops_tensor_2d_f32_t B, mops_tensor_2d_f32_t W, mops_tensor_1d_i32_t indices_W, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); /// CUDA version of mops::outer_product_scatter_add_with_weights_vjp for 64-bit floats @@ -131,7 +134,8 @@ int MOPS_EXPORT mops_cuda_outer_product_scatter_add_with_weights_vjp_f64( mops_tensor_2d_f64_t B, mops_tensor_2d_f64_t W, mops_tensor_1d_i32_t indices_W, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); /// CUDA version of mops::outer_product_scatter_add_with_weights_vjp_vjp for 32-bit floats @@ -148,7 +152,8 @@ int MOPS_EXPORT mops_cuda_outer_product_scatter_add_with_weights_vjp_vjp_f32( mops_tensor_2d_f32_t B, mops_tensor_2d_f32_t W, mops_tensor_1d_i32_t indices_W, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); /// CUDA version of mops::outer_product_scatter_add_with_weights_vjp_vjp for 64-bit floats @@ -165,7 +170,8 @@ int MOPS_EXPORT mops_cuda_outer_product_scatter_add_with_weights_vjp_vjp_f64( mops_tensor_2d_f64_t B, mops_tensor_2d_f64_t W, mops_tensor_1d_i32_t indices_W, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); #ifdef __cplusplus diff --git a/mops/include/mops/opsaw.hpp b/mops/include/mops/opsaw.hpp index 656bdde..3f71db6 100644 --- a/mops/include/mops/opsaw.hpp +++ b/mops/include/mops/opsaw.hpp @@ -194,7 +194,8 @@ void MOPS_EXPORT outer_product_scatter_add_with_weights( Tensor B, Tensor W, Tensor indices_W, - Tensor indices_output + Tensor indices_output, + void* cuda_stream = nullptr ); extern template void outer_product_scatter_add_with_weights( @@ -203,7 +204,8 @@ extern template void outer_product_scatter_add_with_weights( Tensor B, Tensor W, Tensor indices_W, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); extern template void outer_product_scatter_add_with_weights( @@ -212,7 +214,8 @@ extern template void outer_product_scatter_add_with_weights( Tensor B, Tensor W, Tensor indices_W, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); /// CUDA version of mops::outer_product_scatter_add_with_weights_vjp @@ -226,7 +229,8 @@ void MOPS_EXPORT outer_product_scatter_add_with_weights_vjp( Tensor B, Tensor W, Tensor indices_W, - Tensor indices_output + Tensor indices_output, + void* cuda_stream = nullptr ); extern template void outer_product_scatter_add_with_weights_vjp( @@ -238,7 +242,8 @@ extern template void outer_product_scatter_add_with_weights_vjp( Tensor B, Tensor W, Tensor indices_W, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); extern template void outer_product_scatter_add_with_weights_vjp( @@ -250,7 +255,8 @@ extern template void outer_product_scatter_add_with_weights_vjp( Tensor B, Tensor W, Tensor indices_W, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); /// TODO @@ -268,7 +274,8 @@ void MOPS_EXPORT outer_product_scatter_add_with_weights_vjp_vjp( Tensor B, Tensor W, Tensor indices_W, - Tensor indices_output + Tensor indices_output, + void* cuda_stream = nullptr ); // these templates will be precompiled and provided in the mops library @@ -285,7 +292,8 @@ extern template void outer_product_scatter_add_with_weights_vjp_vjp( Tensor B, Tensor W, Tensor indices_W, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); extern template void outer_product_scatter_add_with_weights_vjp_vjp( @@ -301,7 +309,8 @@ extern template void outer_product_scatter_add_with_weights_vjp_vjp( Tensor B, Tensor W, Tensor indices_W, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); } // namespace cuda diff --git a/mops/include/mops/sap.h b/mops/include/mops/sap.h index 5990e57..210b9bf 100644 --- a/mops/include/mops/sap.h +++ b/mops/include/mops/sap.h @@ -96,7 +96,8 @@ int MOPS_EXPORT mops_cuda_sparse_accumulation_of_products_f32( mops_tensor_1d_f32_t C, mops_tensor_1d_i32_t indices_A, mops_tensor_1d_i32_t indices_B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); /// CUDA version of mops::sparse_accumulation_of_products for 64-bit floats @@ -107,7 +108,8 @@ int MOPS_EXPORT mops_cuda_sparse_accumulation_of_products_f64( mops_tensor_1d_f64_t C, mops_tensor_1d_i32_t indices_A, mops_tensor_1d_i32_t indices_B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); /// CUDA version of mops::sparse_accumulation_of_products_vjp for 32-bit floats @@ -120,7 +122,8 @@ int MOPS_EXPORT mops_cuda_sparse_accumulation_of_products_vjp_f32( mops_tensor_1d_f32_t C, mops_tensor_1d_i32_t indices_A, mops_tensor_1d_i32_t indices_B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); /// CUDA version of mops::sparse_accumulation_of_products_vjp for 64-bit floats @@ -133,7 +136,8 @@ int MOPS_EXPORT mops_cuda_sparse_accumulation_of_products_vjp_f64( mops_tensor_1d_f64_t C, mops_tensor_1d_i32_t indices_A, mops_tensor_1d_i32_t indices_B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); /// CUDA version of mops::sparse_accumulation_of_products_vjp_vjp for 32-bit floats @@ -149,7 +153,8 @@ int MOPS_EXPORT mops_cuda_sparse_accumulation_of_products_vjp_vjp_f32( mops_tensor_1d_f32_t C, mops_tensor_1d_i32_t indices_A, mops_tensor_1d_i32_t indices_B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); /// CUDA version of mops::sparse_accumulation_of_products_vjp_vjp for 64-bit floats @@ -165,7 +170,8 @@ int MOPS_EXPORT mops_cuda_sparse_accumulation_of_products_vjp_vjp_f64( mops_tensor_1d_f64_t C, mops_tensor_1d_i32_t indices_A, mops_tensor_1d_i32_t indices_B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ); #ifdef __cplusplus diff --git a/mops/include/mops/sap.hpp b/mops/include/mops/sap.hpp index a27090a..4568297 100644 --- a/mops/include/mops/sap.hpp +++ b/mops/include/mops/sap.hpp @@ -191,7 +191,8 @@ void MOPS_EXPORT sparse_accumulation_of_products( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream = nullptr ); extern template void sparse_accumulation_of_products( @@ -201,7 +202,8 @@ extern template void sparse_accumulation_of_products( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); extern template void sparse_accumulation_of_products( @@ -211,7 +213,8 @@ extern template void sparse_accumulation_of_products( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); /// CUDA version of mops::sparse_accumulation_of_products_vjp @@ -225,7 +228,8 @@ void MOPS_EXPORT sparse_accumulation_of_products_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream = nullptr ); extern template void sparse_accumulation_of_products_vjp( @@ -237,7 +241,8 @@ extern template void sparse_accumulation_of_products_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); extern template void sparse_accumulation_of_products_vjp( @@ -249,7 +254,8 @@ extern template void sparse_accumulation_of_products_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); /// TODO @@ -266,7 +272,8 @@ void MOPS_EXPORT sparse_accumulation_of_products_vjp_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream = nullptr ); // these templates will be precompiled and provided in the mops library @@ -282,7 +289,8 @@ extern template void sparse_accumulation_of_products_vjp_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); extern template void sparse_accumulation_of_products_vjp_vjp( @@ -297,7 +305,8 @@ extern template void sparse_accumulation_of_products_vjp_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); } // namespace cuda diff --git a/mops/include/mops/sasaw.h b/mops/include/mops/sasaw.h index 75170d7..d87560d 100644 --- a/mops/include/mops/sasaw.h +++ b/mops/include/mops/sasaw.h @@ -127,7 +127,8 @@ int MOPS_EXPORT mops_cuda_sparse_accumulation_scatter_add_with_weights_f32( mops_tensor_1d_i32_t indices_W_1, mops_tensor_1d_i32_t indices_W_2, mops_tensor_1d_i32_t indices_output_1, - mops_tensor_1d_i32_t indices_output_2 + mops_tensor_1d_i32_t indices_output_2, + void* cuda_stream ); /// CUDA version of mops::sparse_accumulation_scatter_add_with for 64-bit floats @@ -141,7 +142,8 @@ int MOPS_EXPORT mops_cuda_sparse_accumulation_scatter_add_with_weights_f64( mops_tensor_1d_i32_t indices_W_1, mops_tensor_1d_i32_t indices_W_2, mops_tensor_1d_i32_t indices_output_1, - mops_tensor_1d_i32_t indices_output_2 + mops_tensor_1d_i32_t indices_output_2, + void* cuda_stream ); /// CUDA version of mops::sparse_accumulation_scatter_add_with_weights_vjp for @@ -159,7 +161,8 @@ int MOPS_EXPORT mops_cuda_sparse_accumulation_scatter_add_with_weights_vjp_f32( mops_tensor_1d_i32_t indices_W_1, mops_tensor_1d_i32_t indices_W_2, mops_tensor_1d_i32_t indices_output_1, - mops_tensor_1d_i32_t indices_output_2 + mops_tensor_1d_i32_t indices_output_2, + void* cuda_stream ); /// CUDA version of mops::sparse_accumulation_scatter_add_with_weights_vjp for @@ -177,7 +180,8 @@ int MOPS_EXPORT mops_cuda_sparse_accumulation_scatter_add_with_weights_vjp_f64( mops_tensor_1d_i32_t indices_W_1, mops_tensor_1d_i32_t indices_W_2, mops_tensor_1d_i32_t indices_output_1, - mops_tensor_1d_i32_t indices_output_2 + mops_tensor_1d_i32_t indices_output_2, + void* cuda_stream ); /// CUDA version of mops::sparse_accumulation_scatter_add_with_weights_vjp_vjp for @@ -199,7 +203,8 @@ int MOPS_EXPORT mops_cuda_sparse_accumulation_scatter_add_with_weights_vjp_vjp_f mops_tensor_1d_i32_t indices_W_1, mops_tensor_1d_i32_t indices_W_2, mops_tensor_1d_i32_t indices_output_1, - mops_tensor_1d_i32_t indices_output_2 + mops_tensor_1d_i32_t indices_output_2, + void* cuda_stream ); /// CUDA version of mops::sparse_accumulation_scatter_add_with_weights_vjp_vjp for @@ -221,7 +226,8 @@ int MOPS_EXPORT mops_cuda_sparse_accumulation_scatter_add_with_weights_vjp_vjp_f mops_tensor_1d_i32_t indices_W_1, mops_tensor_1d_i32_t indices_W_2, mops_tensor_1d_i32_t indices_output_1, - mops_tensor_1d_i32_t indices_output_2 + mops_tensor_1d_i32_t indices_output_2, + void* cuda_stream ); #ifdef __cplusplus diff --git a/mops/include/mops/sasaw.hpp b/mops/include/mops/sasaw.hpp index 075db69..8654eb9 100644 --- a/mops/include/mops/sasaw.hpp +++ b/mops/include/mops/sasaw.hpp @@ -253,7 +253,8 @@ void MOPS_EXPORT sparse_accumulation_scatter_add_with_weights( Tensor indices_W_1, Tensor indices_W_2, Tensor indices_output_1, - Tensor indices_output_2 + Tensor indices_output_2, + void* cuda_stream = nullptr ); extern template void sparse_accumulation_scatter_add_with_weights( @@ -266,7 +267,8 @@ extern template void sparse_accumulation_scatter_add_with_weights( Tensor indices_W_1, Tensor indices_W_2, Tensor indices_output_1, - Tensor indices_output_2 + Tensor indices_output_2, + void* cuda_stream ); extern template void sparse_accumulation_scatter_add_with_weights( @@ -279,7 +281,8 @@ extern template void sparse_accumulation_scatter_add_with_weights( Tensor indices_W_1, Tensor indices_W_2, Tensor indices_output_1, - Tensor indices_output_2 + Tensor indices_output_2, + void* cuda_stream ); /// CUDA version of mops::sparse_accumulation_scatter_add_with_weights_vjp @@ -297,7 +300,8 @@ void MOPS_EXPORT sparse_accumulation_scatter_add_with_weights_vjp( Tensor indices_W_1, Tensor indices_W_2, Tensor indices_output_1, - Tensor indices_output_2 + Tensor indices_output_2, + void* cuda_stream = nullptr ); extern template void sparse_accumulation_scatter_add_with_weights_vjp( @@ -313,7 +317,8 @@ extern template void sparse_accumulation_scatter_add_with_weights_vjp( Tensor indices_W_1, Tensor indices_W_2, Tensor indices_output_1, - Tensor indices_output_2 + Tensor indices_output_2, + void* cuda_stream ); extern template void sparse_accumulation_scatter_add_with_weights_vjp( @@ -329,7 +334,8 @@ extern template void sparse_accumulation_scatter_add_with_weights_vjp( Tensor indices_W_1, Tensor indices_W_2, Tensor indices_output_1, - Tensor indices_output_2 + Tensor indices_output_2, + void* cuda_stream ); /// TODO @@ -351,7 +357,8 @@ void MOPS_EXPORT sparse_accumulation_scatter_add_with_weights_vjp_vjp( Tensor indices_W_1, Tensor indices_W_2, Tensor indices_output_1, - Tensor indices_output_2 + Tensor indices_output_2, + void* cuda_stream = nullptr ); // these templates will be precompiled and provided in the mops library @@ -372,7 +379,8 @@ extern template void sparse_accumulation_scatter_add_with_weights_vjp_vjp( Tensor indices_W_1, Tensor indices_W_2, Tensor indices_output_1, - Tensor indices_output_2 + Tensor indices_output_2, + void* cuda_stream ); extern template void sparse_accumulation_scatter_add_with_weights_vjp_vjp( @@ -392,7 +400,8 @@ extern template void sparse_accumulation_scatter_add_with_weights_vjp_vjp( Tensor indices_W_1, Tensor indices_W_2, Tensor indices_output_1, - Tensor indices_output_2 + Tensor indices_output_2, + void* cuda_stream ); } // namespace cuda diff --git a/mops/src/hpe/capi.cpp b/mops/src/hpe/capi.cpp index 7dd5e92..01e3056 100644 --- a/mops/src/hpe/capi.cpp +++ b/mops/src/hpe/capi.cpp @@ -132,14 +132,16 @@ extern "C" int mops_cuda_homogeneous_polynomial_evaluation_f32( mops_tensor_1d_f32_t output, mops_tensor_2d_f32_t A, mops_tensor_1d_f32_t C, - mops_tensor_2d_i32_t indices_A + mops_tensor_2d_i32_t indices_A, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::homogeneous_polynomial_evaluation( {output.data, {checked_cast(output.shape[0])}}, {A.data, {checked_cast(A.shape[0]), checked_cast(A.shape[1])}}, {C.data, {checked_cast(C.shape[0])}}, - {indices_A.data, {checked_cast(indices_A.shape[0]), checked_cast(indices_A.shape[1])}} + {indices_A.data, {checked_cast(indices_A.shape[0]), checked_cast(indices_A.shape[1])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -148,14 +150,16 @@ extern "C" int mops_cuda_homogeneous_polynomial_evaluation_f64( mops_tensor_1d_f64_t output, mops_tensor_2d_f64_t A, mops_tensor_1d_f64_t C, - mops_tensor_2d_i32_t indices_A + mops_tensor_2d_i32_t indices_A, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::homogeneous_polynomial_evaluation( {output.data, {checked_cast(output.shape[0])}}, {A.data, {checked_cast(A.shape[0]), checked_cast(A.shape[1])}}, {C.data, {checked_cast(C.shape[0])}}, - {indices_A.data, {checked_cast(indices_A.shape[0]), checked_cast(indices_A.shape[1])}} + {indices_A.data, {checked_cast(indices_A.shape[0]), checked_cast(indices_A.shape[1])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -165,7 +169,8 @@ extern "C" int mops_cuda_homogeneous_polynomial_evaluation_vjp_f32( mops_tensor_1d_f32_t grad_output, mops_tensor_2d_f32_t A, mops_tensor_1d_f32_t C, - mops_tensor_2d_i32_t indices_A + mops_tensor_2d_i32_t indices_A, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::homogeneous_polynomial_evaluation_vjp( @@ -173,7 +178,8 @@ extern "C" int mops_cuda_homogeneous_polynomial_evaluation_vjp_f32( {grad_output.data, {checked_cast(grad_output.shape[0])}}, {A.data, {checked_cast(A.shape[0]), checked_cast(A.shape[1])}}, {C.data, {checked_cast(C.shape[0])}}, - {indices_A.data, {checked_cast(indices_A.shape[0]), checked_cast(indices_A.shape[1])}} + {indices_A.data, {checked_cast(indices_A.shape[0]), checked_cast(indices_A.shape[1])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -183,7 +189,8 @@ extern "C" int mops_cuda_homogeneous_polynomial_evaluation_vjp_f64( mops_tensor_1d_f64_t grad_output, mops_tensor_2d_f64_t A, mops_tensor_1d_f64_t C, - mops_tensor_2d_i32_t indices_A + mops_tensor_2d_i32_t indices_A, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::homogeneous_polynomial_evaluation_vjp( @@ -191,7 +198,8 @@ extern "C" int mops_cuda_homogeneous_polynomial_evaluation_vjp_f64( {grad_output.data, {checked_cast(grad_output.shape[0])}}, {A.data, {checked_cast(A.shape[0]), checked_cast(A.shape[1])}}, {C.data, {checked_cast(C.shape[0])}}, - {indices_A.data, {checked_cast(indices_A.shape[0]), checked_cast(indices_A.shape[1])}} + {indices_A.data, {checked_cast(indices_A.shape[0]), checked_cast(indices_A.shape[1])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -203,7 +211,8 @@ extern "C" int mops_cuda_homogeneous_polynomial_evaluation_vjp_vjp_f32( mops_tensor_1d_f32_t grad_output, mops_tensor_2d_f32_t A, mops_tensor_1d_f32_t C, - mops_tensor_2d_i32_t indices_A + mops_tensor_2d_i32_t indices_A, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::homogeneous_polynomial_evaluation_vjp_vjp( @@ -214,7 +223,8 @@ extern "C" int mops_cuda_homogeneous_polynomial_evaluation_vjp_vjp_f32( {grad_output.data, {checked_cast(grad_output.shape[0])}}, {A.data, {checked_cast(A.shape[0]), checked_cast(A.shape[1])}}, {C.data, {checked_cast(C.shape[0])}}, - {indices_A.data, {checked_cast(indices_A.shape[0]), checked_cast(indices_A.shape[1])}} + {indices_A.data, {checked_cast(indices_A.shape[0]), checked_cast(indices_A.shape[1])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -226,7 +236,8 @@ extern "C" int mops_cuda_homogeneous_polynomial_evaluation_vjp_vjp_f64( mops_tensor_1d_f64_t grad_output, mops_tensor_2d_f64_t A, mops_tensor_1d_f64_t C, - mops_tensor_2d_i32_t indices_A + mops_tensor_2d_i32_t indices_A, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::homogeneous_polynomial_evaluation_vjp_vjp( @@ -237,7 +248,8 @@ extern "C" int mops_cuda_homogeneous_polynomial_evaluation_vjp_vjp_f64( {grad_output.data, {checked_cast(grad_output.shape[0])}}, {A.data, {checked_cast(A.shape[0]), checked_cast(A.shape[1])}}, {C.data, {checked_cast(C.shape[0])}}, - {indices_A.data, {checked_cast(indices_A.shape[0]), checked_cast(indices_A.shape[1])}} + {indices_A.data, {checked_cast(indices_A.shape[0]), checked_cast(indices_A.shape[1])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } diff --git a/mops/src/hpe/hpe.cpp b/mops/src/hpe/hpe.cpp index 1fcf0ba..52602ae 100644 --- a/mops/src/hpe/hpe.cpp +++ b/mops/src/hpe/hpe.cpp @@ -48,29 +48,37 @@ template void mops::homogeneous_polynomial_evaluation_vjp_vjp( #ifndef MOPS_CUDA_ENABLED template void mops::cuda:: - homogeneous_polynomial_evaluation(Tensor, Tensor, Tensor, Tensor) { + homogeneous_polynomial_evaluation(Tensor, Tensor, Tensor, Tensor, void*) { throw std::runtime_error("MOPS was not compiled with CUDA support"); } template void mops::cuda:: - homogeneous_polynomial_evaluation_vjp(Tensor, Tensor, Tensor, Tensor, Tensor) { + homogeneous_polynomial_evaluation_vjp(Tensor, Tensor, Tensor, Tensor, Tensor, void*) { throw std::runtime_error("MOPS was not compiled with CUDA support"); } template void mops::cuda:: - homogeneous_polynomial_evaluation_vjp_vjp(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) { + homogeneous_polynomial_evaluation_vjp_vjp(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, void*) { throw std::runtime_error("MOPS was not compiled with CUDA support"); } // explicit instantiations of CUDA templates template void mops::cuda::homogeneous_polynomial_evaluation( - Tensor output, Tensor A, Tensor C, Tensor indices_A + Tensor output, + Tensor A, + Tensor C, + Tensor indices_A, + void* stream ); template void mops::cuda::homogeneous_polynomial_evaluation( - Tensor output, Tensor A, Tensor C, Tensor indices_A + Tensor output, + Tensor A, + Tensor C, + Tensor indices_A, + void* stream ); template void mops::cuda::homogeneous_polynomial_evaluation_vjp( @@ -78,7 +86,8 @@ template void mops::cuda::homogeneous_polynomial_evaluation_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* stream ); template void mops::cuda::homogeneous_polynomial_evaluation_vjp( @@ -86,7 +95,8 @@ template void mops::cuda::homogeneous_polynomial_evaluation_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* stream ); template void mops::cuda::homogeneous_polynomial_evaluation_vjp_vjp( @@ -96,7 +106,8 @@ template void mops::cuda::homogeneous_polynomial_evaluation_vjp_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* stream ); template void mops::cuda::homogeneous_polynomial_evaluation_vjp_vjp( @@ -106,7 +117,8 @@ template void mops::cuda::homogeneous_polynomial_evaluation_vjp_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* stream ); #endif diff --git a/mops/src/hpe/hpe.cu b/mops/src/hpe/hpe.cu index 964c4b4..c0b8669 100644 --- a/mops/src/hpe/hpe.cu +++ b/mops/src/hpe/hpe.cu @@ -53,14 +53,22 @@ __global__ void homogeneous_polynomial_evaluation_kernel( __syncthreads(); - int32_t i_monomial = threadIdx.x % polynomial_order; - int32_t x = threadIdx.x / polynomial_order; - int32_t nx = blockDim.x / polynomial_order; - - for (int lbasis = x; lbasis < blockDim.x; lbasis += nx) { - if (i_monomial * blockDim.x + lbasis < polynomial_order * blockDim.x) { - buffer_indices_A[i_monomial * blockDim.x + lbasis] = - indices_A.data[(i + lbasis) * polynomial_order + i_monomial]; + __syncthreads(); + + int32_t i_monomial; + int32_t x; + int32_t nx; + + if (polynomial_order > 0) { + i_monomial = threadIdx.x % polynomial_order; + x = threadIdx.x / polynomial_order; + nx = blockDim.x / polynomial_order; + + for (int lbasis = x; lbasis < blockDim.x; lbasis += nx) { + if (i_monomial * blockDim.x + lbasis < polynomial_order * blockDim.x) { + buffer_indices_A[i_monomial * blockDim.x + lbasis] = + indices_A.data[(i + lbasis) * polynomial_order + i_monomial]; + } } } @@ -108,11 +116,25 @@ __global__ void homogeneous_polynomial_evaluation_kernel( template void mops::cuda::homogeneous_polynomial_evaluation( - Tensor output, Tensor A, Tensor C, Tensor indices_A + Tensor output, + Tensor A, + Tensor C, + Tensor indices_A, + void* cuda_stream ) { check_hpe(output, A, C, indices_A, "cuda_homogeneous_polynomial_evaluation"); + cudaPointerAttributes attributes; + CUDA_CHECK_ERROR(cudaPointerGetAttributes(&attributes, A.data)); + int current_device; + CUDA_CHECK_ERROR(cudaGetDevice(¤t_device)); + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(attributes.device)); + } + + cudaStream_t cstream = reinterpret_cast(cuda_stream); + int32_t nbatch = output.shape[0]; int32_t nnu1 = A.shape[1]; size_t polynomial_order = indices_A.shape[1]; @@ -132,47 +154,47 @@ void mops::cuda::homogeneous_polynomial_evaluation( switch (polynomial_order) { case 0: homogeneous_polynomial_evaluation_kernel - <<>>(output, A, C, indices_A); + <<>>(output, A, C, indices_A); break; case 1: homogeneous_polynomial_evaluation_kernel - <<>>(output, A, C, indices_A); + <<>>(output, A, C, indices_A); break; case 2: homogeneous_polynomial_evaluation_kernel - <<>>(output, A, C, indices_A); + <<>>(output, A, C, indices_A); break; case 3: homogeneous_polynomial_evaluation_kernel - <<>>(output, A, C, indices_A); + <<>>(output, A, C, indices_A); break; case 4: homogeneous_polynomial_evaluation_kernel - <<>>(output, A, C, indices_A); + <<>>(output, A, C, indices_A); break; case 5: homogeneous_polynomial_evaluation_kernel - <<>>(output, A, C, indices_A); + <<>>(output, A, C, indices_A); break; case 6: homogeneous_polynomial_evaluation_kernel - <<>>(output, A, C, indices_A); + <<>>(output, A, C, indices_A); break; case 7: homogeneous_polynomial_evaluation_kernel - <<>>(output, A, C, indices_A); + <<>>(output, A, C, indices_A); break; case 8: homogeneous_polynomial_evaluation_kernel - <<>>(output, A, C, indices_A); + <<>>(output, A, C, indices_A); break; case 9: homogeneous_polynomial_evaluation_kernel - <<>>(output, A, C, indices_A); + <<>>(output, A, C, indices_A); break; case 10: homogeneous_polynomial_evaluation_kernel - <<>>(output, A, C, indices_A); + <<>>(output, A, C, indices_A); break; default: break; @@ -180,17 +202,28 @@ void mops::cuda::homogeneous_polynomial_evaluation( } CUDA_CHECK_ERROR(cudaGetLastError()); + CUDA_CHECK_ERROR(cudaStreamSynchronize(cstream)); - CUDA_CHECK_ERROR(cudaDeviceSynchronize()); + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(current_device)); + } } // explicit instanciations of CUDA templates template void mops::cuda::homogeneous_polynomial_evaluation( - Tensor output, Tensor A, Tensor C, Tensor indices_A + Tensor output, + Tensor A, + Tensor C, + Tensor indices_A, + void* cuda_stream ); template void mops::cuda::homogeneous_polynomial_evaluation( - Tensor output, Tensor A, Tensor C, Tensor indices_A + Tensor output, + Tensor A, + Tensor C, + Tensor indices_A, + void* cuda_stream ); template @@ -231,51 +264,56 @@ __global__ void homogeneous_polynomial_evaluation_vjp_kernel( __syncthreads(); scalar_t gout = grad_output.data[batch_id]; + if (polynomial_order > 0) { + // indices_A : nbasis, polynomial_order + for (int32_t i = 0; i < nbasis; i += blockDim.x) { - // indices_A : nbasis, polynomial_order - for (int32_t i = 0; i < nbasis; i += blockDim.x) { + __syncthreads(); - __syncthreads(); + int32_t basis = i + threadIdx.x; + + int32_t i_monomial; + int32_t x; + int32_t nx; - int32_t i_monomial = threadIdx.x % polynomial_order; - int32_t x = threadIdx.x / polynomial_order; - int32_t nx = blockDim.x / polynomial_order; + i_monomial = threadIdx.x % polynomial_order; + x = threadIdx.x / polynomial_order; + nx = blockDim.x / polynomial_order; - for (int lbasis = x; lbasis < blockDim.x; lbasis += nx) { - if (i_monomial * blockDim.x + lbasis < polynomial_order * blockDim.x) { - buffer_indices_A[i_monomial * blockDim.x + lbasis] = - indices_A.data[(i + lbasis) * polynomial_order + i_monomial]; + for (int lbasis = x; lbasis < blockDim.x; lbasis += nx) { + if (i_monomial * blockDim.x + lbasis < polynomial_order * blockDim.x) { + buffer_indices_A[i_monomial * blockDim.x + lbasis] = + indices_A.data[(i + lbasis) * polynomial_order + i_monomial]; + } } - } - __syncthreads(); + __syncthreads(); - int32_t basis = i + threadIdx.x; + if (basis < nbasis) { - if (basis < nbasis) { + scalar_t c = C.data[basis] * gout; - scalar_t c = C.data[basis] * gout; + for (int32_t i_monomial = 0; i_monomial < polynomial_order; i_monomial++) { - for (int32_t i_monomial = 0; i_monomial < polynomial_order; i_monomial++) { + scalar_t tmp_i = c; + + for (int32_t j_monomial = 0; j_monomial < polynomial_order; j_monomial++) { - scalar_t tmp_i = c; + if (i_monomial == j_monomial) { + continue; + } - for (int32_t j_monomial = 0; j_monomial < polynomial_order; j_monomial++) { + int32_t idx_j = buffer_indices_A + [j_monomial * blockDim.x + threadIdx.x]; // indices_A.data[j_monomial + // * indices_A.shape[0] + basis]; - if (i_monomial == j_monomial) { - continue; + tmp_i *= buffer_nu1[idx_j]; } - int32_t idx_j = - buffer_indices_A[j_monomial * blockDim.x + threadIdx.x]; // indices_A.data[j_monomial - // * indices_A.shape[0] + basis]; + int32_t idx_i = buffer_indices_A[i_monomial * blockDim.x + threadIdx.x]; - tmp_i *= buffer_nu1[idx_j]; + ATOMIC_ADD(&buffer_gradA[idx_i], tmp_i); } - - int32_t idx_i = buffer_indices_A[i_monomial * blockDim.x + threadIdx.x]; - - atomicAdd(&buffer_gradA[idx_i], tmp_i); } } } @@ -283,7 +321,11 @@ __global__ void homogeneous_polynomial_evaluation_vjp_kernel( __syncthreads(); for (int32_t i = threadIdx.x; i < nnu1; i += blockDim.x) { - grad_A.data[batch_id * nnu1 + i] = buffer_gradA[i]; + if (polynomial_order > 0) { + grad_A.data[batch_id * nnu1 + i] = buffer_gradA[i]; + } else { + grad_A.data[batch_id * nnu1 + i] = 0.0; + } } } @@ -293,10 +335,21 @@ void mops::cuda::homogeneous_polynomial_evaluation_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* cuda_stream ) { check_hpe_vjp(grad_A, grad_output, A, C, indices_A, "cuda_homogeneous_polynomial_evaluation_vjp"); + cudaPointerAttributes attributes; + CUDA_CHECK_ERROR(cudaPointerGetAttributes(&attributes, A.data)); + int current_device; + CUDA_CHECK_ERROR(cudaGetDevice(¤t_device)); + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(attributes.device)); + } + + cudaStream_t cstream = reinterpret_cast(cuda_stream); + int32_t nbatch = grad_output.shape[0]; int32_t nnu1 = A.shape[1]; size_t polynomial_order = indices_A.shape[1]; @@ -315,47 +368,47 @@ void mops::cuda::homogeneous_polynomial_evaluation_vjp( switch (polynomial_order) { case 0: homogeneous_polynomial_evaluation_vjp_kernel - <<>>(grad_A, grad_output, A, C, indices_A); + <<>>(grad_A, grad_output, A, C, indices_A); break; case 1: homogeneous_polynomial_evaluation_vjp_kernel - <<>>(grad_A, grad_output, A, C, indices_A); + <<>>(grad_A, grad_output, A, C, indices_A); break; case 2: homogeneous_polynomial_evaluation_vjp_kernel - <<>>(grad_A, grad_output, A, C, indices_A); + <<>>(grad_A, grad_output, A, C, indices_A); break; case 3: homogeneous_polynomial_evaluation_vjp_kernel - <<>>(grad_A, grad_output, A, C, indices_A); + <<>>(grad_A, grad_output, A, C, indices_A); break; case 4: homogeneous_polynomial_evaluation_vjp_kernel - <<>>(grad_A, grad_output, A, C, indices_A); + <<>>(grad_A, grad_output, A, C, indices_A); break; case 5: homogeneous_polynomial_evaluation_vjp_kernel - <<>>(grad_A, grad_output, A, C, indices_A); + <<>>(grad_A, grad_output, A, C, indices_A); break; case 6: homogeneous_polynomial_evaluation_vjp_kernel - <<>>(grad_A, grad_output, A, C, indices_A); + <<>>(grad_A, grad_output, A, C, indices_A); break; case 7: homogeneous_polynomial_evaluation_vjp_kernel - <<>>(grad_A, grad_output, A, C, indices_A); + <<>>(grad_A, grad_output, A, C, indices_A); break; case 8: homogeneous_polynomial_evaluation_vjp_kernel - <<>>(grad_A, grad_output, A, C, indices_A); + <<>>(grad_A, grad_output, A, C, indices_A); break; case 9: homogeneous_polynomial_evaluation_vjp_kernel - <<>>(grad_A, grad_output, A, C, indices_A); + <<>>(grad_A, grad_output, A, C, indices_A); break; case 10: homogeneous_polynomial_evaluation_vjp_kernel - <<>>(grad_A, grad_output, A, C, indices_A); + <<>>(grad_A, grad_output, A, C, indices_A); break; default: break; @@ -363,8 +416,11 @@ void mops::cuda::homogeneous_polynomial_evaluation_vjp( } CUDA_CHECK_ERROR(cudaGetLastError()); + CUDA_CHECK_ERROR(cudaStreamSynchronize(cstream)); - CUDA_CHECK_ERROR(cudaDeviceSynchronize()); + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(current_device)); + } } // explicit instanciations of CUDA templates @@ -373,7 +429,8 @@ template void mops::cuda::homogeneous_polynomial_evaluation_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* cuda_stream ); template void mops::cuda::homogeneous_polynomial_evaluation_vjp( @@ -381,7 +438,8 @@ template void mops::cuda::homogeneous_polynomial_evaluation_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* cuda_stream ); template @@ -392,7 +450,8 @@ void mops::cuda::homogeneous_polynomial_evaluation_vjp_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* cuda_stream ) { throw std::runtime_error("Not implemented"); } @@ -405,7 +464,8 @@ template void mops::cuda::homogeneous_polynomial_evaluation_vjp_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* cuda_stream ); template void mops::cuda::homogeneous_polynomial_evaluation_vjp_vjp( @@ -415,5 +475,6 @@ template void mops::cuda::homogeneous_polynomial_evaluation_vjp_vjp( Tensor grad_output, Tensor A, Tensor C, - Tensor indices_A + Tensor indices_A, + void* cuda_stream ); diff --git a/mops/src/internal/cuda_utils.cu b/mops/src/internal/cuda_utils.cu index 3af533e..0c09ef4 100644 --- a/mops/src/internal/cuda_utils.cu +++ b/mops/src/internal/cuda_utils.cu @@ -11,6 +11,37 @@ __host__ __device__ int32_t find_integer_divisor(int32_t x, int32_t bdim) { return (x + bdim - 1) / bdim; } +__device__ double atomicAdd_presm60(double* address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + + do { + assumed = old; + old = atomicCAS( + address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)) + ); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __longlong_as_double(old); +} + +template __device__ scalar_t ATOMIC_ADD(scalar_t* address, scalar_t val) { +#if __CUDA_ARCH__ < 600 + if constexpr (sizeof(scalar_t) == 4) { + return atomicAdd(address, val); + } else if constexpr (sizeof(scalar_t) == 8) { + return atomicAdd_presm60(address, val); + } +#else + return atomicAdd(address, val); +#endif +} + +template __device__ float ATOMIC_ADD(float* address, float val); +template __device__ double ATOMIC_ADD(double* address, double val); + template __host__ __device__ T* shared_array(std::size_t n_elements, void*& ptr, std::size_t* space) noexcept { const std::uintptr_t inptr = reinterpret_cast(ptr); diff --git a/mops/src/internal/cuda_utils.cuh b/mops/src/internal/cuda_utils.cuh index b8c25be..83645cf 100644 --- a/mops/src/internal/cuda_utils.cuh +++ b/mops/src/internal/cuda_utils.cuh @@ -23,6 +23,17 @@ using namespace std; } \ } while (0) +/* + * Pre SM60 cards do not support atomicAdd(double *, double). This function implements and atomicCAS + * to lock update the address. + */ +__device__ double atomicAdd_presm60(double* address, double val); + +/* + * function to select the right version of atomicAdd for the archcode being compiled. + */ +template __device__ scalar_t ATOMIC_ADD(scalar_t* address, scalar_t val); + __host__ __device__ int32_t find_integer_divisor(int32_t x, int32_t bdim); /* diff --git a/mops/src/opsa/capi.cpp b/mops/src/opsa/capi.cpp index 2be0a2b..2098a6e 100644 --- a/mops/src/opsa/capi.cpp +++ b/mops/src/opsa/capi.cpp @@ -172,7 +172,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_f32( mops_tensor_3d_f32_t output, mops_tensor_2d_f32_t A, mops_tensor_2d_f32_t B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::outer_product_scatter_add( @@ -182,7 +183,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_f32( checked_cast(output.shape[2])}}, {A.data, {checked_cast(A.shape[0]), checked_cast(A.shape[1])}}, {B.data, {checked_cast(B.shape[0]), checked_cast(B.shape[1])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -191,7 +193,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_f64( mops_tensor_3d_f64_t output, mops_tensor_2d_f64_t A, mops_tensor_2d_f64_t B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::outer_product_scatter_add( @@ -201,7 +204,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_f64( checked_cast(output.shape[2])}}, {A.data, {checked_cast(A.shape[0]), checked_cast(A.shape[1])}}, {B.data, {checked_cast(B.shape[0]), checked_cast(B.shape[1])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -212,7 +216,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_vjp_f32( mops_tensor_3d_f32_t grad_output, mops_tensor_2d_f32_t A, mops_tensor_2d_f32_t B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::outer_product_scatter_add_vjp( @@ -224,7 +229,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_vjp_f32( checked_cast(grad_output.shape[2])}}, {A.data, {checked_cast(A.shape[0]), checked_cast(A.shape[1])}}, {B.data, {checked_cast(B.shape[0]), checked_cast(B.shape[1])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -235,7 +241,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_vjp_f64( mops_tensor_3d_f64_t grad_output, mops_tensor_2d_f64_t A, mops_tensor_2d_f64_t B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::outer_product_scatter_add_vjp( @@ -247,7 +254,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_vjp_f64( checked_cast(grad_output.shape[2])}}, {A.data, {checked_cast(A.shape[0]), checked_cast(A.shape[1])}}, {B.data, {checked_cast(B.shape[0]), checked_cast(B.shape[1])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -261,7 +269,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_vjp_vjp_f32( mops_tensor_3d_f32_t grad_output, mops_tensor_2d_f32_t A, mops_tensor_2d_f32_t B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::outer_product_scatter_add_vjp_vjp( @@ -281,7 +290,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_vjp_vjp_f32( checked_cast(grad_output.shape[2])}}, {A.data, {checked_cast(A.shape[0]), checked_cast(A.shape[1])}}, {B.data, {checked_cast(B.shape[0]), checked_cast(B.shape[1])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -295,7 +305,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_vjp_vjp_f64( mops_tensor_3d_f64_t grad_output, mops_tensor_2d_f64_t A, mops_tensor_2d_f64_t B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::outer_product_scatter_add_vjp_vjp( @@ -315,7 +326,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_vjp_vjp_f64( checked_cast(grad_output.shape[2])}}, {A.data, {checked_cast(A.shape[0]), checked_cast(A.shape[1])}}, {B.data, {checked_cast(B.shape[0]), checked_cast(B.shape[1])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } diff --git a/mops/src/opsa/cpu.tpp b/mops/src/opsa/cpu.tpp index 60b7db2..e628e2e 100644 --- a/mops/src/opsa/cpu.tpp +++ b/mops/src/opsa/cpu.tpp @@ -84,7 +84,7 @@ void mops::outer_product_scatter_add_vjp( scalar_t *grad_output_ptr = grad_output.data; scalar_t *a_ptr = A.data; scalar_t *b_ptr = B.data; - int32_t *indices_output_ptr = indices_output.data; + [[maybe_unused]] int32_t *indices_output_ptr = indices_output.data; #pragma omp parallel for for (size_t i = 0; i < size_ab; i++) { @@ -167,7 +167,7 @@ void mops::outer_product_scatter_add_vjp_vjp( scalar_t *grad_output_ptr = grad_output.data; scalar_t *a_ptr = A.data; scalar_t *b_ptr = B.data; - int32_t *indices_output_ptr = indices_output.data; + [[maybe_unused]] int32_t *indices_output_ptr = indices_output.data; scalar_t *grad_output_ptr_i = nullptr; scalar_t *a_ptr_i = nullptr; diff --git a/mops/src/opsa/opsa.cpp b/mops/src/opsa/opsa.cpp index 7a82088..7eab3ad 100644 --- a/mops/src/opsa/opsa.cpp +++ b/mops/src/opsa/opsa.cpp @@ -54,17 +54,25 @@ template void mops::outer_product_scatter_add_vjp_vjp( #ifndef MOPS_CUDA_ENABLED template void mops::cuda:: - outer_product_scatter_add(Tensor, Tensor, Tensor, Tensor) { + outer_product_scatter_add(Tensor, Tensor, Tensor, Tensor, void*) { throw std::runtime_error("MOPS was not compiled with CUDA support"); } // explicit instantiations of CUDA templates template void mops::cuda::outer_product_scatter_add( - Tensor output, Tensor A, Tensor B, Tensor indices_output + Tensor output, + Tensor A, + Tensor B, + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::outer_product_scatter_add( - Tensor output, Tensor A, Tensor B, Tensor indices_output + Tensor output, + Tensor A, + Tensor B, + Tensor indices_output, + void* cuda_stream ); template @@ -74,7 +82,8 @@ void mops::cuda::outer_product_scatter_add_vjp( Tensor /*grad_output*/, Tensor /*A*/, Tensor /*B*/, - Tensor /*indices_output*/ + Tensor /*indices_output*/, + void* /*cudaStream_t*/ ) { throw std::runtime_error("MOPS was not compiled with CUDA support"); } @@ -85,7 +94,8 @@ template void mops::cuda::outer_product_scatter_add_vjp( Tensor grad_output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::outer_product_scatter_add_vjp( @@ -94,7 +104,8 @@ template void mops::cuda::outer_product_scatter_add_vjp( Tensor grad_output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template @@ -107,7 +118,8 @@ void mops::cuda::outer_product_scatter_add_vjp_vjp( Tensor /*grad_output*/, Tensor /*A*/, Tensor /*B*/, - Tensor /*indices_output*/ + Tensor /*indices_output*/, + void* /*cudaStream_t*/ ) { throw std::runtime_error("MOPS was not compiled with CUDA support"); } @@ -121,7 +133,8 @@ template void mops::cuda::outer_product_scatter_add_vjp_vjp( Tensor grad_output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::outer_product_scatter_add_vjp_vjp( @@ -133,7 +146,8 @@ template void mops::cuda::outer_product_scatter_add_vjp_vjp( Tensor grad_output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); #endif diff --git a/mops/src/opsa/opsa.cu b/mops/src/opsa/opsa.cu index a831335..a8e6cb0 100644 --- a/mops/src/opsa/opsa.cu +++ b/mops/src/opsa/opsa.cu @@ -12,7 +12,7 @@ using namespace mops::cuda; #define FULL_MASK 0xffffffff template -__global__ __launch_bounds__(WARP_SIZE* NWARPS_PER_BLOCK) void outer_product_scatter_add_kernel( +__global__ void outer_product_scatter_add_kernel( Tensor A, Tensor B, Tensor first_occurences, @@ -24,7 +24,6 @@ __global__ __launch_bounds__(WARP_SIZE* NWARPS_PER_BLOCK) void outer_product_sca const int32_t threadCol = threadIdx.x % WARP_SIZE; const int32_t threadRow = threadIdx.x / WARP_SIZE; - const int32_t nThreadRow = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; int32_t* first_occurences_start = first_occurences.data; int32_t* first_occurences_end = first_occurences.data + output.shape[0]; @@ -36,7 +35,7 @@ __global__ __launch_bounds__(WARP_SIZE* NWARPS_PER_BLOCK) void outer_product_sca if (nsamples == 0) { // fill tensor with zeros instead - for (int i = threadRow; i < A.shape[1]; i += nThreadRow) { + for (int i = threadRow; i < A.shape[1]; i += NWARPS_PER_BLOCK) { for (int j = threadCol; j < B.shape[1]; j += WARP_SIZE) { output.data[blockIdx.x * A.shape[1] * B.shape[1] + i * B.shape[1] + j] = 0.0; } @@ -44,7 +43,7 @@ __global__ __launch_bounds__(WARP_SIZE* NWARPS_PER_BLOCK) void outer_product_sca return; } - for (int i = threadRow; i < A.shape[1]; i += nThreadRow) { + for (int i = threadRow; i < A.shape[1]; i += NWARPS_PER_BLOCK) { for (int j = threadCol; j < B.shape[1]; j += WARP_SIZE) { scalar_t reg_output = 0.0; @@ -66,10 +65,21 @@ void mops::cuda::outer_product_scatter_add( Tensor output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ) { check_opsa(output, A, B, indices_output, "cuda_outer_product_scatter_add"); + cudaPointerAttributes attributes; + CUDA_CHECK_ERROR(cudaPointerGetAttributes(&attributes, A.data)); + int current_device; + CUDA_CHECK_ERROR(cudaGetDevice(¤t_device)); + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(attributes.device)); + } + + cudaStream_t cstream = reinterpret_cast(cuda_stream); + int32_t* first_occurences = calculate_first_occurences_cuda( indices_output.data, indices_output.shape[0], output.shape[0] ); @@ -78,26 +88,37 @@ void mops::cuda::outer_product_scatter_add( dim3 blockDim(WARP_SIZE * NWARPS_PER_BLOCK, 1, 1); - outer_product_scatter_add_kernel<<>>( + outer_product_scatter_add_kernel<<>>( A, B, mops::Tensor{first_occurences, {output.shape[0] * 2}}, indices_output, output ); CUDA_CHECK_ERROR(cudaGetLastError()); + CUDA_CHECK_ERROR(cudaStreamSynchronize(cstream)); - CUDA_CHECK_ERROR(cudaDeviceSynchronize()); + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(current_device)); + } } // explicit instantiations of CUDA templates template void mops::cuda::outer_product_scatter_add( - Tensor output, Tensor A, Tensor B, Tensor indices_output + Tensor output, + Tensor A, + Tensor B, + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::outer_product_scatter_add( - Tensor output, Tensor A, Tensor B, Tensor indices_output + Tensor output, + Tensor A, + Tensor B, + Tensor indices_output, + void* cuda_stream ); template -__global__ void __launch_bounds__(NWARPS_PER_BLOCK* WARP_SIZE) outer_product_scatter_add_vjp_kernel( +__global__ void outer_product_scatter_add_vjp_kernel( Tensor A, Tensor B, Tensor first_occurences, @@ -111,25 +132,24 @@ __global__ void __launch_bounds__(NWARPS_PER_BLOCK* WARP_SIZE) outer_product_sca const int32_t threadCol = threadIdx.x % WARP_SIZE; const int32_t threadRow = threadIdx.x / WARP_SIZE; - const int32_t nThreadRow = blockDim.x / WARP_SIZE; void* sptr = buffer; size_t space = 0; scalar_t* buffer_grad_in = shared_array(A.shape[1] * B.shape[1], sptr, &space); - scalar_t* buffer_A = shared_array(nThreadRow * A.shape[1], sptr, &space); - scalar_t* buffer_B = shared_array(nThreadRow * B.shape[1], sptr, &space); + scalar_t* buffer_A = shared_array(NWARPS_PER_BLOCK * A.shape[1], sptr, &space); + scalar_t* buffer_B = shared_array(NWARPS_PER_BLOCK * B.shape[1], sptr, &space); scalar_t* buffer_grad_A; scalar_t* buffer_grad_B; if (grad_A.data != nullptr) { - buffer_grad_A = shared_array(nThreadRow * A.shape[1], sptr, &space); + buffer_grad_A = shared_array(NWARPS_PER_BLOCK * A.shape[1], sptr, &space); } if (grad_B.data != nullptr) { - buffer_grad_B = shared_array(nThreadRow * B.shape[1], sptr, &space); + buffer_grad_B = shared_array(NWARPS_PER_BLOCK * B.shape[1], sptr, &space); } int32_t* first_occurences_start = first_occurences.data; @@ -154,7 +174,7 @@ __global__ void __launch_bounds__(NWARPS_PER_BLOCK* WARP_SIZE) outer_product_sca __syncthreads(); - for (int32_t sample_idx = threadRow; sample_idx < nsamples; sample_idx += nThreadRow) { + for (int32_t sample_idx = threadRow; sample_idx < nsamples; sample_idx += NWARPS_PER_BLOCK) { __syncwarp(); @@ -210,7 +230,7 @@ __global__ void __launch_bounds__(NWARPS_PER_BLOCK* WARP_SIZE) outer_product_sca // thread 0 contains the gradient for this subset of features_A. if (threadCol == 0) { - buffer_grad_A[i * nThreadRow + threadRow] = dsumA; + buffer_grad_A[i * NWARPS_PER_BLOCK + threadRow] = dsumA; } } } @@ -227,7 +247,8 @@ __global__ void __launch_bounds__(NWARPS_PER_BLOCK* WARP_SIZE) outer_product_sca if (grad_A.data != nullptr) { // write gradA for (int i = threadCol; i < A.shape[1]; i += WARP_SIZE) { - grad_A.data[sample * A.shape[1] + i] = buffer_grad_A[i * nThreadRow + threadRow]; + grad_A.data[sample * A.shape[1] + i] = + buffer_grad_A[i * NWARPS_PER_BLOCK + threadRow]; } } } @@ -240,12 +261,23 @@ void mops::cuda::outer_product_scatter_add_vjp( Tensor grad_output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ) { check_opsa_vjp( grad_A, grad_B, grad_output, A, B, indices_output, "cuda_outer_product_scatter_add_vjp" ); + cudaPointerAttributes attributes; + CUDA_CHECK_ERROR(cudaPointerGetAttributes(&attributes, A.data)); + int current_device; + CUDA_CHECK_ERROR(cudaGetDevice(¤t_device)); + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(attributes.device)); + } + + cudaStream_t cstream = reinterpret_cast(cuda_stream); + int32_t* first_occurences = calculate_first_occurences_cuda( indices_output.data, indices_output.shape[0], grad_output.shape[0] ); @@ -268,7 +300,7 @@ void mops::cuda::outer_product_scatter_add_vjp( shared_array(NWARPS_PER_BLOCK * B.shape[1], sptr, &space); } - outer_product_scatter_add_vjp_kernel<<>>( + outer_product_scatter_add_vjp_kernel<<>>( A, B, mops::Tensor{first_occurences, {grad_output.shape[0]}}, @@ -279,8 +311,11 @@ void mops::cuda::outer_product_scatter_add_vjp( ); CUDA_CHECK_ERROR(cudaGetLastError()); + CUDA_CHECK_ERROR(cudaStreamSynchronize(cstream)); - CUDA_CHECK_ERROR(cudaDeviceSynchronize()); + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(current_device)); + } } // these templates will be precompiled and provided in the mops library @@ -290,7 +325,8 @@ template void mops::cuda::outer_product_scatter_add_vjp( Tensor grad_output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::outer_product_scatter_add_vjp( @@ -299,22 +335,312 @@ template void mops::cuda::outer_product_scatter_add_vjp( Tensor grad_output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); +template +__global__ void outer_product_scatter_add_vjp_vjp_kernel( + Tensor grad_grad_output, + Tensor grad_A_2, + Tensor grad_B_2, + Tensor grad_grad_A, + Tensor grad_grad_B, + Tensor grad_output, + Tensor A, + Tensor B, + Tensor first_occurences, + Tensor indices_output +) { + extern __shared__ char buffer[]; + + const int32_t threadCol = threadIdx.x % WARP_SIZE; + const int32_t threadRow = threadIdx.x / WARP_SIZE; + + void* sptr = buffer; + size_t space = 0; + + scalar_t* buffer_grad_out; + + scalar_t* buffer_A; + scalar_t* buffer_B; + + scalar_t* buffer_grad_grad_out; + scalar_t* buffer_grad_A_2; + scalar_t* buffer_grad_B_2; + + scalar_t* buffer_grad_grad_A; + scalar_t* buffer_grad_grad_B; + + bool compute_grad_A_2 = grad_A_2.data != nullptr; + bool compute_grad_B_2 = grad_B_2.data != nullptr; + bool compute_grad_grad_output = (grad_grad_output.data != nullptr); + + buffer_grad_out = shared_array(A.shape[1] * B.shape[1], sptr, &space); + buffer_A = shared_array(NWARPS_PER_BLOCK * A.shape[1], sptr, &space); + buffer_B = shared_array(NWARPS_PER_BLOCK * B.shape[1], sptr, &space); + + buffer_grad_grad_A = shared_array(NWARPS_PER_BLOCK * A.shape[1], sptr, &space); + buffer_grad_grad_B = shared_array(NWARPS_PER_BLOCK * B.shape[1], sptr, &space); + + if (compute_grad_grad_output) { + buffer_grad_grad_out = shared_array(A.shape[1] * B.shape[1], sptr, &space); + } + + if (compute_grad_A_2) { + buffer_grad_A_2 = shared_array(NWARPS_PER_BLOCK * A.shape[1], sptr, &space); + } + + if (compute_grad_B_2) { + buffer_grad_B_2 = shared_array(NWARPS_PER_BLOCK * B.shape[1], sptr, &space); + } + + int32_t* first_occurences_start = first_occurences.data; + int32_t* first_occurences_end = first_occurences.data + grad_output.shape[0]; + + int32_t sample_start = first_occurences_start[blockIdx.x]; + int32_t sample_end = first_occurences_end[blockIdx.x]; + + int32_t nsamples = sample_end - sample_start; + + if (nsamples == 0) { + return; + } + + /* + * initialise buffer_grad_in for this sub block + */ + + for (int tid = threadIdx.x; tid < A.shape[1] * B.shape[1]; tid += blockDim.x) { + buffer_grad_out[tid] = grad_output.data[blockIdx.x * A.shape[1] * B.shape[1] + tid]; + } + + if (compute_grad_grad_output) { + for (int tid = threadIdx.x; tid < A.shape[1] * B.shape[1]; tid += blockDim.x) { + buffer_grad_grad_out[tid] = 0.0; + } + } + + __syncthreads(); + + for (int32_t sample_idx = threadRow; sample_idx < nsamples; sample_idx += NWARPS_PER_BLOCK) { + + __syncwarp(); + + int32_t sample = sample_idx + sample_start; + + /* + * zero temporary buffers and load A, B into shared memory + */ + + for (int tid = threadCol; tid < A.shape[1]; tid += WARP_SIZE) { + + if (compute_grad_A_2) { + buffer_grad_A_2[threadRow * A.shape[1] + tid] = 0.0; + } + + buffer_A[threadRow * A.shape[1] + tid] = A.data[sample * A.shape[1] + tid]; + + if (grad_grad_A.data != nullptr) { + buffer_grad_grad_A[threadRow * A.shape[1] + tid] = + grad_grad_A.data[sample * A.shape[1] + tid]; + } + } + + for (int tid = threadCol; tid < B.shape[1]; tid += WARP_SIZE) { + + if (compute_grad_B_2) { + buffer_grad_B_2[threadRow * B.shape[1] + tid] = 0.0; + } + + buffer_B[threadRow * B.shape[1] + tid] = B.data[sample * B.shape[1] + tid]; + + if (grad_grad_B.data != nullptr) { + buffer_grad_grad_B[threadRow * B.shape[1] + tid] = + grad_grad_B.data[sample * B.shape[1] + tid]; + } + } + + __syncwarp(); + + /* + * perform the reduction + */ + for (int i = 0; i < A.shape[1]; i++) { + + scalar_t grad_A2_tmp = 0.0; + + for (int j = threadCol; j < B.shape[1]; j += WARP_SIZE) { + + if (compute_grad_A_2 && grad_grad_B.data != nullptr) { + grad_A2_tmp += buffer_grad_grad_B[threadRow * B.shape[1] + j] * + buffer_grad_out[i * B.shape[1] + j]; + } + + if (compute_grad_B_2 && grad_grad_A.data != nullptr) { + buffer_grad_B_2[threadRow * B.shape[1] + j] += + buffer_grad_grad_A[threadRow * A.shape[1] + i] * + buffer_grad_out[i * B.shape[1] + j]; + } + + if (compute_grad_grad_output && grad_grad_B.data != nullptr) { + ATOMIC_ADD( + &buffer_grad_grad_out[i * B.shape[1] + j], + buffer_A[threadRow * A.shape[1] + i] * + buffer_grad_grad_B[threadRow * B.shape[1] + j] + ); + } + + if (compute_grad_grad_output && grad_grad_A.data != nullptr) { + ATOMIC_ADD( + &buffer_grad_grad_out[i * B.shape[1] + j], + buffer_B[threadRow * B.shape[1] + j] * + buffer_grad_grad_A[threadRow * A.shape[1] + i] + ); + } + } + + // reduce across B dimension + if (compute_grad_A_2) { + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + grad_A2_tmp += __shfl_down_sync(FULL_MASK, grad_A2_tmp, offset, WARP_SIZE); + } + if (threadCol == 0) { + buffer_grad_A_2[i * NWARPS_PER_BLOCK + threadRow] = grad_A2_tmp; + } + } + } + + __syncwarp(); + + if (compute_grad_B_2) { + // write gradB + for (int j = threadCol; j < B.shape[1]; j += WARP_SIZE) { + grad_B_2.data[sample * B.shape[1] + j] = buffer_grad_B_2[threadRow * B.shape[1] + j]; + } + } + + if (compute_grad_A_2) { + // write gradA + for (int i = threadCol; i < A.shape[1]; i += WARP_SIZE) { + grad_A_2.data[sample * A.shape[1] + i] = + buffer_grad_A_2[i * NWARPS_PER_BLOCK + threadRow]; + } + } + } + + __syncthreads(); + + if (compute_grad_grad_output) { + for (int tid = threadIdx.x; tid < A.shape[1] * B.shape[1]; tid += blockDim.x) { + grad_grad_output.data[blockIdx.x * A.shape[1] * B.shape[1] + tid] = + buffer_grad_grad_out[tid]; + } + } +} + template void mops::cuda::outer_product_scatter_add_vjp_vjp( - Tensor /*grad_grad_output*/, - Tensor /*grad_A_2*/, - Tensor /*grad_B_2*/, - Tensor /*grad_grad_A*/, - Tensor /*grad_grad_B*/, - Tensor /*grad_output*/, - Tensor /*A*/, - Tensor /*B*/, - Tensor /*indices_output*/ + Tensor grad_grad_output, + Tensor grad_A_2, + Tensor grad_B_2, + Tensor grad_grad_A, + Tensor grad_grad_B, + Tensor grad_output, + Tensor A, + Tensor B, + Tensor indices_output, + void* cuda_stream ) { - throw std::runtime_error("Not implemented"); + + check_opsa_vjp_vjp( + grad_grad_output, + grad_A_2, + grad_B_2, + grad_grad_A, + grad_grad_B, + grad_output, + A, + B, + indices_output, + "cuda_outer_product_scatter_add_vjp_vjp" + ); + + cudaPointerAttributes attributes; + CUDA_CHECK_ERROR(cudaPointerGetAttributes(&attributes, A.data)); + int current_device; + CUDA_CHECK_ERROR(cudaGetDevice(¤t_device)); + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(attributes.device)); + } + + cudaStream_t cstream = reinterpret_cast(cuda_stream); + + int32_t* first_occurences = calculate_first_occurences_cuda( + indices_output.data, indices_output.shape[0], grad_output.shape[0] + ); + + dim3 gridDim(grad_output.shape[0], 1, 1); + + dim3 blockDim(NWARPS_PER_BLOCK * WARP_SIZE, 1, 1); + + void* sptr = 0; + size_t space = 0; + + scalar_t* buffer_grad_out; + + scalar_t* buffer_A; + scalar_t* buffer_B; + + scalar_t* buffer_grad_grad_out; + scalar_t* buffer_grad_A_2; + scalar_t* buffer_grad_B_2; + + scalar_t* buffer_grad_grad_A; + scalar_t* buffer_grad_grad_B; + + bool compute_grad_A_2 = grad_A_2.data != nullptr; + bool compute_grad_B_2 = grad_B_2.data != nullptr; + bool compute_grad_grad_output = (grad_grad_output.data != nullptr); + + buffer_grad_out = shared_array(A.shape[1] * B.shape[1], sptr, &space); + buffer_A = shared_array(NWARPS_PER_BLOCK * A.shape[1], sptr, &space); + buffer_B = shared_array(NWARPS_PER_BLOCK * B.shape[1], sptr, &space); + + buffer_grad_grad_A = shared_array(NWARPS_PER_BLOCK * A.shape[1], sptr, &space); + buffer_grad_grad_B = shared_array(NWARPS_PER_BLOCK * B.shape[1], sptr, &space); + + if (compute_grad_grad_output) { + buffer_grad_grad_out = shared_array(A.shape[1] * B.shape[1], sptr, &space); + } + + if (compute_grad_A_2) { + buffer_grad_A_2 = shared_array(NWARPS_PER_BLOCK * A.shape[1], sptr, &space); + } + + if (compute_grad_B_2) { + buffer_grad_B_2 = shared_array(NWARPS_PER_BLOCK * B.shape[1], sptr, &space); + } + + outer_product_scatter_add_vjp_vjp_kernel<<>>( + grad_grad_output, + grad_A_2, + grad_B_2, + grad_grad_A, + grad_grad_B, + grad_output, + A, + B, + mops::Tensor{first_occurences, {grad_output.shape[0]}}, + indices_output + ); + + CUDA_CHECK_ERROR(cudaGetLastError()); + CUDA_CHECK_ERROR(cudaStreamSynchronize(cstream)); + + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(current_device)); + } } // explicit instantiations of CUDA templates @@ -327,7 +653,8 @@ template void mops::cuda::outer_product_scatter_add_vjp_vjp( Tensor grad_output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::outer_product_scatter_add_vjp_vjp( @@ -339,5 +666,6 @@ template void mops::cuda::outer_product_scatter_add_vjp_vjp( Tensor grad_output, Tensor A, Tensor B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); diff --git a/mops/src/opsaw/capi.cpp b/mops/src/opsaw/capi.cpp index 04e0bc4..3585d70 100644 --- a/mops/src/opsaw/capi.cpp +++ b/mops/src/opsaw/capi.cpp @@ -210,7 +210,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_with_weights_f32( mops_tensor_2d_f32_t B, mops_tensor_2d_f32_t W, mops_tensor_1d_i32_t indices_W, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::outer_product_scatter_add_with_weights( @@ -222,7 +223,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_with_weights_f32( {B.data, {checked_cast(B.shape[0]), checked_cast(B.shape[1])}}, {W.data, {checked_cast(W.shape[0]), checked_cast(W.shape[1])}}, {indices_W.data, {checked_cast(indices_W.shape[0])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -233,7 +235,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_with_weights_f64( mops_tensor_2d_f64_t B, mops_tensor_2d_f64_t W, mops_tensor_1d_i32_t indices_W, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::outer_product_scatter_add_with_weights( @@ -245,7 +248,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_with_weights_f64( {B.data, {checked_cast(B.shape[0]), checked_cast(B.shape[1])}}, {W.data, {checked_cast(W.shape[0]), checked_cast(W.shape[1])}}, {indices_W.data, {checked_cast(indices_W.shape[0])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -259,7 +263,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_with_weights_vjp_f32( mops_tensor_2d_f32_t B, mops_tensor_2d_f32_t W, mops_tensor_1d_i32_t indices_W, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::outer_product_scatter_add_with_weights_vjp( @@ -274,7 +279,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_with_weights_vjp_f32( {B.data, {checked_cast(B.shape[0]), checked_cast(B.shape[1])}}, {W.data, {checked_cast(W.shape[0]), checked_cast(W.shape[1])}}, {indices_W.data, {checked_cast(indices_W.shape[0])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -288,7 +294,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_with_weights_vjp_f64( mops_tensor_2d_f64_t B, mops_tensor_2d_f64_t W, mops_tensor_1d_i32_t indices_W, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::outer_product_scatter_add_with_weights_vjp( @@ -303,7 +310,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_with_weights_vjp_f64( {B.data, {checked_cast(B.shape[0]), checked_cast(B.shape[1])}}, {W.data, {checked_cast(W.shape[0]), checked_cast(W.shape[1])}}, {indices_W.data, {checked_cast(indices_W.shape[0])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -321,7 +329,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_with_weights_vjp_vjp_f32( mops_tensor_2d_f32_t B, mops_tensor_2d_f32_t W, mops_tensor_1d_i32_t indices_W, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::outer_product_scatter_add_with_weights_vjp_vjp( @@ -346,7 +355,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_with_weights_vjp_vjp_f32( {B.data, {checked_cast(B.shape[0]), checked_cast(B.shape[1])}}, {W.data, {checked_cast(W.shape[0]), checked_cast(W.shape[1])}}, {indices_W.data, {checked_cast(indices_W.shape[0])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -364,7 +374,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_with_weights_vjp_vjp_f64( mops_tensor_2d_f64_t B, mops_tensor_2d_f64_t W, mops_tensor_1d_i32_t indices_W, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::outer_product_scatter_add_with_weights_vjp_vjp( @@ -389,7 +400,8 @@ extern "C" int mops_cuda_outer_product_scatter_add_with_weights_vjp_vjp_f64( {B.data, {checked_cast(B.shape[0]), checked_cast(B.shape[1])}}, {W.data, {checked_cast(W.shape[0]), checked_cast(W.shape[1])}}, {indices_W.data, {checked_cast(indices_W.shape[0])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } diff --git a/mops/src/opsaw/cuda.tpp b/mops/src/opsaw/cuda.tpp deleted file mode 100644 index 204e7db..0000000 --- a/mops/src/opsaw/cuda.tpp +++ /dev/null @@ -1,49 +0,0 @@ -#include - -#include "mops/opsaw.hpp" - -template -void mops::cuda::outer_product_scatter_add_with_weights( - Tensor, - Tensor, - Tensor, - Tensor, - Tensor, - Tensor -) { - throw std::runtime_error("CUDA implementation does not exist yet"); -} - -template -void mops::cuda::outer_product_scatter_add_with_weights_vjp( - Tensor, - Tensor, - Tensor, - Tensor, - Tensor, - Tensor, - Tensor, - Tensor, - Tensor -) { - throw std::runtime_error("CUDA implementation does not exist yet"); -} - -template -void mops::cuda::outer_product_scatter_add_with_weights_vjp_vjp( - Tensor /*grad_grad_output*/, - Tensor /*grad_A_2*/, - Tensor /*grad_B_2*/, - Tensor /*grad_W_2*/, - Tensor /*grad_grad_A*/, - Tensor /*grad_grad_B*/, - Tensor /*grad_grad_W*/, - Tensor /*grad_output*/, - Tensor /*A*/, - Tensor /*B*/, - Tensor /*W*/, - Tensor /*indices_W*/, - Tensor /*indices_output*/ -) { - throw std::runtime_error("CUDA implementation does not exist yet"); -} diff --git a/mops/src/opsaw/opsaw.cpp b/mops/src/opsaw/opsaw.cpp index b3b30b8..6e30c41 100644 --- a/mops/src/opsaw/opsaw.cpp +++ b/mops/src/opsaw/opsaw.cpp @@ -75,42 +75,25 @@ template void mops::outer_product_scatter_add_with_weights_vjp_vjp( Tensor indices_output ); -#ifdef MOPS_CUDA_ENABLED -#include "cuda.tpp" -#else +#ifndef MOPS_CUDA_ENABLED template void mops::cuda:: - outer_product_scatter_add_with_weights(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) { + outer_product_scatter_add_with_weights(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, void*) { throw std::runtime_error("MOPS was not compiled with CUDA support"); } template void mops::cuda:: - outer_product_scatter_add_with_weights_vjp(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) { + outer_product_scatter_add_with_weights_vjp(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, void*) { throw std::runtime_error("MOPS was not compiled with CUDA support"); } template -void mops::cuda::outer_product_scatter_add_with_weights_vjp_vjp( - Tensor /*grad_grad_output*/, - Tensor /*grad_A_2*/, - Tensor /*grad_B_2*/, - Tensor /*grad_W_2*/, - Tensor /*grad_grad_A*/, - Tensor /*grad_grad_B*/, - Tensor /*grad_grad_W*/, - Tensor /*grad_output*/, - Tensor /*A*/, - Tensor /*B*/, - Tensor /*W*/, - Tensor /*indices_W*/, - Tensor /*indices_output*/ -) { +void mops::cuda:: + outer_product_scatter_add_with_weights_vjp_vjp(Tensor /*grad_grad_output*/, Tensor /*grad_A_2*/, Tensor /*grad_B_2*/, Tensor /*grad_W_2*/, Tensor /*grad_grad_A*/, Tensor /*grad_grad_B*/, Tensor /*grad_grad_W*/, Tensor /*grad_output*/, Tensor /*A*/, Tensor /*B*/, Tensor /*W*/, Tensor /*indices_W*/, Tensor /*indices_output*/, void*) { throw std::runtime_error("MOPS was not compiled with CUDA support"); } -#endif - // explicit instantiations of CUDA templates template void mops::cuda::outer_product_scatter_add_with_weights( Tensor output, @@ -118,7 +101,8 @@ template void mops::cuda::outer_product_scatter_add_with_weights( Tensor B, Tensor W, Tensor indices_W, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::outer_product_scatter_add_with_weights( @@ -127,7 +111,8 @@ template void mops::cuda::outer_product_scatter_add_with_weights( Tensor B, Tensor W, Tensor indices_W, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::outer_product_scatter_add_with_weights_vjp( @@ -139,7 +124,8 @@ template void mops::cuda::outer_product_scatter_add_with_weights_vjp( Tensor B, Tensor W, Tensor indices_W, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::outer_product_scatter_add_with_weights_vjp( @@ -151,7 +137,8 @@ template void mops::cuda::outer_product_scatter_add_with_weights_vjp( Tensor B, Tensor W, Tensor indices_W, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::outer_product_scatter_add_with_weights_vjp_vjp( @@ -167,7 +154,8 @@ template void mops::cuda::outer_product_scatter_add_with_weights_vjp_vjp( Tensor B, Tensor W, Tensor indices_W, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::outer_product_scatter_add_with_weights_vjp_vjp( @@ -183,5 +171,8 @@ template void mops::cuda::outer_product_scatter_add_with_weights_vjp_vjp Tensor B, Tensor W, Tensor indices_W, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); + +#endif \ No newline at end of file diff --git a/mops/src/opsaw/opsaw.cu b/mops/src/opsaw/opsaw.cu index 8421bac..28c48fe 100644 --- a/mops/src/opsaw/opsaw.cu +++ b/mops/src/opsaw/opsaw.cu @@ -1 +1,106 @@ -// todo: cuda device code +#include + +#include "mops/opsaw.hpp" + +using namespace mops; +using namespace mops::cuda; +using namespace std; + +template +void mops::cuda:: + outer_product_scatter_add_with_weights(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, void*) { + throw std::runtime_error("CUDA implementation does not exist yet"); +} + +template +void mops::cuda:: + outer_product_scatter_add_with_weights_vjp(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, void*) { + throw std::runtime_error("CUDA implementation does not exist yet"); +} + +template +void mops::cuda:: + outer_product_scatter_add_with_weights_vjp_vjp(Tensor /*grad_grad_output*/, Tensor /*grad_A_2*/, Tensor /*grad_B_2*/, Tensor /*grad_W_2*/, Tensor /*grad_grad_A*/, Tensor /*grad_grad_B*/, Tensor /*grad_grad_W*/, Tensor /*grad_output*/, Tensor /*A*/, Tensor /*B*/, Tensor /*W*/, Tensor /*indices_W*/, Tensor /*indices_output*/, void*) { + throw std::runtime_error("CUDA implementation does not exist yet"); +} + +// explicit instantiations of CUDA templates +template void mops::cuda::outer_product_scatter_add_with_weights( + Tensor output, + Tensor A, + Tensor B, + Tensor W, + Tensor indices_W, + Tensor indices_output, + void* cuda_stream +); + +template void mops::cuda::outer_product_scatter_add_with_weights( + Tensor output, + Tensor A, + Tensor B, + Tensor W, + Tensor indices_W, + Tensor indices_output, + void* cuda_stream +); + +template void mops::cuda::outer_product_scatter_add_with_weights_vjp( + Tensor grad_A, + Tensor grad_B, + Tensor grad_W, + Tensor grad_output, + Tensor A, + Tensor B, + Tensor W, + Tensor indices_W, + Tensor indices_output, + void* cuda_stream +); + +template void mops::cuda::outer_product_scatter_add_with_weights_vjp( + Tensor grad_A, + Tensor grad_B, + Tensor grad_W, + Tensor grad_output, + Tensor A, + Tensor B, + Tensor W, + Tensor indices_W, + Tensor indices_output, + void* cuda_stream +); + +template void mops::cuda::outer_product_scatter_add_with_weights_vjp_vjp( + Tensor grad_grad_output, + Tensor grad_A_2, + Tensor grad_B_2, + Tensor grad_W_2, + Tensor grad_grad_A, + Tensor grad_grad_B, + Tensor grad_grad_W, + Tensor grad_output, + Tensor A, + Tensor B, + Tensor W, + Tensor indices_W, + Tensor indices_output, + void* cuda_stream +); + +template void mops::cuda::outer_product_scatter_add_with_weights_vjp_vjp( + Tensor grad_grad_output, + Tensor grad_A_2, + Tensor grad_B_2, + Tensor grad_W_2, + Tensor grad_grad_A, + Tensor grad_grad_B, + Tensor grad_grad_W, + Tensor grad_output, + Tensor A, + Tensor B, + Tensor W, + Tensor indices_W, + Tensor indices_output, + void* cuda_stream +); \ No newline at end of file diff --git a/mops/src/sap/capi.cpp b/mops/src/sap/capi.cpp index e5c79d8..e79a256 100644 --- a/mops/src/sap/capi.cpp +++ b/mops/src/sap/capi.cpp @@ -191,7 +191,8 @@ extern "C" int mops_cuda_sparse_accumulation_of_products_f32( mops_tensor_1d_f32_t C, mops_tensor_1d_i32_t indices_A, mops_tensor_1d_i32_t indices_B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::sparse_accumulation_of_products( @@ -201,7 +202,8 @@ extern "C" int mops_cuda_sparse_accumulation_of_products_f32( {C.data, {checked_cast(C.shape[0])}}, {indices_A.data, {checked_cast(indices_A.shape[0])}}, {indices_B.data, {checked_cast(indices_B.shape[0])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -213,7 +215,8 @@ extern "C" int mops_cuda_sparse_accumulation_of_products_f64( mops_tensor_1d_f64_t C, mops_tensor_1d_i32_t indices_A, mops_tensor_1d_i32_t indices_B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::sparse_accumulation_of_products( @@ -223,7 +226,8 @@ extern "C" int mops_cuda_sparse_accumulation_of_products_f64( {C.data, {checked_cast(C.shape[0])}}, {indices_A.data, {checked_cast(indices_A.shape[0])}}, {indices_B.data, {checked_cast(indices_B.shape[0])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -237,7 +241,8 @@ extern "C" int mops_cuda_sparse_accumulation_of_products_vjp_f32( mops_tensor_1d_f32_t C, mops_tensor_1d_i32_t indices_A, mops_tensor_1d_i32_t indices_B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::sparse_accumulation_of_products_vjp( @@ -250,7 +255,8 @@ extern "C" int mops_cuda_sparse_accumulation_of_products_vjp_f32( {C.data, {checked_cast(C.shape[0])}}, {indices_A.data, {checked_cast(indices_A.shape[0])}}, {indices_B.data, {checked_cast(indices_B.shape[0])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -264,7 +270,8 @@ extern "C" int mops_cuda_sparse_accumulation_of_products_vjp_f64( mops_tensor_1d_f64_t C, mops_tensor_1d_i32_t indices_A, mops_tensor_1d_i32_t indices_B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::sparse_accumulation_of_products_vjp( @@ -277,7 +284,8 @@ extern "C" int mops_cuda_sparse_accumulation_of_products_vjp_f64( {C.data, {checked_cast(C.shape[0])}}, {indices_A.data, {checked_cast(indices_A.shape[0])}}, {indices_B.data, {checked_cast(indices_B.shape[0])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -294,7 +302,8 @@ extern "C" int mops_cuda_sparse_accumulation_of_products_vjp_vjp_f32( mops_tensor_1d_f32_t C, mops_tensor_1d_i32_t indices_A, mops_tensor_1d_i32_t indices_B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::sparse_accumulation_of_products_vjp_vjp( @@ -313,7 +322,8 @@ extern "C" int mops_cuda_sparse_accumulation_of_products_vjp_vjp_f32( {C.data, {checked_cast(C.shape[0])}}, {indices_A.data, {checked_cast(indices_A.shape[0])}}, {indices_B.data, {checked_cast(indices_B.shape[0])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -330,7 +340,8 @@ extern "C" int mops_cuda_sparse_accumulation_of_products_vjp_vjp_f64( mops_tensor_1d_f64_t C, mops_tensor_1d_i32_t indices_A, mops_tensor_1d_i32_t indices_B, - mops_tensor_1d_i32_t indices_output + mops_tensor_1d_i32_t indices_output, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::sparse_accumulation_of_products_vjp_vjp( @@ -349,7 +360,8 @@ extern "C" int mops_cuda_sparse_accumulation_of_products_vjp_vjp_f64( {C.data, {checked_cast(C.shape[0])}}, {indices_A.data, {checked_cast(indices_A.shape[0])}}, {indices_B.data, {checked_cast(indices_B.shape[0])}}, - {indices_output.data, {checked_cast(indices_output.shape[0])}} + {indices_output.data, {checked_cast(indices_output.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } diff --git a/mops/src/sap/sap.cpp b/mops/src/sap/sap.cpp index 545c429..b91b41a 100644 --- a/mops/src/sap/sap.cpp +++ b/mops/src/sap/sap.cpp @@ -78,13 +78,13 @@ template void mops::sparse_accumulation_of_products_vjp_vjp( #ifndef MOPS_CUDA_ENABLED template void mops::cuda:: - sparse_accumulation_of_products(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) { + sparse_accumulation_of_products(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, void*) { throw std::runtime_error("MOPS was not compiled with CUDA support"); } template void mops::cuda:: - sparse_accumulation_of_products_vjp(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) { + sparse_accumulation_of_products_vjp(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, void*) { throw std::runtime_error("MOPS was not compiled with CUDA support"); } @@ -101,7 +101,8 @@ void mops::cuda::sparse_accumulation_of_products_vjp_vjp( Tensor /*C*/, Tensor /*indices_A*/, Tensor /*indices_B*/, - Tensor /*indices_output*/ + Tensor /*indices_output*/, + void* /*cuda_stream*/ ) { throw std::runtime_error("MOPS was not compiled with CUDA support"); } @@ -114,7 +115,8 @@ template void mops::cuda::sparse_accumulation_of_products( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::sparse_accumulation_of_products( @@ -124,7 +126,8 @@ template void mops::cuda::sparse_accumulation_of_products( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::sparse_accumulation_of_products_vjp( @@ -136,7 +139,8 @@ template void mops::cuda::sparse_accumulation_of_products_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::sparse_accumulation_of_products_vjp( @@ -148,7 +152,8 @@ template void mops::cuda::sparse_accumulation_of_products_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::sparse_accumulation_of_products_vjp_vjp( @@ -163,7 +168,8 @@ template void mops::cuda::sparse_accumulation_of_products_vjp_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::sparse_accumulation_of_products_vjp_vjp( @@ -178,7 +184,8 @@ template void mops::cuda::sparse_accumulation_of_products_vjp_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); #endif diff --git a/mops/src/sap/sap.cu b/mops/src/sap/sap.cu index be72bc8..55d7716 100644 --- a/mops/src/sap/sap.cu +++ b/mops/src/sap/sap.cu @@ -70,8 +70,8 @@ __global__ void sparse_accumulation_of_products_kernel( int b_idx = (packed_indices[k] >> 8) & 0xFF; int a_idx = (packed_indices[k] >> 16) & 0xFF; - atomicAdd( - buffer_out + out_idx * WARP_SIZE + laneID, + ATOMIC_ADD( + &buffer_out[out_idx * WARP_SIZE + laneID], C.data[k] * buffer_A[a_idx * WARP_SIZE + laneID] * buffer_B[b_idx * WARP_SIZE + laneID] ); } @@ -95,12 +95,23 @@ void mops::cuda::sparse_accumulation_of_products( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ) { check_sap( output, A, B, C, indices_A, indices_B, indices_output, "cuda_sparse_accumulation_of_products" ); + cudaPointerAttributes attributes; + CUDA_CHECK_ERROR(cudaPointerGetAttributes(&attributes, A.data)); + int current_device; + CUDA_CHECK_ERROR(cudaGetDevice(¤t_device)); + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(attributes.device)); + } + + cudaStream_t cstream = reinterpret_cast(cuda_stream); + dim3 block_dim(find_integer_divisor(A.shape[0], WARP_SIZE)); dim3 thread_block(WARP_SIZE * NWARPS_PER_BLOCK, 1, 1); @@ -113,12 +124,16 @@ void mops::cuda::sparse_accumulation_of_products( shared_array(WARP_SIZE * B.shape[1], sptr, &space); shared_array(indices_A.shape[0], sptr, &space); - sparse_accumulation_of_products_kernel - <<>>(output, A, B, C, indices_A, indices_B, indices_output); + sparse_accumulation_of_products_kernel<<>>( + output, A, B, C, indices_A, indices_B, indices_output + ); CUDA_CHECK_ERROR(cudaGetLastError()); + CUDA_CHECK_ERROR(cudaStreamSynchronize(cstream)); - CUDA_CHECK_ERROR(cudaDeviceSynchronize()); + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(current_device)); + } } // explicit instanciations of CUDA templates @@ -129,7 +144,8 @@ template void mops::cuda::sparse_accumulation_of_products( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::sparse_accumulation_of_products( @@ -139,7 +155,8 @@ template void mops::cuda::sparse_accumulation_of_products( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template @@ -235,15 +252,15 @@ __global__ void sparse_accumulation_of_products_vjp_kernel( int a_idx = (packed_indices[k] >> 16) & 0xFF; if (grad_A.data != nullptr) { - atomicAdd( - buffer_gradA + a_idx * WARP_SIZE + laneID, + ATOMIC_ADD( + &buffer_gradA[a_idx * WARP_SIZE + laneID], C.data[k] * buffer_B[b_idx * WARP_SIZE + laneID] * buffer_gradout[out_idx * WARP_SIZE + laneID] ); } if (grad_B.data != nullptr) { - atomicAdd( - buffer_gradB + b_idx * WARP_SIZE + laneID, + ATOMIC_ADD( + &buffer_gradB[b_idx * WARP_SIZE + laneID], C.data[k] * buffer_A[a_idx * WARP_SIZE + laneID] * buffer_gradout[out_idx * WARP_SIZE + laneID] ); @@ -282,7 +299,8 @@ void mops::cuda::sparse_accumulation_of_products_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ) { check_sap_vjp( grad_A, @@ -297,6 +315,16 @@ void mops::cuda::sparse_accumulation_of_products_vjp( "cuda_sparse_accumulation_of_products_vjp" ); + cudaPointerAttributes attributes; + CUDA_CHECK_ERROR(cudaPointerGetAttributes(&attributes, A.data)); + int current_device; + CUDA_CHECK_ERROR(cudaGetDevice(¤t_device)); + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(attributes.device)); + } + + cudaStream_t cstream = reinterpret_cast(cuda_stream); + dim3 block_dim(find_integer_divisor(grad_A.shape[0], WARP_SIZE)); dim3 thread_block(WARP_SIZE * NWARPS_PER_BLOCK, 1, 1); @@ -317,13 +345,17 @@ void mops::cuda::sparse_accumulation_of_products_vjp( shared_array(WARP_SIZE * grad_A.shape[1], sptr, &space); } - sparse_accumulation_of_products_vjp_kernel<<>>( - grad_A, grad_B, grad_output, A, B, C, indices_A, indices_B, indices_output - ); + sparse_accumulation_of_products_vjp_kernel + <<>>( + grad_A, grad_B, grad_output, A, B, C, indices_A, indices_B, indices_output + ); CUDA_CHECK_ERROR(cudaGetLastError()); + CUDA_CHECK_ERROR(cudaStreamSynchronize(cstream)); - CUDA_CHECK_ERROR(cudaDeviceSynchronize()); + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(current_device)); + } } template void mops::cuda::sparse_accumulation_of_products_vjp( @@ -335,7 +367,8 @@ template void mops::cuda::sparse_accumulation_of_products_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::sparse_accumulation_of_products_vjp( @@ -347,7 +380,8 @@ template void mops::cuda::sparse_accumulation_of_products_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template @@ -507,15 +541,15 @@ __global__ void sparse_accumulation_of_products_vjp_vjp_kernel( scalar_t grad_grad_A_k = buffer_grad_grad_A[a_idx * WARP_SIZE + laneID]; if (grad_grad_output.data != nullptr) { - atomicAdd( - buffer_grad_grad_output + out_idx * WARP_SIZE + laneID, + ATOMIC_ADD( + &buffer_grad_grad_output[out_idx * WARP_SIZE + laneID], grad_grad_A_k * buffer_B[b_idx * WARP_SIZE + laneID] * c ); } if (grad_B_2.data != nullptr) { - atomicAdd( - buffer_grad_B2 + b_idx * WARP_SIZE + laneID, + ATOMIC_ADD( + &buffer_grad_B2[b_idx * WARP_SIZE + laneID], grad_grad_A_k * buffer_grad_output[out_idx * WARP_SIZE + laneID] * c ); } @@ -525,15 +559,15 @@ __global__ void sparse_accumulation_of_products_vjp_vjp_kernel( scalar_t grad_grad_B_k = buffer_grad_grad_B[b_idx * WARP_SIZE + laneID]; if (grad_grad_output.data != nullptr) { - atomicAdd( - buffer_grad_grad_output + out_idx * WARP_SIZE + laneID, + ATOMIC_ADD( + &buffer_grad_grad_output[out_idx * WARP_SIZE + laneID], grad_grad_B_k * buffer_A[a_idx * WARP_SIZE + laneID] * c ); } if (grad_A_2.data != nullptr) { - atomicAdd( - buffer_grad_A2 + a_idx * WARP_SIZE + laneID, + ATOMIC_ADD( + &buffer_grad_A2[a_idx * WARP_SIZE + laneID], grad_grad_B_k * buffer_grad_output[out_idx * WARP_SIZE + laneID] * c ); } @@ -586,7 +620,8 @@ void mops::cuda::sparse_accumulation_of_products_vjp_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ) { check_sap_vjp_vjp( grad_grad_output, @@ -604,6 +639,16 @@ void mops::cuda::sparse_accumulation_of_products_vjp_vjp( "cuda_sparse_accumulation_of_products_vjp_vjp" ); + cudaPointerAttributes attributes; + CUDA_CHECK_ERROR(cudaPointerGetAttributes(&attributes, A.data)); + int current_device; + CUDA_CHECK_ERROR(cudaGetDevice(¤t_device)); + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(attributes.device)); + } + + cudaStream_t cstream = reinterpret_cast(cuda_stream); + dim3 block_dim(find_integer_divisor(grad_A_2.shape[0], WARP_SIZE)); dim3 thread_block(WARP_SIZE * NWARPS_PER_BLOCK, 1, 1); @@ -658,24 +703,28 @@ void mops::cuda::sparse_accumulation_of_products_vjp_vjp( int32_t* packed_indices = shared_array(indices_A.shape[0], sptr, &space); - sparse_accumulation_of_products_vjp_vjp_kernel<<>>( - grad_grad_output, - grad_A_2, - grad_B_2, - grad_grad_A, - grad_grad_B, - grad_output, - A, - B, - C, - indices_A, - indices_B, - indices_output - ); + sparse_accumulation_of_products_vjp_vjp_kernel + <<>>( + grad_grad_output, + grad_A_2, + grad_B_2, + grad_grad_A, + grad_grad_B, + grad_output, + A, + B, + C, + indices_A, + indices_B, + indices_output + ); CUDA_CHECK_ERROR(cudaGetLastError()); + CUDA_CHECK_ERROR(cudaStreamSynchronize(cstream)); - CUDA_CHECK_ERROR(cudaDeviceSynchronize()); + if (current_device != attributes.device) { + CUDA_CHECK_ERROR(cudaSetDevice(current_device)); + } } // explicit instanciations of CUDA templates @@ -691,7 +740,8 @@ template void mops::cuda::sparse_accumulation_of_products_vjp_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); template void mops::cuda::sparse_accumulation_of_products_vjp_vjp( @@ -706,5 +756,6 @@ template void mops::cuda::sparse_accumulation_of_products_vjp_vjp( Tensor C, Tensor indices_A, Tensor indices_B, - Tensor indices_output + Tensor indices_output, + void* cuda_stream ); diff --git a/mops/src/sasaw/capi.cpp b/mops/src/sasaw/capi.cpp index 5dfb933..57f5620 100644 --- a/mops/src/sasaw/capi.cpp +++ b/mops/src/sasaw/capi.cpp @@ -276,7 +276,8 @@ extern "C" int mops_cuda_sparse_accumulation_scatter_add_with_weights_f32( mops_tensor_1d_i32_t indices_W_1, mops_tensor_1d_i32_t indices_W_2, mops_tensor_1d_i32_t indices_output_1, - mops_tensor_1d_i32_t indices_output_2 + mops_tensor_1d_i32_t indices_output_2, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::sparse_accumulation_scatter_add_with_weights( @@ -292,7 +293,8 @@ extern "C" int mops_cuda_sparse_accumulation_scatter_add_with_weights_f32( {indices_W_1.data, {checked_cast(indices_W_1.shape[0])}}, {indices_W_2.data, {checked_cast(indices_W_2.shape[0])}}, {indices_output_1.data, {checked_cast(indices_output_1.shape[0])}}, - {indices_output_2.data, {checked_cast(indices_output_2.shape[0])}} + {indices_output_2.data, {checked_cast(indices_output_2.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -307,7 +309,8 @@ extern "C" int mops_cuda_sparse_accumulation_scatter_add_with_weights_f64( mops_tensor_1d_i32_t indices_W_1, mops_tensor_1d_i32_t indices_W_2, mops_tensor_1d_i32_t indices_output_1, - mops_tensor_1d_i32_t indices_output_2 + mops_tensor_1d_i32_t indices_output_2, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::sparse_accumulation_scatter_add_with_weights( @@ -323,7 +326,8 @@ extern "C" int mops_cuda_sparse_accumulation_scatter_add_with_weights_f64( {indices_W_1.data, {checked_cast(indices_W_1.shape[0])}}, {indices_W_2.data, {checked_cast(indices_W_2.shape[0])}}, {indices_output_1.data, {checked_cast(indices_output_1.shape[0])}}, - {indices_output_2.data, {checked_cast(indices_output_2.shape[0])}} + {indices_output_2.data, {checked_cast(indices_output_2.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -341,7 +345,8 @@ extern "C" int mops_cuda_sparse_accumulation_scatter_add_with_weights_vjp_f32( mops_tensor_1d_i32_t indices_W_1, mops_tensor_1d_i32_t indices_W_2, mops_tensor_1d_i32_t indices_output_1, - mops_tensor_1d_i32_t indices_output_2 + mops_tensor_1d_i32_t indices_output_2, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp( @@ -362,7 +367,8 @@ extern "C" int mops_cuda_sparse_accumulation_scatter_add_with_weights_vjp_f32( {indices_W_1.data, {checked_cast(indices_W_1.shape[0])}}, {indices_W_2.data, {checked_cast(indices_W_2.shape[0])}}, {indices_output_1.data, {checked_cast(indices_output_1.shape[0])}}, - {indices_output_2.data, {checked_cast(indices_output_2.shape[0])}} + {indices_output_2.data, {checked_cast(indices_output_2.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -380,7 +386,8 @@ extern "C" int mops_cuda_sparse_accumulation_scatter_add_with_weights_vjp_f64( mops_tensor_1d_i32_t indices_W_1, mops_tensor_1d_i32_t indices_W_2, mops_tensor_1d_i32_t indices_output_1, - mops_tensor_1d_i32_t indices_output_2 + mops_tensor_1d_i32_t indices_output_2, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp( @@ -401,7 +408,8 @@ extern "C" int mops_cuda_sparse_accumulation_scatter_add_with_weights_vjp_f64( {indices_W_1.data, {checked_cast(indices_W_1.shape[0])}}, {indices_W_2.data, {checked_cast(indices_W_2.shape[0])}}, {indices_output_1.data, {checked_cast(indices_output_1.shape[0])}}, - {indices_output_2.data, {checked_cast(indices_output_2.shape[0])}} + {indices_output_2.data, {checked_cast(indices_output_2.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -423,7 +431,8 @@ extern "C" int mops_cuda_sparse_accumulation_scatter_add_with_weights_vjp_vjp_f3 mops_tensor_1d_i32_t indices_W_1, mops_tensor_1d_i32_t indices_W_2, mops_tensor_1d_i32_t indices_output_1, - mops_tensor_1d_i32_t indices_output_2 + mops_tensor_1d_i32_t indices_output_2, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp_vjp( @@ -457,7 +466,8 @@ extern "C" int mops_cuda_sparse_accumulation_scatter_add_with_weights_vjp_vjp_f3 {indices_W_1.data, {checked_cast(indices_W_1.shape[0])}}, {indices_W_2.data, {checked_cast(indices_W_2.shape[0])}}, {indices_output_1.data, {checked_cast(indices_output_1.shape[0])}}, - {indices_output_2.data, {checked_cast(indices_output_2.shape[0])}} + {indices_output_2.data, {checked_cast(indices_output_2.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } @@ -479,7 +489,8 @@ extern "C" int mops_cuda_sparse_accumulation_scatter_add_with_weights_vjp_vjp_f6 mops_tensor_1d_i32_t indices_W_1, mops_tensor_1d_i32_t indices_W_2, mops_tensor_1d_i32_t indices_output_1, - mops_tensor_1d_i32_t indices_output_2 + mops_tensor_1d_i32_t indices_output_2, + void* cuda_stream ) { MOPS_CATCH_EXCEPTIONS_BEGIN mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp_vjp( @@ -513,7 +524,8 @@ extern "C" int mops_cuda_sparse_accumulation_scatter_add_with_weights_vjp_vjp_f6 {indices_W_1.data, {checked_cast(indices_W_1.shape[0])}}, {indices_W_2.data, {checked_cast(indices_W_2.shape[0])}}, {indices_output_1.data, {checked_cast(indices_output_1.shape[0])}}, - {indices_output_2.data, {checked_cast(indices_output_2.shape[0])}} + {indices_output_2.data, {checked_cast(indices_output_2.shape[0])}}, + cuda_stream ); MOPS_CATCH_EXCEPTIONS_END } diff --git a/mops/src/sasaw/cuda.tpp b/mops/src/sasaw/cuda.tpp deleted file mode 100644 index c2f8558..0000000 --- a/mops/src/sasaw/cuda.tpp +++ /dev/null @@ -1,61 +0,0 @@ -#include - -#include "mops/sasaw.hpp" - -template -void mops::cuda::sparse_accumulation_scatter_add_with_weights( - Tensor, - Tensor, - Tensor, - Tensor, - Tensor, - Tensor, - Tensor, - Tensor, - Tensor, - Tensor -) { - throw std::runtime_error("CUDA implementation does not exist yet"); -} - -template -void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp( - Tensor, - Tensor , - Tensor, - Tensor, - Tensor , - Tensor , - Tensor , - Tensor , - Tensor, - Tensor, - Tensor, - Tensor, - Tensor -) { - throw std::runtime_error("CUDA implementation does not exist yet"); -} - -template -void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp_vjp( - Tensor /*grad_grad_output*/, - Tensor /*grad_A_2*/, - Tensor /*grad_B_2*/, - Tensor /*grad_W_2*/, - Tensor /*grad_grad_A*/, - Tensor /*grad_grad_B*/, - Tensor /*grad_grad_W*/, - Tensor /*grad_output*/, - Tensor /*A*/, - Tensor /*B*/, - Tensor /*C*/, - Tensor /*W*/, - Tensor /*indices_A*/, - Tensor /*indices_W_1*/, - Tensor /*indices_W_2*/, - Tensor /*indices_output_1*/, - Tensor /*indices_output_2*/ -) { - throw std::runtime_error("CUDA implementation does not exist yet"); -} diff --git a/mops/src/sasaw/sasaw.cpp b/mops/src/sasaw/sasaw.cpp index 9bc4a68..50525aa 100644 --- a/mops/src/sasaw/sasaw.cpp +++ b/mops/src/sasaw/sasaw.cpp @@ -99,18 +99,16 @@ template void mops::sparse_accumulation_scatter_add_with_weights_vjp_vjp Tensor indices_output_2 ); -#ifdef MOPS_CUDA_ENABLED -#include "cuda.tpp" -#else +#ifndef MOPS_CUDA_ENABLED template void mops::cuda:: - sparse_accumulation_scatter_add_with_weights(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) { + sparse_accumulation_scatter_add_with_weights(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, void*) { throw std::runtime_error("MOPS was not compiled with CUDA support"); } template void mops::cuda:: - sparse_accumulation_scatter_add_with_weights_vjp(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) { + sparse_accumulation_scatter_add_with_weights_vjp(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, void*) { throw std::runtime_error("MOPS was not compiled with CUDA support"); } @@ -132,13 +130,12 @@ void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp_vjp( Tensor /*indices_W_1*/, Tensor /*indices_W_2*/, Tensor /*indices_output_1*/, - Tensor /*indices_output_2*/ + Tensor /*indices_output_2*/, + void* /*cuda_stream*/ ) { throw std::runtime_error("MOPS was not compiled with CUDA support"); } -#endif - // explicit instantiations of CUDA templates template void mops::cuda::sparse_accumulation_scatter_add_with_weights( Tensor output, @@ -150,7 +147,8 @@ template void mops::cuda::sparse_accumulation_scatter_add_with_weights( Tensor indices_W_1, Tensor indices_W_2, Tensor indices_output_1, - Tensor indices_output_2 + Tensor indices_output_2, + void* cuda_stream ); template void mops::cuda::sparse_accumulation_scatter_add_with_weights( @@ -163,7 +161,8 @@ template void mops::cuda::sparse_accumulation_scatter_add_with_weights( Tensor indices_W_1, Tensor indices_W_2, Tensor indices_output_1, - Tensor indices_output_2 + Tensor indices_output_2, + void* cuda_stream ); template void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp( @@ -179,7 +178,8 @@ template void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp indices_W_1, Tensor indices_W_2, Tensor indices_output_1, - Tensor indices_output_2 + Tensor indices_output_2, + void* cuda_stream ); template void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp( @@ -195,7 +195,8 @@ template void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp indices_W_1, Tensor indices_W_2, Tensor indices_output_1, - Tensor indices_output_2 + Tensor indices_output_2, + void* cuda_stream ); template void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp_vjp( @@ -215,7 +216,8 @@ template void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp_vjp indices_W_1, Tensor indices_W_2, Tensor indices_output_1, - Tensor indices_output_2 + Tensor indices_output_2, + void* cuda_stream ); template void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp_vjp( @@ -235,5 +237,8 @@ template void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp_vjp indices_W_1, Tensor indices_W_2, Tensor indices_output_1, - Tensor indices_output_2 + Tensor indices_output_2, + void* cuda_stream ); + +#endif \ No newline at end of file diff --git a/mops/src/sasaw/sasaw.cu b/mops/src/sasaw/sasaw.cu index 8421bac..06b94b1 100644 --- a/mops/src/sasaw/sasaw.cu +++ b/mops/src/sasaw/sasaw.cu @@ -1 +1,148 @@ -// todo: cuda device code +#include + +#include "mops/sasaw.hpp" + +using namespace mops; +using namespace mops::cuda; +using namespace std; + +template +void mops::cuda:: + sparse_accumulation_scatter_add_with_weights(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, void*) { + throw std::runtime_error("MOPS was not compiled with CUDA support"); +} + +template +void mops::cuda:: + sparse_accumulation_scatter_add_with_weights_vjp(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, void*) { + throw std::runtime_error("MOPS was not compiled with CUDA support"); +} + +template +void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp_vjp( + Tensor /*grad_grad_output*/, + Tensor /*grad_A_2*/, + Tensor /*grad_B_2*/, + Tensor /*grad_W_2*/, + Tensor /*grad_grad_A*/, + Tensor /*grad_grad_B*/, + Tensor /*grad_grad_W*/, + Tensor /*grad_output*/, + Tensor /*A*/, + Tensor /*B*/, + Tensor /*C*/, + Tensor /*W*/, + Tensor /*indices_A*/, + Tensor /*indices_W_1*/, + Tensor /*indices_W_2*/, + Tensor /*indices_output_1*/, + Tensor /*indices_output_2*/, + void* /*cuda_stream*/ +) { + throw std::runtime_error("MOPS was not compiled with CUDA support"); +} + +// explicit instantiations of CUDA templates +template void mops::cuda::sparse_accumulation_scatter_add_with_weights( + Tensor output, + Tensor A, + Tensor B, + Tensor C, + Tensor W, + Tensor indices_A, + Tensor indices_W_1, + Tensor indices_W_2, + Tensor indices_output_1, + Tensor indices_output_2, + void* cuda_stream +); + +template void mops::cuda::sparse_accumulation_scatter_add_with_weights( + Tensor output, + Tensor A, + Tensor B, + Tensor C, + Tensor W, + Tensor indices_A, + Tensor indices_W_1, + Tensor indices_W_2, + Tensor indices_output_1, + Tensor indices_output_2, + void* cuda_stream +); + +template void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp( + Tensor grad_A, + Tensor grad_B, + Tensor grad_W, + Tensor grad_output, + Tensor A, + Tensor B, + Tensor C, + Tensor W, + Tensor indices_A, + Tensor indices_W_1, + Tensor indices_W_2, + Tensor indices_output_1, + Tensor indices_output_2, + void* cuda_stream +); + +template void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp( + Tensor grad_A, + Tensor grad_B, + Tensor grad_W, + Tensor grad_output, + Tensor A, + Tensor B, + Tensor C, + Tensor W, + Tensor indices_A, + Tensor indices_W_1, + Tensor indices_W_2, + Tensor indices_output_1, + Tensor indices_output_2, + void* cuda_stream +); + +template void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp_vjp( + Tensor grad_grad_output, + Tensor grad_A_2, + Tensor grad_B_2, + Tensor grad_W_2, + Tensor grad_grad_A, + Tensor grad_grad_B, + Tensor grad_grad_W, + Tensor grad_output, + Tensor A, + Tensor B, + Tensor C, + Tensor W, + Tensor indices_A, + Tensor indices_W_1, + Tensor indices_W_2, + Tensor indices_output_1, + Tensor indices_output_2, + void* cuda_stream +); + +template void mops::cuda::sparse_accumulation_scatter_add_with_weights_vjp_vjp( + Tensor grad_grad_output, + Tensor grad_A_2, + Tensor grad_B_2, + Tensor grad_W_2, + Tensor grad_grad_A, + Tensor grad_grad_B, + Tensor grad_grad_W, + Tensor grad_output, + Tensor A, + Tensor B, + Tensor C, + Tensor W, + Tensor indices_A, + Tensor indices_W_1, + Tensor indices_W_2, + Tensor indices_output_1, + Tensor indices_output_2, + void* cuda_stream +); \ No newline at end of file diff --git a/python/mops-torch/tests/opsa.py b/python/mops-torch/tests/opsa.py index c95b8bb..ee0ff1c 100644 --- a/python/mops-torch/tests/opsa.py +++ b/python/mops-torch/tests/opsa.py @@ -58,15 +58,12 @@ def test_opsa_grads(device): ).values assert torch.autograd.gradcheck( - mops.torch.outer_product_scatter_add, - (A, B, indices, output_size), + mops.torch.outer_product_scatter_add, (A, B, indices, output_size) ) - if device != "cuda": # not yet implemented - assert torch.autograd.gradgradcheck( - mops.torch.outer_product_scatter_add, - (A, B, indices, output_size), - ) + assert torch.autograd.gradgradcheck( + mops.torch.outer_product_scatter_add, (A, B, indices, output_size) + ) def test_opsa_ref():