-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9303ba0
commit 2ea1745
Showing
28 changed files
with
928 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,14 @@ | ||
#ifndef MOPS_HPP | ||
#define MOPS_HPP | ||
|
||
#include "mops/exports.h" // IWYU pragma: export | ||
|
||
#include "mops/capi.hpp" // IWYU pragma: export | ||
#include "mops/opsa.hpp" // IWYU pragma: export | ||
#include "mops/exports.h" // IWYU pragma: export | ||
|
||
#include "mops/capi.hpp" // IWYU pragma: export | ||
#include "mops/hpe.hpp" // IWYU pragma: export | ||
#include "mops/opsa.hpp" // IWYU pragma: export | ||
#include "mops/sap.hpp" // IWYU pragma: export | ||
#include "mops/opsax.hpp" // IWYU pragma: export | ||
#include "mops/sasax.hpp" // IWYU pragma: export | ||
|
||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#ifndef MOPS_OPSAX_H | ||
#define MOPS_OPSAX_H | ||
|
||
#include "mops/exports.h" | ||
#include "mops/tensor.h" | ||
|
||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
/// CPU version of mops::outer_product_scatter_add_with for 32-bit floats | ||
int MOPS_EXPORT mops_outer_product_scatter_add_with_weights_f32( | ||
mops_tensor_3d_f32_t output, | ||
mops_tensor_2d_f32_t tensor_a, | ||
mops_tensor_2d_f32_t tensor_r, | ||
mops_tensor_2d_f32_t tensor_x, | ||
mops_tensor_1d_i32_t tensor_i, | ||
mops_tensor_1d_i32_t tensor_j | ||
); | ||
|
||
|
||
/// CPU version of mops::outer_product_scatter_add_with for 64-bit floats | ||
int MOPS_EXPORT mops_outer_product_scatter_add_with_weights_f64( | ||
mops_tensor_3d_f64_t output, | ||
mops_tensor_2d_f64_t tensor_a, | ||
mops_tensor_2d_f64_t tensor_r, | ||
mops_tensor_2d_f64_t tensor_x, | ||
mops_tensor_1d_i32_t tensor_i, | ||
mops_tensor_1d_i32_t tensor_j | ||
); | ||
|
||
|
||
/// CUDA version of mops::outer_product_scatter_add_with for 32-bit floats | ||
int MOPS_EXPORT mops_cuda_outer_product_scatter_add_with_weights_f32( | ||
mops_tensor_3d_f32_t output, | ||
mops_tensor_2d_f32_t tensor_a, | ||
mops_tensor_2d_f32_t tensor_r, | ||
mops_tensor_2d_f32_t tensor_x, | ||
mops_tensor_1d_i32_t tensor_i, | ||
mops_tensor_1d_i32_t tensor_j | ||
); | ||
|
||
|
||
/// CUDA version of mops::outer_product_scatter_add_with for 64-bit floats | ||
int MOPS_EXPORT mops_cuda_outer_product_scatter_add_with_weights_f64( | ||
mops_tensor_3d_f64_t output, | ||
mops_tensor_2d_f64_t tensor_a, | ||
mops_tensor_2d_f64_t tensor_r, | ||
mops_tensor_2d_f64_t tensor_x, | ||
mops_tensor_1d_i32_t tensor_i, | ||
mops_tensor_1d_i32_t tensor_j | ||
); | ||
|
||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
|
||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#ifndef MOPS_OPSAX_HPP | ||
#define MOPS_OPSAX_HPP | ||
|
||
#include <cstddef> | ||
#include <cstdint> | ||
|
||
#include "mops/exports.h" | ||
#include "mops/tensor.hpp" | ||
|
||
namespace mops { | ||
/// TODO | ||
template<typename scalar_t> | ||
void MOPS_EXPORT outer_product_scatter_add_with_weights( | ||
Tensor<scalar_t, 3> output, | ||
Tensor<scalar_t, 2> tensor_a, | ||
Tensor<scalar_t, 2> tensor_r, | ||
Tensor<scalar_t, 2> tensor_x, | ||
Tensor<int32_t, 1> i, | ||
Tensor<int32_t, 1> j | ||
); | ||
|
||
// these templates will be precompiled and provided in the mops library | ||
extern template void outer_product_scatter_add_with_weights( | ||
Tensor<float, 3> output, | ||
Tensor<float, 2> tensor_a, | ||
Tensor<float, 2> tensor_r, | ||
Tensor<float, 2> tensor_x, | ||
Tensor<int32_t, 1> i, | ||
Tensor<int32_t, 1> j | ||
); | ||
|
||
extern template void outer_product_scatter_add_with_weights( | ||
Tensor<double, 3> output, | ||
Tensor<double, 2> tensor_a, | ||
Tensor<double, 2> tensor_r, | ||
Tensor<double, 2> tensor_x, | ||
Tensor<int32_t, 1> i, | ||
Tensor<int32_t, 1> j | ||
); | ||
|
||
namespace cuda { | ||
/// CUDA version of mops::outer_product_scatter_add_with | ||
template<typename scalar_t> | ||
void MOPS_EXPORT outer_product_scatter_add_with_weights( | ||
Tensor<scalar_t, 3> output, | ||
Tensor<scalar_t, 2> tensor_a, | ||
Tensor<scalar_t, 2> tensor_r, | ||
Tensor<scalar_t, 2> tensor_x, | ||
Tensor<int32_t, 1> i, | ||
Tensor<int32_t, 1> j | ||
); | ||
} | ||
} | ||
|
||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
#ifndef MOPS_SASAX_H | ||
#define MOPS_SASAX_H | ||
|
||
#include "mops/exports.h" | ||
#include "mops/tensor.h" | ||
|
||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
/// CPU version of mops::sparse_accumulation_scatter_add_with for 32-bit floats | ||
int MOPS_EXPORT mops_sparse_accumulation_scatter_add_with_weights_f32( | ||
mops_tensor_3d_f32_t output, | ||
mops_tensor_2d_f32_t tensor_a, | ||
mops_tensor_2d_f32_t tensor_r, | ||
mops_tensor_3d_f32_t tensor_x, | ||
mops_tensor_1d_f32_t tensor_c, | ||
mops_tensor_1d_i32_t tensor_i, | ||
mops_tensor_1d_i32_t tensor_j, | ||
mops_tensor_1d_i32_t tensor_m_1, | ||
mops_tensor_1d_i32_t tensor_m_2, | ||
mops_tensor_1d_i32_t tensor_m_3 | ||
); | ||
|
||
|
||
/// CPU version of mops::sparse_accumulation_scatter_add_with for 64-bit floats | ||
int MOPS_EXPORT mops_sparse_accumulation_scatter_add_with_weights_f64( | ||
mops_tensor_3d_f64_t output, | ||
mops_tensor_2d_f64_t tensor_a, | ||
mops_tensor_2d_f64_t tensor_r, | ||
mops_tensor_3d_f64_t tensor_x, | ||
mops_tensor_1d_f64_t tensor_c, | ||
mops_tensor_1d_i32_t tensor_i, | ||
mops_tensor_1d_i32_t tensor_j, | ||
mops_tensor_1d_i32_t tensor_m_1, | ||
mops_tensor_1d_i32_t tensor_m_2, | ||
mops_tensor_1d_i32_t tensor_m_3 | ||
); | ||
|
||
|
||
/// CUDA version of mops::sparse_accumulation_scatter_add_with for 32-bit floats | ||
int MOPS_EXPORT mops_cuda_sparse_accumulation_scatter_add_with_weights_f32( | ||
mops_tensor_3d_f32_t output, | ||
mops_tensor_2d_f32_t tensor_a, | ||
mops_tensor_2d_f32_t tensor_r, | ||
mops_tensor_3d_f32_t tensor_x, | ||
mops_tensor_1d_f32_t tensor_c, | ||
mops_tensor_1d_i32_t tensor_i, | ||
mops_tensor_1d_i32_t tensor_j, | ||
mops_tensor_1d_i32_t tensor_m_1, | ||
mops_tensor_1d_i32_t tensor_m_2, | ||
mops_tensor_1d_i32_t tensor_m_3 | ||
); | ||
|
||
|
||
/// CUDA version of mops::sparse_accumulation_scatter_add_with for 64-bit floats | ||
int MOPS_EXPORT mops_cuda_sparse_accumulation_scatter_add_with_weights_f64( | ||
mops_tensor_3d_f64_t output, | ||
mops_tensor_2d_f64_t tensor_a, | ||
mops_tensor_2d_f64_t tensor_r, | ||
mops_tensor_3d_f64_t tensor_x, | ||
mops_tensor_1d_f64_t tensor_c, | ||
mops_tensor_1d_i32_t tensor_i, | ||
mops_tensor_1d_i32_t tensor_j, | ||
mops_tensor_1d_i32_t tensor_m_1, | ||
mops_tensor_1d_i32_t tensor_m_2, | ||
mops_tensor_1d_i32_t tensor_m_3 | ||
); | ||
|
||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
|
||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#ifndef MOPS_SASAX_HPP | ||
#define MOPS_SASAX_HPP | ||
|
||
#include <cstddef> | ||
#include <cstdint> | ||
|
||
#include "mops/exports.h" | ||
#include "mops/tensor.hpp" | ||
|
||
namespace mops { | ||
/// TODO | ||
template<typename scalar_t> | ||
void MOPS_EXPORT sparse_accumulation_scatter_add_with_weights( | ||
Tensor<scalar_t, 3> output, | ||
Tensor<scalar_t, 2> tensor_a, | ||
Tensor<scalar_t, 2> tensor_r, | ||
Tensor<scalar_t, 3> tensor_x, | ||
Tensor<scalar_t, 1> tensor_c, | ||
Tensor<int, 1> tensor_i, | ||
Tensor<int, 1> tensor_j, | ||
Tensor<int, 1> tensor_m_1, | ||
Tensor<int, 1> tensor_m_2, | ||
Tensor<int, 1> tensor_m_3 | ||
); | ||
|
||
// these templates will be precompiled and provided in the mops library | ||
extern template void sparse_accumulation_scatter_add_with_weights( | ||
Tensor<float, 3> output, | ||
Tensor<float, 2> tensor_a, | ||
Tensor<float, 2> tensor_r, | ||
Tensor<float, 3> tensor_x, | ||
Tensor<float, 1> tensor_c, | ||
Tensor<int, 1> tensor_i, | ||
Tensor<int, 1> tensor_j, | ||
Tensor<int, 1> tensor_m_1, | ||
Tensor<int, 1> tensor_m_2, | ||
Tensor<int, 1> tensor_m_3 | ||
); | ||
|
||
extern template void sparse_accumulation_scatter_add_with_weights( | ||
Tensor<double, 3> output, | ||
Tensor<double, 2> tensor_a, | ||
Tensor<double, 2> tensor_r, | ||
Tensor<double, 3> tensor_x, | ||
Tensor<double, 1> tensor_c, | ||
Tensor<int, 1> tensor_i, | ||
Tensor<int, 1> tensor_j, | ||
Tensor<int, 1> tensor_m_1, | ||
Tensor<int, 1> tensor_m_2, | ||
Tensor<int, 1> tensor_m_3 | ||
); | ||
|
||
namespace cuda { | ||
/// CUDA version of mops::sparse_accumulation_scatter_add_with | ||
template<typename scalar_t> | ||
void MOPS_EXPORT sparse_accumulation_scatter_add_with_weights( | ||
Tensor<scalar_t, 3> output, | ||
Tensor<scalar_t, 2> tensor_a, | ||
Tensor<scalar_t, 2> tensor_r, | ||
Tensor<scalar_t, 3> tensor_x, | ||
Tensor<scalar_t, 1> tensor_c, | ||
Tensor<int, 1> tensor_i, | ||
Tensor<int, 1> tensor_j, | ||
Tensor<int, 1> tensor_m_1, | ||
Tensor<int, 1> tensor_m_2, | ||
Tensor<int, 1> tensor_m_3 | ||
); | ||
} | ||
} | ||
|
||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.