Skip to content

Commit

Permalink
Add opsa vjp
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Nov 21, 2023
1 parent 8ff2473 commit 2046740
Show file tree
Hide file tree
Showing 18 changed files with 667 additions and 58 deletions.
115 changes: 115 additions & 0 deletions mops-torch/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
cmake_minimum_required(VERSION 3.16)

if (POLICY CMP0077)
# use variables to set OPTIONS
cmake_policy(SET CMP0077 NEW)
endif()


file(READ ${CMAKE_CURRENT_SOURCE_DIR}/VERSION MOPS_TORCH_VERSION)
string(STRIP ${MOPS_TORCH_VERSION} MOPS_TORCH_VERSION)
string(REGEX REPLACE "^([0-9]+)\\..*" "\\1" MOPS_TORCH_VERSION_MAJOR "${MOPS_TORCH_VERSION}")
string(REGEX REPLACE "^[0-9]+\\.([0-9]+).*" "\\1" MOPS_TORCH_VERSION_MINOR "${MOPS_TORCH_VERSION}")
string(REGEX REPLACE "^[0-9]+\\.[0-9]+\\.([0-9]+).*" "\\1" MOPS_TORCH_VERSION_PATCH "${MOPS_TORCH_VERSION}")

project(mops VERSION ${MOPS_TORCH_VERSION} LANGUAGES CXX)

include(CheckLanguage)
check_language(CUDA)

if(CMAKE_CUDA_COMPILER)
enable_language(CUDA)
set(CUDA_USE_STATIC_CUDA_RUNTIME OFF CACHE BOOL "" FORCE)
else()
message(STATUS "Could not find a CUDA compiler")
endif()

set(LIB_INSTALL_DIR "lib" CACHE PATH "Path relative to CMAKE_INSTALL_PREFIX where to install libraries")
set(BIN_INSTALL_DIR "bin" CACHE PATH "Path relative to CMAKE_INSTALL_PREFIX where to install DLL/binaries")
set(INCLUDE_INSTALL_DIR "include" CACHE PATH "Path relative to CMAKE_INSTALL_PREFIX where to install headers")

if (${CMAKE_CURRENT_SOURCE_DIR} STREQUAL ${CMAKE_SOURCE_DIR})
set(MOPS_TORCH_MAIN_PROJECT ON)
else()
set(MOPS_TORCH_MAIN_PROJECT OFF)
endif()

# Set a default build type if none was specified
if (${MOPS_TORCH_MAIN_PROJECT})
if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "")
message(STATUS "Setting build type to 'relwithdebinfo' as none was specified.")
set(
CMAKE_BUILD_TYPE "relwithdebinfo"
CACHE STRING
"Choose the type of build, options are: none(CMAKE_CXX_FLAGS or CMAKE_C_FLAGS used) debug release relwithdebinfo minsizerel."
FORCE
)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS release debug relwithdebinfo minsizerel none)
endif()
endif()

find_package(Torch 1.11 REQUIRED)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../mops mops)

add_library(mops_torch SHARED
"src/register.cpp"
"src/opsa.cpp"

"include/mops/torch.hpp"
"include/mops/torch/opsa.hpp"
)

if(CMAKE_CUDA_COMPILER)
target_compile_definitions(mops_torch PUBLIC MOPS_CUDA_ENABLED)
endif()

target_compile_features(mops_torch PUBLIC cxx_std_17)
target_link_libraries(mops_torch PRIVATE mops)
target_link_libraries(mops_torch PUBLIC torch)

# Create a header defining MOPS_TORCH_EXPORT for exported classes/functions
set_target_properties(mops PROPERTIES
# hide non-exported symbols by default
CXX_VISIBILITY_PRESET hidden
)

target_include_directories(mops_torch PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/include>
$<INSTALL_INTERFACE:include>
)

#------------------------------------------------------------------------------#
# Installation configuration
#------------------------------------------------------------------------------#
include(CMakePackageConfigHelpers)
write_basic_package_version_file(
mops_torch-config-version.cmake
VERSION ${MOPS_TORCH_VERSION}
COMPATIBILITY SameMinorVersion
)

install(TARGETS mops_torch
EXPORT mops_torch-targets
LIBRARY DESTINATION ${LIB_INSTALL_DIR}
ARCHIVE DESTINATION ${LIB_INSTALL_DIR}
RUNTIME DESTINATION ${BIN_INSTALL_DIR}
)

install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/ DESTINATION ${INCLUDE_INSTALL_DIR})
install(DIRECTORY ${PROJECT_BINARY_DIR}/include/ DESTINATION ${INCLUDE_INSTALL_DIR})

# Install files to find mops in CMake projects
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/mops_torch-config.in.cmake
${CMAKE_CURRENT_BINARY_DIR}/mops_torch-config.cmake
@ONLY
)
install(EXPORT mops_torch-targets
DESTINATION ${LIB_INSTALL_DIR}/cmake/mops_torch
)
install(FILES
${PROJECT_BINARY_DIR}/mops_torch-config-version.cmake
${PROJECT_BINARY_DIR}/mops_torch-config.cmake
DESTINATION ${LIB_INSTALL_DIR}/cmake/mops_torch
)
1 change: 1 addition & 0 deletions mops-torch/VERSION
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.1.0
1 change: 1 addition & 0 deletions mops-torch/cmake/mops_torch-config.in.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include(${CMAKE_CURRENT_LIST_DIR}/mops_torch-targets.cmake)
6 changes: 6 additions & 0 deletions mops-torch/include/mops/torch.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#ifndef MOPS_TORCH_H
#define MOPS_TORCH_H

#include "torch/opsa.hpp" // IWYU pragma: export

#endif
36 changes: 36 additions & 0 deletions mops-torch/include/mops/torch/opsa.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#ifndef MOPS_TORCH_OPSA_H
#define MOPS_TORCH_OPSA_H

#include <torch/script.h>

#include <mops.hpp>

namespace mops_torch {

/// TODO
torch::Tensor outer_product_scatter_add(
torch::Tensor A,
torch::Tensor B,
torch::Tensor indices_output,
int64_t output_size
);

class OuterProductScatterAdd: public torch::autograd::Function<mops_torch::OuterProductScatterAdd> {
public:
static std::vector<torch::Tensor> forward(
torch::autograd::AutogradContext *ctx,
torch::Tensor A,
torch::Tensor B,
torch::Tensor indices_output,
int64_t output_size
);

static std::vector<torch::Tensor> backward(
torch::autograd::AutogradContext *ctx,
std::vector<torch::Tensor> grad_outputs
);
};

}

#endif
142 changes: 142 additions & 0 deletions mops-torch/src/opsa.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#include "mops/torch/opsa.hpp"

using namespace mops_torch;


torch::Tensor mops_torch::outer_product_scatter_add(
torch::Tensor A,
torch::Tensor B,
torch::Tensor indices_output,
int64_t output_size
) {
return OuterProductScatterAdd::apply(A, B, indices_output, output_size)[0];
}

template <typename scalar_t>
static mops::Tensor<scalar_t, 1> torch_to_mops_1d(torch::Tensor tensor) {
assert(tensor.sizes().size() == 1);
return {
tensor.data_ptr<scalar_t>(),
{static_cast<size_t>(tensor.size(0))},
};
}

template <typename scalar_t>
static mops::Tensor<scalar_t, 2> torch_to_mops_2d(torch::Tensor tensor) {
assert(tensor.sizes().size() == 2);
return {
tensor.data_ptr<scalar_t>(),
{static_cast<size_t>(tensor.size(0)), static_cast<size_t>(tensor.size(1))},
};
}

std::vector<torch::Tensor> OuterProductScatterAdd::forward(
torch::autograd::AutogradContext *ctx,
torch::Tensor A,
torch::Tensor B,
torch::Tensor indices_output,
int64_t output_size
) {
if (A.sizes().size() != 2 || B.sizes().size() != 2) {
C10_THROW_ERROR(ValueError, "`A` and `B` must be 2-D tensor");
}

if (indices_output.sizes().size() != 1) {
C10_THROW_ERROR(ValueError, "`indices_output` must be a 1-D tensor");
}

if (indices_output.scalar_type() != torch::kInt32) {
C10_THROW_ERROR(ValueError, "`indices_output` must be a tensor of 32-bit integers");
}

if (A.device() != B.device() || A.device() != indices_output.device()) {
C10_THROW_ERROR(ValueError,
"all tensors must be on the same device, got " + A.device().str() +
", " + B.device().str() + ", and " + indices_output.device().str()
);
}

if (A.scalar_type() != B.scalar_type()) {
C10_THROW_ERROR(ValueError,
std::string("`A` and `B` must have the same dtype, got ") +
torch::toString(A.scalar_type()) + " and " + torch::toString(B.scalar_type())
);
}

torch::Tensor output;
if (A.device().is_cpu()) {
output = torch::zeros({output_size, A.size(1), B.size(1)},
torch::TensorOptions().dtype(A.scalar_type()).device(A.device())
);

assert(output.is_contiguous());

AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "outer_product_scatter_add", [&](){
mops::outer_product_scatter_add<scalar_t>(
torch_to_mops_2d<scalar_t>(output.reshape({-1, output.size(1) * output.size(2)})),
torch_to_mops_2d<scalar_t>(A),
torch_to_mops_2d<scalar_t>(B),
torch_to_mops_1d<int32_t>(indices_output)
);
});
} else {
C10_THROW_ERROR(ValueError,
"outer_product_scatter_add is not implemented for device " + A.device().str()
);
}

if (A.requires_grad() || B.requires_grad()) {
ctx->save_for_backward({A, B, indices_output});
}

return {output};
}

std::vector<torch::Tensor> OuterProductScatterAdd::backward(
torch::autograd::AutogradContext *ctx,
std::vector<torch::Tensor> grad_outputs
) {
auto saved_variables = ctx->get_saved_variables();
auto A = saved_variables[0];
auto B = saved_variables[1];
auto indices_output = saved_variables[2];

auto grad_output = grad_outputs[0];
if (!grad_output.is_contiguous()) {
throw std::runtime_error("expected contiguous grad_output");
}

auto grad_A = torch::Tensor();
auto grad_B = torch::Tensor();

if (A.device().is_cpu()) {
AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "outer_product_scatter_add_vjp", [&](){
auto mops_grad_A = mops::Tensor<scalar_t, 2>{nullptr, {0, 0}};
if (A.requires_grad()) {
grad_A = torch::zeros_like(A);
mops_grad_A = torch_to_mops_2d<scalar_t>(grad_A);
}

auto mops_grad_B = mops::Tensor<scalar_t, 2>{nullptr, {0, 0}};
if (B.requires_grad()) {
grad_B = torch::zeros_like(B);
mops_grad_B = torch_to_mops_2d<scalar_t>(grad_B);
}

mops::outer_product_scatter_add_vjp<scalar_t>(
mops_grad_A,
mops_grad_B,
torch_to_mops_2d<scalar_t>(grad_output.reshape({-1, grad_output.size(1) * grad_output.size(2)})),
torch_to_mops_2d<scalar_t>(A),
torch_to_mops_2d<scalar_t>(B),
torch_to_mops_1d<int32_t>(indices_output)
);
});
} else {
C10_THROW_ERROR(ValueError,
"outer_product_scatter_add is not implemented for device " + A.device().str()
);
}

return {grad_A, grad_B, torch::Tensor(), torch::Tensor()};
}
7 changes: 7 additions & 0 deletions mops-torch/src/register.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include <torch/script.h>

#include "mops/torch/opsa.hpp"

TORCH_LIBRARY(mops, m) {
m.def("outer_product_scatter_add", mops_torch::outer_product_scatter_add);
}
21 changes: 21 additions & 0 deletions mops/include/mops/opsa.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,27 @@ int MOPS_EXPORT mops_outer_product_scatter_add_f64(
mops_tensor_1d_i32_t indices_output
);

/// CPU version of mops::outer_product_scatter_add_vjp for 32-bit floats
int MOPS_EXPORT mops_outer_product_scatter_add_vjp_f32(
mops_tensor_2d_f32_t grad_A,
mops_tensor_2d_f32_t grad_B,
mops_tensor_2d_f32_t grad_output,
mops_tensor_2d_f32_t A,
mops_tensor_2d_f32_t B,
mops_tensor_1d_i32_t indices_output
);


/// CPU version of mops::outer_product_scatter_add_vjp for 64-bit floats
int MOPS_EXPORT mops_outer_product_scatter_add_vjp_f64(
mops_tensor_2d_f64_t grad_A,
mops_tensor_2d_f64_t grad_B,
mops_tensor_2d_f64_t grad_output,
mops_tensor_2d_f64_t A,
mops_tensor_2d_f64_t B,
mops_tensor_1d_i32_t indices_output
);


/// CUDA version of mops::outer_product_scatter_add for 32-bit floats
int MOPS_EXPORT mops_cuda_outer_product_scatter_add_f32(
Expand Down
38 changes: 38 additions & 0 deletions mops/include/mops/opsa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,44 @@ namespace mops {
Tensor<int32_t, 1> indices_output
);

/// Vector-Jacobian product for `outer_product_scatter_add` (i.e. backward
/// propagation of gradients)
///
/// `grad_A` and `grad_B` are the outputs of this function, and should have
/// the same shape as `A` and `B`. If you don't need one of these gradients,
/// set the corresponding `.data` pointer to `NULL`.
///
/// `grad_output` should have the same shape as `output` in
/// `outer_product_scatter_add`.
template<typename scalar_t>
void MOPS_EXPORT outer_product_scatter_add_vjp(
Tensor<scalar_t, 2> grad_A,
Tensor<scalar_t, 2> grad_B,
Tensor<scalar_t, 2> grad_output,
Tensor<scalar_t, 2> A,
Tensor<scalar_t, 2> B,
Tensor<int32_t, 1> indices_output
);

// these templates will be precompiled and provided in the mops library
extern template void outer_product_scatter_add_vjp(
Tensor<float, 2> grad_A,
Tensor<float, 2> grad_B,
Tensor<float, 2> grad_output,
Tensor<float, 2> A,
Tensor<float, 2> B,
Tensor<int32_t, 1> indexes
);

extern template void outer_product_scatter_add_vjp(
Tensor<double, 2> grad_A,
Tensor<double, 2> grad_B,
Tensor<double, 2> grad_output,
Tensor<double, 2> A,
Tensor<double, 2> B,
Tensor<int32_t, 1> indexes
);

namespace cuda {
/// CUDA version of mops::outer_product_scatter_add
template<typename scalar_t>
Expand Down
Loading

0 comments on commit 2046740

Please sign in to comment.