Skip to content

Commit

Permalink
Move compilation of CUDA code to NVRTC (#131)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Filippo Bigi <[email protected]>
Co-authored-by: frostedoyster <[email protected]>
Co-authored-by: Michele Ceriotti <[email protected]>
Co-authored-by: Guillaume Fraux <[email protected]>
Co-authored-by: Guillaume Fraux <[email protected]>
  • Loading branch information
6 people authored Oct 17, 2024
1 parent eb98c8b commit 877370a
Show file tree
Hide file tree
Showing 26 changed files with 2,010 additions and 1,252 deletions.
2 changes: 1 addition & 1 deletion ci/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ test_job:
SLURM_JOB_NUM_NODES: 1
SLURM_PARTITION: normal
SLURM_NTASKS: 1
SLURM_TIMELIMIT: '00:40:00'
SLURM_TIMELIMIT: '02:30:00'
GIT_STRATEGY: fetch
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def run(self):
"sphericart": [
"sphericart/lib/*",
"sphericart/include/*",
]
],
},
extras_require=extras_require,
)
2 changes: 1 addition & 1 deletion sphericart-jax/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ if(CMAKE_CUDA_COMPILER AND SPHERICART_ENABLE_CUDA)

enable_language(CUDA)
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
set(CUDA_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/sphericart_jax_cuda.cu ${CMAKE_CURRENT_LIST_DIR}/src/sphericart_jax_cuda.cpp)
set(CUDA_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/sphericart_jax_cuda.cpp)
pybind11_add_module(sphericart_jax_cuda ${CUDA_SOURCES})
set_target_properties(sphericart_jax_cuda PROPERTIES CUDA_ARCHITECTURES native)
install(TARGETS sphericart_jax_cuda DESTINATION sphericart_jax_cuda)
Expand Down
47 changes: 0 additions & 47 deletions sphericart-jax/include/sphericart/sphericart_jax_cuda.hpp

This file was deleted.

2 changes: 1 addition & 1 deletion sphericart-jax/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,6 @@ def run(self):
"sphericart-jax": [
"sphericart/jax/lib/*",
"sphericart/jax/include/*",
]
],
},
)
1 change: 0 additions & 1 deletion sphericart-jax/src/sphericart_jax_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ void cpu_sph_with_hessians(void* out_tuple, const void** in) {
}

// Registration of the custom calls with pybind11

pybind11::dict Registrations() {
pybind11::dict dict;
dict["cpu_spherical_f32"] = EncapsulateFunction(cpu_sph<sphericart::SphericalHarmonics, float>);
Expand Down
119 changes: 99 additions & 20 deletions sphericart-jax/src/sphericart_jax_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,116 @@
// devices. It is exposed as a standard pybind11 module defining "capsule"
// objects containing our methods. For simplicity, we export a separate capsule
// for each supported dtype.
// This file is separated from `sphericart_jax_cuda.cu` because pybind11 does
// not accept cuda files.

#include <cstdlib>
#include <map>
#include <mutex>
#include <tuple>

#include "sphericart_cuda.hpp"
#include "sphericart/pybind11_kernel_helpers.hpp"
#include "sphericart/sphericart_jax_cuda.hpp"

using namespace sphericart_jax;
using namespace sphericart_jax::cuda;
struct SphDescriptor {
std::int64_t n_samples;
std::int64_t lmax;
};

namespace {
namespace sphericart_jax {
namespace cuda {

// Registration of the custom calls with pybind11
template <template <typename> class C, typename T>
using CacheMapCUDA = std::map<size_t, std::unique_ptr<C<T>>>;

template <template <typename> class C, typename T>
std::unique_ptr<C<T>>& _get_or_create_sph_cuda(size_t l_max) {
// Static map to cache instances based on parameters
static CacheMapCUDA<C, T> sph_cache;
static std::mutex cache_mutex;

// Check if instance exists in cache, if not create and store it
std::lock_guard<std::mutex> lock(cache_mutex);
auto it = sph_cache.find(l_max);
if (it == sph_cache.end()) {
it = sph_cache.insert({l_max, std::make_unique<C<T>>(l_max)}).first;
}
return it->second;
}

template <template <typename> class C, typename T>
inline void cuda_sph(void* stream, void** in, const char* opaque, std::size_t opaque_len) {
// Parse the inputs
const T* xyz = reinterpret_cast<const T*>(in[0]);
T* sph = reinterpret_cast<T*>(in[1]);

const SphDescriptor& d = *UnpackDescriptor<SphDescriptor>(opaque, opaque_len);
const std::int64_t n_samples = d.n_samples;
const std::int64_t lmax = d.lmax;

auto& calculator = _get_or_create_sph_cuda<C, T>(lmax);

calculator->compute(xyz, n_samples, sph, stream);
}

template <template <typename> class C, typename T>
inline void cuda_sph_with_gradients(
void* stream, void** in, const char* opaque, std::size_t opaque_len
) {
// Parse the inputs
const T* xyz = reinterpret_cast<const T*>(in[0]);
T* sph = reinterpret_cast<T*>(in[1]);
T* dsph = reinterpret_cast<T*>(in[2]);

const SphDescriptor& d = *UnpackDescriptor<SphDescriptor>(opaque, opaque_len);
const std::int64_t n_samples = d.n_samples;
const std::int64_t lmax = d.lmax;

auto& calculator = _get_or_create_sph_cuda<C, T>(lmax);
calculator->compute_with_gradients(xyz, n_samples, sph, dsph, stream);
}

template <template <typename> class C, typename T>
inline void cuda_sph_with_hessians(
void* stream, void** in, const char* opaque, std::size_t opaque_len
) {
// Parse the inputs
const T* xyz = reinterpret_cast<const T*>(in[0]);
T* sph = reinterpret_cast<T*>(in[1]);
T* dsph = reinterpret_cast<T*>(in[2]);
T* ddsph = reinterpret_cast<T*>(in[3]);

const SphDescriptor& d = *UnpackDescriptor<SphDescriptor>(opaque, opaque_len);
const std::int64_t n_samples = d.n_samples;
const std::int64_t lmax = d.lmax;

auto& calculator = _get_or_create_sph_cuda<C, T>(lmax);
calculator->compute_with_hessians(xyz, n_samples, sph, dsph, ddsph, stream);
}

// Registration of the custom calls with pybind11
pybind11::dict Registrations() {
pybind11::dict dict;
dict["cuda_spherical_f32"] = EncapsulateFunction(cuda_spherical_f32);
dict["cuda_spherical_f64"] = EncapsulateFunction(cuda_spherical_f64);
dict["cuda_dspherical_f32"] = EncapsulateFunction(cuda_dspherical_f32);
dict["cuda_dspherical_f64"] = EncapsulateFunction(cuda_dspherical_f64);
dict["cuda_ddspherical_f32"] = EncapsulateFunction(cuda_ddspherical_f32);
dict["cuda_ddspherical_f64"] = EncapsulateFunction(cuda_ddspherical_f64);
dict["cuda_solid_f32"] = EncapsulateFunction(cuda_solid_f32);
dict["cuda_solid_f64"] = EncapsulateFunction(cuda_solid_f64);
dict["cuda_dsolid_f32"] = EncapsulateFunction(cuda_dsolid_f32);
dict["cuda_dsolid_f64"] = EncapsulateFunction(cuda_dsolid_f64);
dict["cuda_ddsolid_f32"] = EncapsulateFunction(cuda_ddsolid_f32);
dict["cuda_ddsolid_f64"] = EncapsulateFunction(cuda_ddsolid_f64);
dict["cuda_spherical_f32"] =
EncapsulateFunction(cuda_sph<sphericart::cuda::SphericalHarmonics, float>);
dict["cuda_spherical_f64"] =
EncapsulateFunction(cuda_sph<sphericart::cuda::SphericalHarmonics, double>);
dict["cuda_dspherical_f32"] =
EncapsulateFunction(cuda_sph_with_gradients<sphericart::cuda::SphericalHarmonics, float>);
dict["cuda_dspherical_f64"] =
EncapsulateFunction(cuda_sph_with_gradients<sphericart::cuda::SphericalHarmonics, double>);
dict["cuda_ddspherical_f32"] =
EncapsulateFunction(cuda_sph_with_hessians<sphericart::cuda::SphericalHarmonics, float>);
dict["cuda_ddspherical_f64"] =
EncapsulateFunction(cuda_sph_with_hessians<sphericart::cuda::SphericalHarmonics, double>);
dict["cuda_solid_f32"] = EncapsulateFunction(cuda_sph<sphericart::cuda::SolidHarmonics, float>);
dict["cuda_solid_f64"] = EncapsulateFunction(cuda_sph<sphericart::cuda::SolidHarmonics, double>);
dict["cuda_dsolid_f32"] =
EncapsulateFunction(cuda_sph_with_gradients<sphericart::cuda::SolidHarmonics, float>);
dict["cuda_dsolid_f64"] =
EncapsulateFunction(cuda_sph_with_gradients<sphericart::cuda::SolidHarmonics, double>);
dict["cuda_ddsolid_f32"] =
EncapsulateFunction(cuda_sph_with_hessians<sphericart::cuda::SolidHarmonics, float>);
dict["cuda_ddsolid_f64"] =
EncapsulateFunction(cuda_sph_with_hessians<sphericart::cuda::SolidHarmonics, double>);
return dict;
}

Expand All @@ -44,4 +122,5 @@ PYBIND11_MODULE(sphericart_jax_cuda, m) {
});
}

} // namespace
} // namespace cuda
} // namespace sphericart_jax
148 changes: 0 additions & 148 deletions sphericart-jax/src/sphericart_jax_cuda.cu

This file was deleted.

Loading

0 comments on commit 877370a

Please sign in to comment.