diff --git a/mops-torch/CMakeLists.txt b/mops-torch/CMakeLists.txt new file mode 100644 index 0000000..100e87c --- /dev/null +++ b/mops-torch/CMakeLists.txt @@ -0,0 +1,115 @@ +cmake_minimum_required(VERSION 3.16) + +if (POLICY CMP0077) + # use variables to set OPTIONS + cmake_policy(SET CMP0077 NEW) +endif() + + +file(READ ${CMAKE_CURRENT_SOURCE_DIR}/VERSION MOPS_TORCH_VERSION) +string(STRIP ${MOPS_TORCH_VERSION} MOPS_TORCH_VERSION) +string(REGEX REPLACE "^([0-9]+)\\..*" "\\1" MOPS_TORCH_VERSION_MAJOR "${MOPS_TORCH_VERSION}") +string(REGEX REPLACE "^[0-9]+\\.([0-9]+).*" "\\1" MOPS_TORCH_VERSION_MINOR "${MOPS_TORCH_VERSION}") +string(REGEX REPLACE "^[0-9]+\\.[0-9]+\\.([0-9]+).*" "\\1" MOPS_TORCH_VERSION_PATCH "${MOPS_TORCH_VERSION}") + +project(mops VERSION ${MOPS_TORCH_VERSION} LANGUAGES CXX) + +include(CheckLanguage) +check_language(CUDA) + +if(CMAKE_CUDA_COMPILER) + enable_language(CUDA) + set(CUDA_USE_STATIC_CUDA_RUNTIME OFF CACHE BOOL "" FORCE) +else() + message(STATUS "Could not find a CUDA compiler") +endif() + +set(LIB_INSTALL_DIR "lib" CACHE PATH "Path relative to CMAKE_INSTALL_PREFIX where to install libraries") +set(BIN_INSTALL_DIR "bin" CACHE PATH "Path relative to CMAKE_INSTALL_PREFIX where to install DLL/binaries") +set(INCLUDE_INSTALL_DIR "include" CACHE PATH "Path relative to CMAKE_INSTALL_PREFIX where to install headers") + +if (${CMAKE_CURRENT_SOURCE_DIR} STREQUAL ${CMAKE_SOURCE_DIR}) + set(MOPS_TORCH_MAIN_PROJECT ON) +else() + set(MOPS_TORCH_MAIN_PROJECT OFF) +endif() + +# Set a default build type if none was specified +if (${MOPS_TORCH_MAIN_PROJECT}) + if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "") + message(STATUS "Setting build type to 'relwithdebinfo' as none was specified.") + set( + CMAKE_BUILD_TYPE "relwithdebinfo" + CACHE STRING + "Choose the type of build, options are: none(CMAKE_CXX_FLAGS or CMAKE_C_FLAGS used) debug release relwithdebinfo minsizerel." + FORCE + ) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS release debug relwithdebinfo minsizerel none) + endif() +endif() + +find_package(Torch 1.11 REQUIRED) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../mops mops) + +add_library(mops_torch SHARED + "src/register.cpp" + "src/opsa.cpp" + + "include/mops/torch.hpp" + "include/mops/torch/opsa.hpp" +) + +if(CMAKE_CUDA_COMPILER) + target_compile_definitions(mops_torch PUBLIC MOPS_CUDA_ENABLED) +endif() + +target_compile_features(mops_torch PUBLIC cxx_std_17) +target_link_libraries(mops_torch PRIVATE mops) +target_link_libraries(mops_torch PUBLIC torch) + +# Create a header defining MOPS_TORCH_EXPORT for exported classes/functions +set_target_properties(mops PROPERTIES + # hide non-exported symbols by default + CXX_VISIBILITY_PRESET hidden +) + +target_include_directories(mops_torch PUBLIC + $ + $ + $ +) + +#------------------------------------------------------------------------------# +# Installation configuration +#------------------------------------------------------------------------------# +include(CMakePackageConfigHelpers) +write_basic_package_version_file( + mops_torch-config-version.cmake + VERSION ${MOPS_TORCH_VERSION} + COMPATIBILITY SameMinorVersion +) + +install(TARGETS mops_torch + EXPORT mops_torch-targets + LIBRARY DESTINATION ${LIB_INSTALL_DIR} + ARCHIVE DESTINATION ${LIB_INSTALL_DIR} + RUNTIME DESTINATION ${BIN_INSTALL_DIR} +) + +install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/ DESTINATION ${INCLUDE_INSTALL_DIR}) +install(DIRECTORY ${PROJECT_BINARY_DIR}/include/ DESTINATION ${INCLUDE_INSTALL_DIR}) + +# Install files to find mops in CMake projects +configure_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/mops_torch-config.in.cmake + ${CMAKE_CURRENT_BINARY_DIR}/mops_torch-config.cmake + @ONLY +) +install(EXPORT mops_torch-targets + DESTINATION ${LIB_INSTALL_DIR}/cmake/mops_torch +) +install(FILES + ${PROJECT_BINARY_DIR}/mops_torch-config-version.cmake + ${PROJECT_BINARY_DIR}/mops_torch-config.cmake + DESTINATION ${LIB_INSTALL_DIR}/cmake/mops_torch +) diff --git a/mops-torch/VERSION b/mops-torch/VERSION new file mode 100644 index 0000000..6e8bf73 --- /dev/null +++ b/mops-torch/VERSION @@ -0,0 +1 @@ +0.1.0 diff --git a/mops-torch/cmake/mops_torch-config.in.cmake b/mops-torch/cmake/mops_torch-config.in.cmake new file mode 100644 index 0000000..d01eac5 --- /dev/null +++ b/mops-torch/cmake/mops_torch-config.in.cmake @@ -0,0 +1 @@ +include(${CMAKE_CURRENT_LIST_DIR}/mops_torch-targets.cmake) diff --git a/mops-torch/include/mops/torch.hpp b/mops-torch/include/mops/torch.hpp new file mode 100644 index 0000000..0525b7c --- /dev/null +++ b/mops-torch/include/mops/torch.hpp @@ -0,0 +1,6 @@ +#ifndef MOPS_TORCH_H +#define MOPS_TORCH_H + +#include "torch/opsa.hpp" // IWYU pragma: export + +#endif diff --git a/mops-torch/include/mops/torch/opsa.hpp b/mops-torch/include/mops/torch/opsa.hpp new file mode 100644 index 0000000..fb1ab93 --- /dev/null +++ b/mops-torch/include/mops/torch/opsa.hpp @@ -0,0 +1,36 @@ +#ifndef MOPS_TORCH_OPSA_H +#define MOPS_TORCH_OPSA_H + +#include + +#include + +namespace mops_torch { + +/// TODO +torch::Tensor outer_product_scatter_add( + torch::Tensor A, + torch::Tensor B, + torch::Tensor indices_output, + int64_t output_size +); + +class OuterProductScatterAdd: public torch::autograd::Function { +public: + static std::vector forward( + torch::autograd::AutogradContext *ctx, + torch::Tensor A, + torch::Tensor B, + torch::Tensor indices_output, + int64_t output_size + ); + + static std::vector backward( + torch::autograd::AutogradContext *ctx, + std::vector grad_outputs + ); +}; + +} + +#endif diff --git a/mops-torch/src/opsa.cpp b/mops-torch/src/opsa.cpp new file mode 100644 index 0000000..817c270 --- /dev/null +++ b/mops-torch/src/opsa.cpp @@ -0,0 +1,142 @@ +#include "mops/torch/opsa.hpp" + +using namespace mops_torch; + + +torch::Tensor mops_torch::outer_product_scatter_add( + torch::Tensor A, + torch::Tensor B, + torch::Tensor indices_output, + int64_t output_size +) { + return OuterProductScatterAdd::apply(A, B, indices_output, output_size)[0]; +} + +template +static mops::Tensor torch_to_mops_1d(torch::Tensor tensor) { + assert(tensor.sizes().size() == 1); + return { + tensor.data_ptr(), + {static_cast(tensor.size(0))}, + }; +} + +template +static mops::Tensor torch_to_mops_2d(torch::Tensor tensor) { + assert(tensor.sizes().size() == 2); + return { + tensor.data_ptr(), + {static_cast(tensor.size(0)), static_cast(tensor.size(1))}, + }; +} + +std::vector OuterProductScatterAdd::forward( + torch::autograd::AutogradContext *ctx, + torch::Tensor A, + torch::Tensor B, + torch::Tensor indices_output, + int64_t output_size +) { + if (A.sizes().size() != 2 || B.sizes().size() != 2) { + C10_THROW_ERROR(ValueError, "`A` and `B` must be 2-D tensor"); + } + + if (indices_output.sizes().size() != 1) { + C10_THROW_ERROR(ValueError, "`indices_output` must be a 1-D tensor"); + } + + if (indices_output.scalar_type() != torch::kInt32) { + C10_THROW_ERROR(ValueError, "`indices_output` must be a tensor of 32-bit integers"); + } + + if (A.device() != B.device() || A.device() != indices_output.device()) { + C10_THROW_ERROR(ValueError, + "all tensors must be on the same device, got " + A.device().str() + + ", " + B.device().str() + ", and " + indices_output.device().str() + ); + } + + if (A.scalar_type() != B.scalar_type()) { + C10_THROW_ERROR(ValueError, + std::string("`A` and `B` must have the same dtype, got ") + + torch::toString(A.scalar_type()) + " and " + torch::toString(B.scalar_type()) + ); + } + + torch::Tensor output; + if (A.device().is_cpu()) { + output = torch::zeros({output_size, A.size(1), B.size(1)}, + torch::TensorOptions().dtype(A.scalar_type()).device(A.device()) + ); + + assert(output.is_contiguous()); + + AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "outer_product_scatter_add", [&](){ + mops::outer_product_scatter_add( + torch_to_mops_2d(output.reshape({-1, output.size(1) * output.size(2)})), + torch_to_mops_2d(A), + torch_to_mops_2d(B), + torch_to_mops_1d(indices_output) + ); + }); + } else { + C10_THROW_ERROR(ValueError, + "outer_product_scatter_add is not implemented for device " + A.device().str() + ); + } + + if (A.requires_grad() || B.requires_grad()) { + ctx->save_for_backward({A, B, indices_output}); + } + + return {output}; +} + +std::vector OuterProductScatterAdd::backward( + torch::autograd::AutogradContext *ctx, + std::vector grad_outputs +) { + auto saved_variables = ctx->get_saved_variables(); + auto A = saved_variables[0]; + auto B = saved_variables[1]; + auto indices_output = saved_variables[2]; + + auto grad_output = grad_outputs[0]; + if (!grad_output.is_contiguous()) { + throw std::runtime_error("expected contiguous grad_output"); + } + + auto grad_A = torch::Tensor(); + auto grad_B = torch::Tensor(); + + if (A.device().is_cpu()) { + AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "outer_product_scatter_add_vjp", [&](){ + auto mops_grad_A = mops::Tensor{nullptr, {0, 0}}; + if (A.requires_grad()) { + grad_A = torch::zeros_like(A); + mops_grad_A = torch_to_mops_2d(grad_A); + } + + auto mops_grad_B = mops::Tensor{nullptr, {0, 0}}; + if (B.requires_grad()) { + grad_B = torch::zeros_like(B); + mops_grad_B = torch_to_mops_2d(grad_B); + } + + mops::outer_product_scatter_add_vjp( + mops_grad_A, + mops_grad_B, + torch_to_mops_2d(grad_output.reshape({-1, grad_output.size(1) * grad_output.size(2)})), + torch_to_mops_2d(A), + torch_to_mops_2d(B), + torch_to_mops_1d(indices_output) + ); + }); + } else { + C10_THROW_ERROR(ValueError, + "outer_product_scatter_add is not implemented for device " + A.device().str() + ); + } + + return {grad_A, grad_B, torch::Tensor(), torch::Tensor()}; +} diff --git a/mops-torch/src/register.cpp b/mops-torch/src/register.cpp new file mode 100644 index 0000000..5d262be --- /dev/null +++ b/mops-torch/src/register.cpp @@ -0,0 +1,7 @@ +#include + +#include "mops/torch/opsa.hpp" + +TORCH_LIBRARY(mops, m) { + m.def("outer_product_scatter_add", mops_torch::outer_product_scatter_add); +} diff --git a/mops/include/mops/opsa.h b/mops/include/mops/opsa.h index 4eda5dd..91bc165 100644 --- a/mops/include/mops/opsa.h +++ b/mops/include/mops/opsa.h @@ -26,6 +26,27 @@ int MOPS_EXPORT mops_outer_product_scatter_add_f64( mops_tensor_1d_i32_t indices_output ); +/// CPU version of mops::outer_product_scatter_add_vjp for 32-bit floats +int MOPS_EXPORT mops_outer_product_scatter_add_vjp_f32( + mops_tensor_2d_f32_t grad_A, + mops_tensor_2d_f32_t grad_B, + mops_tensor_2d_f32_t grad_output, + mops_tensor_2d_f32_t A, + mops_tensor_2d_f32_t B, + mops_tensor_1d_i32_t indices_output +); + + +/// CPU version of mops::outer_product_scatter_add_vjp for 64-bit floats +int MOPS_EXPORT mops_outer_product_scatter_add_vjp_f64( + mops_tensor_2d_f64_t grad_A, + mops_tensor_2d_f64_t grad_B, + mops_tensor_2d_f64_t grad_output, + mops_tensor_2d_f64_t A, + mops_tensor_2d_f64_t B, + mops_tensor_1d_i32_t indices_output +); + /// CUDA version of mops::outer_product_scatter_add for 32-bit floats int MOPS_EXPORT mops_cuda_outer_product_scatter_add_f32( diff --git a/mops/include/mops/opsa.hpp b/mops/include/mops/opsa.hpp index d17f6a0..46e6664 100644 --- a/mops/include/mops/opsa.hpp +++ b/mops/include/mops/opsa.hpp @@ -32,6 +32,44 @@ namespace mops { Tensor indices_output ); + /// Vector-Jacobian product for `outer_product_scatter_add` (i.e. backward + /// propagation of gradients) + /// + /// `grad_A` and `grad_B` are the outputs of this function, and should have + /// the same shape as `A` and `B`. If you don't need one of these gradients, + /// set the corresponding `.data` pointer to `NULL`. + /// + /// `grad_output` should have the same shape as `output` in + /// `outer_product_scatter_add`. + template + void MOPS_EXPORT outer_product_scatter_add_vjp( + Tensor grad_A, + Tensor grad_B, + Tensor grad_output, + Tensor A, + Tensor B, + Tensor indices_output + ); + + // these templates will be precompiled and provided in the mops library + extern template void outer_product_scatter_add_vjp( + Tensor grad_A, + Tensor grad_B, + Tensor grad_output, + Tensor A, + Tensor B, + Tensor indexes + ); + + extern template void outer_product_scatter_add_vjp( + Tensor grad_A, + Tensor grad_B, + Tensor grad_output, + Tensor A, + Tensor B, + Tensor indexes + ); + namespace cuda { /// CUDA version of mops::outer_product_scatter_add template diff --git a/mops/src/opsa/capi.cpp b/mops/src/opsa/capi.cpp index 5b611fc..4b83d8f 100644 --- a/mops/src/opsa/capi.cpp +++ b/mops/src/opsa/capi.cpp @@ -46,6 +46,47 @@ extern "C" int mops_outer_product_scatter_add_f64( ); } +extern "C" int mops_outer_product_scatter_add_vjp_f32( + mops_tensor_2d_f32_t grad_A, + mops_tensor_2d_f32_t grad_B, + mops_tensor_2d_f32_t grad_output, + mops_tensor_2d_f32_t A, + mops_tensor_2d_f32_t B, + mops_tensor_1d_i32_t indices_output +) { + MOPS_CATCH_EXCEPTIONS( + mops::outer_product_scatter_add_vjp( + {grad_A.data, {checked_cast(grad_A.shape[0]), checked_cast(grad_A.shape[1])}}, + {grad_B.data, {checked_cast(grad_B.shape[0]), checked_cast(grad_B.shape[1])}}, + {grad_output.data, {checked_cast(grad_output.shape[0]), checked_cast(grad_output.shape[1])}}, + {A.data, {checked_cast(A.shape[0]), checked_cast(A.shape[1])}}, + {B.data, {checked_cast(B.shape[0]), checked_cast(B.shape[1])}}, + {indices_output.data, {checked_cast(indices_output.shape[0])}} + ); + ); +} + + +extern "C" int mops_outer_product_scatter_add_vjp_f64( + mops_tensor_2d_f64_t grad_A, + mops_tensor_2d_f64_t grad_B, + mops_tensor_2d_f64_t grad_output, + mops_tensor_2d_f64_t A, + mops_tensor_2d_f64_t B, + mops_tensor_1d_i32_t indices_output +) { + MOPS_CATCH_EXCEPTIONS( + mops::outer_product_scatter_add_vjp( + {grad_A.data, {checked_cast(grad_A.shape[0]), checked_cast(grad_A.shape[1])}}, + {grad_B.data, {checked_cast(grad_B.shape[0]), checked_cast(grad_B.shape[1])}}, + {grad_output.data, {checked_cast(grad_output.shape[0]), checked_cast(grad_output.shape[1])}}, + {A.data, {checked_cast(A.shape[0]), checked_cast(A.shape[1])}}, + {B.data, {checked_cast(B.shape[0]), checked_cast(B.shape[1])}}, + {indices_output.data, {checked_cast(indices_output.shape[0])}} + ); + ); +} + extern "C" int mops_cuda_outer_product_scatter_add_f32( mops_tensor_2d_f32_t output, diff --git a/mops/src/opsa/cpu.tpp b/mops/src/opsa/cpu.tpp index 780c0e6..6af19e2 100644 --- a/mops/src/opsa/cpu.tpp +++ b/mops/src/opsa/cpu.tpp @@ -5,9 +5,10 @@ #include "mops/opsa.hpp" +using namespace mops; + template -void mops::outer_product_scatter_add( - Tensor output, +static void check_inputs_shape( Tensor A, Tensor B, Tensor indices_output @@ -27,6 +28,16 @@ void mops::outer_product_scatter_add( " and " + std::to_string(A.shape[0]) ); } +} + +template +void mops::outer_product_scatter_add( + Tensor output, + Tensor A, + Tensor B, + Tensor indices_output +) { + check_inputs_shape(A, B, indices_output); if (A.shape[1] * B.shape[1] != output.shape[1]) { throw std::runtime_error( @@ -52,44 +63,66 @@ void mops::outer_product_scatter_add( } } +template +void mops::outer_product_scatter_add_vjp( + Tensor grad_A, + Tensor grad_B, + Tensor grad_output, + Tensor A, + Tensor B, + Tensor indices_output +) { + check_inputs_shape(A, B, indices_output); + if (A.shape[1] * B.shape[1] != grad_output.shape[1]) { + throw std::runtime_error( + "`grad_output` tensor must have " + std::to_string(A.shape[1] * B.shape[1]) + + " elements in the second dimension, got " + std::to_string(grad_output.shape[1]) + ); + } + if (grad_A.data != nullptr) { + if (A.shape[0] != grad_A.shape[0] || A.shape[1] != grad_A.shape[1]) { + throw std::runtime_error( + "A and grad_A tensors must have the same shape" + ); + } -// if (A.requires_grad()) { -// grad_a = torch::zeros_like(A); -// scalar_t* grad_a_ptr = grad_a.data_ptr(); - -// #pragma omp parallel for -// for (long idx_out = 0; idx_out < out_dim; idx_out++) { -// long idx_in = first_occurrences_ptr[idx_out]; -// if (idx_in < 0) continue; -// while (scatter_indices_ptr[idx_in] == idx_out) { -// for (long idx_a = 0; idx_a < size_a; idx_a++) { -// for (long idx_b = 0; idx_b < size_b; idx_b++) { -// grad_a_ptr[size_a*idx_in+idx_a] += grad_output_ptr[size_a*size_b*idx_out+size_b*idx_a+idx_b] * tensor_b_ptr[size_b*idx_in+idx_b]; -// } -// } -// idx_in++; -// if (idx_in == size_scatter) break; -// } -// } -// } + for (size_t i=0; i(); + if (grad_B.data != nullptr) { + if (B.shape[0] != grad_B.shape[0] || B.shape[1] != grad_B.shape[1]) { + throw std::runtime_error( + "B and grad_B tensors must have the same shape" + ); + } -// #pragma omp parallel for -// for (long idx_out = 0; idx_out < out_dim; idx_out++) { -// long idx_in = first_occurrences_ptr[idx_out]; -// if (idx_in < 0) continue; -// while (scatter_indices_ptr[idx_in] == idx_out) { -// for (long idx_a = 0; idx_a < size_a; idx_a++) { -// for (long idx_b = 0; idx_b < size_b; idx_b++) { -// grad_b_ptr[size_b*idx_in+idx_b] += grad_output_ptr[size_a*size_b*idx_out+size_b*idx_a+idx_b] * tensor_a_ptr[size_a*idx_in+idx_a]; -// } -// } -// idx_in++; -// if (idx_in == size_scatter) break; -// } -// } -// } + for (size_t i=0; i( Tensor indices_output ); +template void mops::outer_product_scatter_add_vjp( + Tensor grad_tensor_a, + Tensor grad_tensor_b, + Tensor grad_output, + Tensor tensor_a, + Tensor tensor_b, + Tensor indexe +); + +template void mops::outer_product_scatter_add_vjp( + Tensor grad_tensor_a, + Tensor grad_tensor_b, + Tensor grad_output, + Tensor tensor_a, + Tensor tensor_b, + Tensor indexe +); + #ifdef MOPS_CUDA_ENABLED #include "cuda.tpp" diff --git a/python/mops/src/mops/__init__.py b/python/mops/src/mops/__init__.py index e48f08a..9ed6ac9 100644 --- a/python/mops/src/mops/__init__.py +++ b/python/mops/src/mops/__init__.py @@ -1,2 +1,2 @@ -from .opsa import outer_product_scatter_add # noqa +from .opsa import outer_product_scatter_add, outer_product_scatter_add_vjp # noqa from .sap import sparse_accumulation_of_products # noqa diff --git a/python/mops/src/mops/_c_api.py b/python/mops/src/mops/_c_api.py index c791ca3..536ff21 100644 --- a/python/mops/src/mops/_c_api.py +++ b/python/mops/src/mops/_c_api.py @@ -59,6 +59,26 @@ def setup_functions(lib): ] lib.mops_outer_product_scatter_add_f64.restype = _check_status + lib.mops_outer_product_scatter_add_vjp_f32.argtypes = [ + mops_tensor_2d_f32_t, + mops_tensor_2d_f32_t, + mops_tensor_2d_f32_t, + mops_tensor_2d_f32_t, + mops_tensor_2d_f32_t, + mops_tensor_1d_i32_t, + ] + lib.mops_outer_product_scatter_add_vjp_f32.restype = _check_status + + lib.mops_outer_product_scatter_add_vjp_f64.argtypes = [ + mops_tensor_2d_f64_t, + mops_tensor_2d_f64_t, + mops_tensor_2d_f64_t, + mops_tensor_2d_f64_t, + mops_tensor_2d_f64_t, + mops_tensor_1d_i32_t, + ] + lib.mops_outer_product_scatter_add_vjp_f64.restype = _check_status + lib.mops_cuda_outer_product_scatter_add_f32.argtypes = [ mops_tensor_2d_f32_t, mops_tensor_2d_f32_t, diff --git a/python/mops/src/mops/opsa.py b/python/mops/src/mops/opsa.py index 33b41c8..35b5dfc 100644 --- a/python/mops/src/mops/opsa.py +++ b/python/mops/src/mops/opsa.py @@ -2,7 +2,7 @@ from ._c_lib import _get_library from .checks import _check_opsa -from .utils import numpy_to_mops_tensor +from .utils import null_mops_tensor_like, numpy_to_mops_tensor def outer_product_scatter_add(A, B, indices_output, output_size): @@ -49,3 +49,78 @@ def outer_product_scatter_add(A, B, indices_output, output_size): ) return output.reshape(-1, A.shape[1], B.shape[1]) + + +def outer_product_scatter_add_vjp( + grad_output, + A, + B, + P, + compute_grad_A=False, + compute_grad_B=False, +): + grad_output = np.ascontiguousarray(grad_output) + A = np.ascontiguousarray(A) + B = np.ascontiguousarray(B) + indices = np.ascontiguousarray(P) + + if A.dtype != B.dtype or A.dtype != grad_output.dtype: + raise TypeError("A, B and grad_output must have the same dtype") + + if len(A.shape) != 2 or len(B.shape) != 2: + raise TypeError("A and B must be 2-dimensional arrays") + + if len(grad_output.shape) != 3: + raise TypeError("grad_output must be 3-dimensional arrays") + + if not np.can_cast(indices, np.int32, "same_kind"): + raise TypeError("indices must be an array of integers") + + indices = indices.astype(np.int32) + + if len(indices.shape) != 1: + raise TypeError("indices must be 1-dimensional arrays") + + if A.shape[0] != B.shape[0] or A.shape[0] != indices.shape[0]: + raise TypeError( + "A, B and indices must have the same number of elements on the " + "first dimension" + ) + + if compute_grad_A: + grad_A = np.zeros_like(A) + mops_grad_A = numpy_to_mops_tensor(grad_A) + else: + grad_A = None + mops_grad_A = null_mops_tensor_like(A) + + if compute_grad_B: + grad_B = np.zeros_like(B) + mops_grad_B = numpy_to_mops_tensor(grad_B) + else: + grad_B = None + mops_grad_B = null_mops_tensor_like(B) + + lib = _get_library() + + if A.dtype == np.float32: + function = lib.mops_outer_product_scatter_add_vjp_f32 + elif A.dtype == np.float64: + function = lib.mops_outer_product_scatter_add_vjp_f64 + else: + raise TypeError( + "Unsupported dtype detected. Only float32 and float64 are supported" + ) + + function( + mops_grad_A, + mops_grad_B, + numpy_to_mops_tensor( + grad_output.reshape(-1, grad_output.shape[1] * grad_output.shape[2]) + ), + numpy_to_mops_tensor(A), + numpy_to_mops_tensor(B), + numpy_to_mops_tensor(indices), + ) + + return grad_A, grad_B diff --git a/python/mops/src/mops/utils.py b/python/mops/src/mops/utils.py index f536211..4924db8 100644 --- a/python/mops/src/mops/utils.py +++ b/python/mops/src/mops/utils.py @@ -52,3 +52,34 @@ def numpy_to_mops_tensor(array): raise TypeError("we can only convert 1D arrays of int32") else: raise TypeError("we can only arrays of int32, float32 or float64") + + +def null_mops_tensor_like(array): + if array.dtype == np.float32: + if len(array.shape) == 2: + tensor = mops_tensor_2d_f32_t() + tensor.data = None + tensor.shape[0] = 0 + tensor.shape[1] = 0 + return tensor + else: + raise TypeError("we can only convert 2D arrays of float32") + elif array.dtype == np.float64: + if len(array.shape) == 2: + tensor = mops_tensor_2d_f64_t() + tensor.data = None + tensor.shape[0] = 0 + tensor.shape[1] = 0 + return tensor + else: + raise TypeError("we can only convert 2D arrays of float64") + elif array.dtype == np.int32: + if len(array.shape) == 1: + tensor = mops_tensor_1d_i32_t() + tensor.data = array.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)) + tensor.shape[0] = array.shape[0] + return tensor + else: + raise TypeError("we can only convert 1D arrays of int32") + else: + raise TypeError("we can only arrays of int32, float32 or float64") diff --git a/python/mops/tests/opsa.py b/python/mops/tests/opsa.py index 58afac8..77ae942 100644 --- a/python/mops/tests/opsa.py +++ b/python/mops/tests/opsa.py @@ -30,6 +30,29 @@ def test_opsa_no_neighbors(): actual = mops.outer_product_scatter_add(A, B, indices, np.max(indices) + 1) assert np.allclose(reference, actual) + # just checking that the jvp runs + grad_A, grad_B = mops.outer_product_scatter_add_vjp( + np.ones_like(actual), + A, + B, + indices, + compute_grad_A=True, + ) + + assert grad_A.shape == A.shape + assert grad_B is None + + grad_A, grad_B = mops.outer_product_scatter_add_vjp( + np.ones_like(actual), + A, + B, + indices, + compute_grad_B=True, + ) + + assert grad_A is None + assert grad_B.shape == B.shape + def test_opsa_wrong_type(): message = ( diff --git a/tox.ini b/tox.ini index bee3f18..faac931 100644 --- a/tox.ini +++ b/tox.ini @@ -5,7 +5,7 @@ min_version = 4.0 envlist = python-tests cxx-tests - ; torch-cxx-tests + torch-cxx-tests [testenv:python-tests] passenv = * @@ -30,6 +30,23 @@ commands = cmake --build {envdir}/build --config Debug ctest --test-dir {envdir}/build --build-config Debug --output-on-failure + +[testenv:torch-cxx-tests] +package = skip +passenv = * +deps = + cmake >= 3.20 + torch + +cmake-options = + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_PREFIX_PATH={env_site_packages_dir}/torch/ + +commands = + cmake {toxinidir}/mops-torch -B {envdir}/build {[testenv:torch-cxx-tests]cmake-options} + cmake --build {envdir}/build --config Debug + + [testenv:build-python] # this environement makes sure one can build sdist and wheels for Python deps = @@ -47,19 +64,3 @@ commands = python -m build . --outdir dist twine check dist/*.tar.gz twine check dist/*.whl - - -; [testenv:torch-cxx-tests] -; package = skip -; passenv = * -; deps = -; cmake >= 3.20 -; torch - -; cmake-options = -; -DCMAKE_BUILD_TYPE=Debug \ -; -DCMAKE_PREFIX_PATH={env_site_packages_dir}/torch/ - -; commands = -; cmake {toxinidir}/mops-torch -B {envdir}/build {[testenv:torch-cxx-tests]cmake-options} -; cmake --build {envdir}/build --config Debug