-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into port-polyeval-cpu
- Loading branch information
Showing
43 changed files
with
1,489 additions
and
519 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
cmake_minimum_required(VERSION 3.16) | ||
|
||
if (POLICY CMP0077) | ||
# use variables to set OPTIONS | ||
cmake_policy(SET CMP0077 NEW) | ||
endif() | ||
|
||
|
||
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/VERSION MOPS_TORCH_VERSION) | ||
string(STRIP ${MOPS_TORCH_VERSION} MOPS_TORCH_VERSION) | ||
string(REGEX REPLACE "^([0-9]+)\\..*" "\\1" MOPS_TORCH_VERSION_MAJOR "${MOPS_TORCH_VERSION}") | ||
string(REGEX REPLACE "^[0-9]+\\.([0-9]+).*" "\\1" MOPS_TORCH_VERSION_MINOR "${MOPS_TORCH_VERSION}") | ||
string(REGEX REPLACE "^[0-9]+\\.[0-9]+\\.([0-9]+).*" "\\1" MOPS_TORCH_VERSION_PATCH "${MOPS_TORCH_VERSION}") | ||
|
||
project(mops VERSION ${MOPS_TORCH_VERSION} LANGUAGES CXX) | ||
|
||
include(CheckLanguage) | ||
check_language(CUDA) | ||
|
||
if(CMAKE_CUDA_COMPILER) | ||
enable_language(CUDA) | ||
set(CUDA_USE_STATIC_CUDA_RUNTIME OFF CACHE BOOL "" FORCE) | ||
else() | ||
message(STATUS "Could not find a CUDA compiler") | ||
endif() | ||
|
||
set(LIB_INSTALL_DIR "lib" CACHE PATH "Path relative to CMAKE_INSTALL_PREFIX where to install libraries") | ||
set(BIN_INSTALL_DIR "bin" CACHE PATH "Path relative to CMAKE_INSTALL_PREFIX where to install DLL/binaries") | ||
set(INCLUDE_INSTALL_DIR "include" CACHE PATH "Path relative to CMAKE_INSTALL_PREFIX where to install headers") | ||
|
||
if (${CMAKE_CURRENT_SOURCE_DIR} STREQUAL ${CMAKE_SOURCE_DIR}) | ||
set(MOPS_TORCH_MAIN_PROJECT ON) | ||
else() | ||
set(MOPS_TORCH_MAIN_PROJECT OFF) | ||
endif() | ||
|
||
# Set a default build type if none was specified | ||
if (${MOPS_TORCH_MAIN_PROJECT}) | ||
if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "") | ||
message(STATUS "Setting build type to 'relwithdebinfo' as none was specified.") | ||
set( | ||
CMAKE_BUILD_TYPE "relwithdebinfo" | ||
CACHE STRING | ||
"Choose the type of build, options are: none(CMAKE_CXX_FLAGS or CMAKE_C_FLAGS used) debug release relwithdebinfo minsizerel." | ||
FORCE | ||
) | ||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS release debug relwithdebinfo minsizerel none) | ||
endif() | ||
endif() | ||
|
||
find_package(Torch 1.11 REQUIRED) | ||
|
||
set(BUILD_SHARED_LIBS OFF) | ||
add_subdirectory(mops EXCLUDE_FROM_ALL) | ||
set_target_properties(mops PROPERTIES POSITION_INDEPENDENT_CODE ON) | ||
|
||
add_library(mops_torch SHARED | ||
"src/register.cpp" | ||
"src/opsa.cpp" | ||
|
||
"include/mops/torch.hpp" | ||
"include/mops/torch/opsa.hpp" | ||
) | ||
|
||
if(CMAKE_CUDA_COMPILER) | ||
target_compile_definitions(mops_torch PUBLIC MOPS_CUDA_ENABLED) | ||
endif() | ||
|
||
target_compile_features(mops_torch PUBLIC cxx_std_17) | ||
target_link_libraries(mops_torch PRIVATE mops) | ||
target_link_libraries(mops_torch PUBLIC torch) | ||
|
||
# Create a header defining MOPS_TORCH_EXPORT for exported classes/functions | ||
set_target_properties(mops PROPERTIES | ||
# hide non-exported symbols by default | ||
CXX_VISIBILITY_PRESET hidden | ||
) | ||
|
||
target_include_directories(mops_torch PUBLIC | ||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include> | ||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/include> | ||
$<INSTALL_INTERFACE:include> | ||
) | ||
|
||
#------------------------------------------------------------------------------# | ||
# Installation configuration | ||
#------------------------------------------------------------------------------# | ||
include(CMakePackageConfigHelpers) | ||
write_basic_package_version_file( | ||
mops_torch-config-version.cmake | ||
VERSION ${MOPS_TORCH_VERSION} | ||
COMPATIBILITY SameMinorVersion | ||
) | ||
|
||
install(TARGETS mops_torch | ||
EXPORT mops_torch-targets | ||
LIBRARY DESTINATION ${LIB_INSTALL_DIR} | ||
ARCHIVE DESTINATION ${LIB_INSTALL_DIR} | ||
RUNTIME DESTINATION ${BIN_INSTALL_DIR} | ||
) | ||
|
||
install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/ DESTINATION ${INCLUDE_INSTALL_DIR}) | ||
|
||
# Install files to find mops in CMake projects | ||
configure_file( | ||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/mops_torch-config.in.cmake | ||
${CMAKE_CURRENT_BINARY_DIR}/mops_torch-config.cmake | ||
@ONLY | ||
) | ||
install(EXPORT mops_torch-targets | ||
DESTINATION ${LIB_INSTALL_DIR}/cmake/mops_torch | ||
) | ||
install(FILES | ||
${PROJECT_BINARY_DIR}/mops_torch-config-version.cmake | ||
${PROJECT_BINARY_DIR}/mops_torch-config.cmake | ||
DESTINATION ${LIB_INSTALL_DIR}/cmake/mops_torch | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
0.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
include(${CMAKE_CURRENT_LIST_DIR}/mops_torch-targets.cmake) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#ifndef MOPS_TORCH_H | ||
#define MOPS_TORCH_H | ||
|
||
#include "torch/opsa.hpp" // IWYU pragma: export | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#ifndef MOPS_TORCH_OPSA_H | ||
#define MOPS_TORCH_OPSA_H | ||
|
||
#include <torch/script.h> | ||
|
||
#include <mops.hpp> | ||
|
||
namespace mops_torch { | ||
|
||
/// TODO | ||
torch::Tensor outer_product_scatter_add( | ||
torch::Tensor A, | ||
torch::Tensor B, | ||
torch::Tensor indices_output, | ||
int64_t output_size | ||
); | ||
|
||
class OuterProductScatterAdd: public torch::autograd::Function<mops_torch::OuterProductScatterAdd> { | ||
public: | ||
static std::vector<torch::Tensor> forward( | ||
torch::autograd::AutogradContext *ctx, | ||
torch::Tensor A, | ||
torch::Tensor B, | ||
torch::Tensor indices_output, | ||
int64_t output_size | ||
); | ||
|
||
static std::vector<torch::Tensor> backward( | ||
torch::autograd::AutogradContext *ctx, | ||
std::vector<torch::Tensor> grad_outputs | ||
); | ||
}; | ||
|
||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../mops |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
#include "mops/torch/opsa.hpp" | ||
|
||
using namespace mops_torch; | ||
|
||
|
||
torch::Tensor mops_torch::outer_product_scatter_add( | ||
torch::Tensor A, | ||
torch::Tensor B, | ||
torch::Tensor indices_output, | ||
int64_t output_size | ||
) { | ||
return OuterProductScatterAdd::apply(A, B, indices_output, output_size)[0]; | ||
} | ||
|
||
template <typename scalar_t> | ||
static mops::Tensor<scalar_t, 1> torch_to_mops_1d(torch::Tensor tensor) { | ||
assert(tensor.sizes().size() == 1); | ||
return { | ||
tensor.data_ptr<scalar_t>(), | ||
{static_cast<size_t>(tensor.size(0))}, | ||
}; | ||
} | ||
|
||
template <typename scalar_t> | ||
static mops::Tensor<scalar_t, 2> torch_to_mops_2d(torch::Tensor tensor) { | ||
assert(tensor.sizes().size() == 2); | ||
return { | ||
tensor.data_ptr<scalar_t>(), | ||
{static_cast<size_t>(tensor.size(0)), static_cast<size_t>(tensor.size(1))}, | ||
}; | ||
} | ||
|
||
std::vector<torch::Tensor> OuterProductScatterAdd::forward( | ||
torch::autograd::AutogradContext *ctx, | ||
torch::Tensor A, | ||
torch::Tensor B, | ||
torch::Tensor indices_output, | ||
int64_t output_size | ||
) { | ||
if (A.sizes().size() != 2 || B.sizes().size() != 2) { | ||
C10_THROW_ERROR(ValueError, "`A` and `B` must be 2-D tensor"); | ||
} | ||
|
||
if (indices_output.sizes().size() != 1) { | ||
C10_THROW_ERROR(ValueError, "`indices_output` must be a 1-D tensor"); | ||
} | ||
|
||
if (indices_output.scalar_type() != torch::kInt32) { | ||
C10_THROW_ERROR(ValueError, "`indices_output` must be a tensor of 32-bit integers"); | ||
} | ||
|
||
if (A.device() != B.device() || A.device() != indices_output.device()) { | ||
C10_THROW_ERROR(ValueError, | ||
"all tensors must be on the same device, got " + A.device().str() + | ||
", " + B.device().str() + ", and " + indices_output.device().str() | ||
); | ||
} | ||
|
||
if (A.scalar_type() != B.scalar_type()) { | ||
C10_THROW_ERROR(ValueError, | ||
std::string("`A` and `B` must have the same dtype, got ") + | ||
torch::toString(A.scalar_type()) + " and " + torch::toString(B.scalar_type()) | ||
); | ||
} | ||
|
||
torch::Tensor output; | ||
if (A.device().is_cpu()) { | ||
output = torch::zeros({output_size, A.size(1), B.size(1)}, | ||
torch::TensorOptions().dtype(A.scalar_type()).device(A.device()) | ||
); | ||
|
||
assert(output.is_contiguous()); | ||
|
||
AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "outer_product_scatter_add", [&](){ | ||
mops::outer_product_scatter_add<scalar_t>( | ||
torch_to_mops_2d<scalar_t>(output.reshape({-1, output.size(1) * output.size(2)})), | ||
torch_to_mops_2d<scalar_t>(A), | ||
torch_to_mops_2d<scalar_t>(B), | ||
torch_to_mops_1d<int32_t>(indices_output) | ||
); | ||
}); | ||
} else { | ||
C10_THROW_ERROR(ValueError, | ||
"outer_product_scatter_add is not implemented for device " + A.device().str() | ||
); | ||
} | ||
|
||
if (A.requires_grad() || B.requires_grad()) { | ||
ctx->save_for_backward({A, B, indices_output}); | ||
} | ||
|
||
return {output}; | ||
} | ||
|
||
std::vector<torch::Tensor> OuterProductScatterAdd::backward( | ||
torch::autograd::AutogradContext *ctx, | ||
std::vector<torch::Tensor> grad_outputs | ||
) { | ||
auto saved_variables = ctx->get_saved_variables(); | ||
auto A = saved_variables[0]; | ||
auto B = saved_variables[1]; | ||
auto indices_output = saved_variables[2]; | ||
|
||
auto grad_output = grad_outputs[0]; | ||
if (!grad_output.is_contiguous()) { | ||
throw std::runtime_error("expected contiguous grad_output"); | ||
} | ||
|
||
auto grad_A = torch::Tensor(); | ||
auto grad_B = torch::Tensor(); | ||
|
||
if (A.device().is_cpu()) { | ||
AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "outer_product_scatter_add_vjp", [&](){ | ||
auto mops_grad_A = mops::Tensor<scalar_t, 2>{nullptr, {0, 0}}; | ||
if (A.requires_grad()) { | ||
grad_A = torch::zeros_like(A); | ||
mops_grad_A = torch_to_mops_2d<scalar_t>(grad_A); | ||
} | ||
|
||
auto mops_grad_B = mops::Tensor<scalar_t, 2>{nullptr, {0, 0}}; | ||
if (B.requires_grad()) { | ||
grad_B = torch::zeros_like(B); | ||
mops_grad_B = torch_to_mops_2d<scalar_t>(grad_B); | ||
} | ||
|
||
mops::outer_product_scatter_add_vjp<scalar_t>( | ||
mops_grad_A, | ||
mops_grad_B, | ||
torch_to_mops_2d<scalar_t>(grad_output.reshape({-1, grad_output.size(1) * grad_output.size(2)})), | ||
torch_to_mops_2d<scalar_t>(A), | ||
torch_to_mops_2d<scalar_t>(B), | ||
torch_to_mops_1d<int32_t>(indices_output) | ||
); | ||
}); | ||
} else { | ||
C10_THROW_ERROR(ValueError, | ||
"outer_product_scatter_add is not implemented for device " + A.device().str() | ||
); | ||
} | ||
|
||
return {grad_A, grad_B, torch::Tensor(), torch::Tensor()}; | ||
} |
Oops, something went wrong.