diff --git a/ci/pipeline.yml b/ci/pipeline.yml index 285c34da..2d0a300b 100644 --- a/ci/pipeline.yml +++ b/ci/pipeline.yml @@ -35,5 +35,5 @@ test_job: SLURM_JOB_NUM_NODES: 1 SLURM_PARTITION: normal SLURM_NTASKS: 1 - SLURM_TIMELIMIT: '00:40:00' + SLURM_TIMELIMIT: '02:30:00' GIT_STRATEGY: fetch diff --git a/setup.py b/setup.py index 4af4b874..c9ccba30 100644 --- a/setup.py +++ b/setup.py @@ -139,7 +139,7 @@ def run(self): "sphericart": [ "sphericart/lib/*", "sphericart/include/*", - ] + ], }, extras_require=extras_require, ) diff --git a/sphericart-jax/CMakeLists.txt b/sphericart-jax/CMakeLists.txt index 8842acfd..9a00a818 100644 --- a/sphericart-jax/CMakeLists.txt +++ b/sphericart-jax/CMakeLists.txt @@ -60,7 +60,7 @@ if(CMAKE_CUDA_COMPILER AND SPHERICART_ENABLE_CUDA) enable_language(CUDA) include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - set(CUDA_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/sphericart_jax_cuda.cu ${CMAKE_CURRENT_LIST_DIR}/src/sphericart_jax_cuda.cpp) + set(CUDA_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/sphericart_jax_cuda.cpp) pybind11_add_module(sphericart_jax_cuda ${CUDA_SOURCES}) set_target_properties(sphericart_jax_cuda PROPERTIES CUDA_ARCHITECTURES native) install(TARGETS sphericart_jax_cuda DESTINATION sphericart_jax_cuda) diff --git a/sphericart-jax/include/sphericart/sphericart_jax_cuda.hpp b/sphericart-jax/include/sphericart/sphericart_jax_cuda.hpp deleted file mode 100644 index 7bb0b109..00000000 --- a/sphericart-jax/include/sphericart/sphericart_jax_cuda.hpp +++ /dev/null @@ -1,47 +0,0 @@ -// This file is needed as a workaround for pybind11 not accepting cuda files. -// Note that all the templated functions are split into separate functions so -// that they can be compiled in the `.cu` file. - -#ifndef _SPHERICART_JAX_CUDA_HPP_ -#define _SPHERICART_JAX_CUDA_HPP_ - -#include -#include - -struct SphDescriptor { - std::int64_t n_samples; - std::int64_t lmax; -}; - -namespace sphericart_jax { - -namespace cuda { - -void cuda_spherical_f32(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len); - -void cuda_spherical_f64(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len); - -void cuda_dspherical_f32(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len); - -void cuda_dspherical_f64(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len); - -void cuda_ddspherical_f32(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len); - -void cuda_ddspherical_f64(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len); - -void cuda_solid_f32(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len); - -void cuda_solid_f64(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len); - -void cuda_dsolid_f32(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len); - -void cuda_dsolid_f64(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len); - -void cuda_ddsolid_f32(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len); - -void cuda_ddsolid_f64(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len); - -} // namespace cuda -} // namespace sphericart_jax - -#endif diff --git a/sphericart-jax/setup.py b/sphericart-jax/setup.py index 9100b520..11fa1bb6 100644 --- a/sphericart-jax/setup.py +++ b/sphericart-jax/setup.py @@ -93,6 +93,6 @@ def run(self): "sphericart-jax": [ "sphericart/jax/lib/*", "sphericart/jax/include/*", - ] + ], }, ) diff --git a/sphericart-jax/src/sphericart_jax_cpu.cpp b/sphericart-jax/src/sphericart_jax_cpu.cpp index 5571ff3e..b5335805 100644 --- a/sphericart-jax/src/sphericart_jax_cpu.cpp +++ b/sphericart-jax/src/sphericart_jax_cpu.cpp @@ -90,7 +90,6 @@ void cpu_sph_with_hessians(void* out_tuple, const void** in) { } // Registration of the custom calls with pybind11 - pybind11::dict Registrations() { pybind11::dict dict; dict["cpu_spherical_f32"] = EncapsulateFunction(cpu_sph); diff --git a/sphericart-jax/src/sphericart_jax_cuda.cpp b/sphericart-jax/src/sphericart_jax_cuda.cpp index 61824f53..23ba7f76 100644 --- a/sphericart-jax/src/sphericart_jax_cuda.cpp +++ b/sphericart-jax/src/sphericart_jax_cuda.cpp @@ -2,38 +2,116 @@ // devices. It is exposed as a standard pybind11 module defining "capsule" // objects containing our methods. For simplicity, we export a separate capsule // for each supported dtype. -// This file is separated from `sphericart_jax_cuda.cu` because pybind11 does -// not accept cuda files. #include #include #include #include +#include "sphericart_cuda.hpp" #include "sphericart/pybind11_kernel_helpers.hpp" -#include "sphericart/sphericart_jax_cuda.hpp" -using namespace sphericart_jax; -using namespace sphericart_jax::cuda; +struct SphDescriptor { + std::int64_t n_samples; + std::int64_t lmax; +}; -namespace { +namespace sphericart_jax { +namespace cuda { -// Registration of the custom calls with pybind11 +template