Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move compilation of CUDA code to NVRTC #131

Merged
merged 153 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from 152 commits
Commits
Show all changes
153 commits
Select commit Hold shift + click to select a range
9a0a53c
NVRTC support.
nickjbrowning Jul 30, 2024
04c9a0c
added stream back in.
nickjbrowning Jul 30, 2024
9ce24b6
caching was creating issues...
nickjbrowning Jul 30, 2024
8dc3734
fixed stub.
nickjbrowning Jul 30, 2024
5c22c89
fixed cuda stubs.
nickjbrowning Jul 30, 2024
8d2742a
linting.
nickjbrowning Jul 30, 2024
9a21479
update with jittify
nickjbrowning Aug 29, 2024
3e8afce
added some simple temporary tests
nickjbrowning Aug 29, 2024
4679794
update
nickjbrowning Sep 5, 2024
2d22541
x
nickjbrowning Sep 5, 2024
11ac20d
backwards is broken...
nickjbrowning Sep 5, 2024
d91dfb0
fixed backwards.
nickjbrowning Sep 5, 2024
5e4237b
fixed backwards + smem update.
nickjbrowning Sep 5, 2024
5fb3c89
removfed smem caching in sphericart::cuda class
nickjbrowning Sep 5, 2024
a8aeb25
basic dynamic loading working, issues with some tests...
nickjbrowning Sep 5, 2024
1b441dc
added nvrtcGetErrorString
nickjbrowning Sep 5, 2024
022efb2
demangling ifdef.
nickjbrowning Sep 5, 2024
9b560b1
removed some comments.
nickjbrowning Sep 5, 2024
2aec4c1
fixed cuContext issue.
nickjbrowning Sep 5, 2024
53f1525
minor updates to add macros.
nickjbrowning Sep 5, 2024
8d65e5d
fixed macro choice
nickjbrowning Sep 6, 2024
45d98d5
Merge branch 'main' into nvrtc
frostedoyster Sep 6, 2024
11d6977
fixed linting
nickjbrowning Sep 6, 2024
4af48ab
Fix const conflict
frostedoyster Sep 6, 2024
d1ce794
Merge branch 'nvrtc' of https://github.com/lab-cosmo/sphericart into …
frostedoyster Sep 6, 2024
ea91723
Fix CUDA stubs
frostedoyster Sep 6, 2024
80043f2
removed faulty macros
nickjbrowning Sep 6, 2024
82f8f2d
formatting
nickjbrowning Sep 6, 2024
70e5f9c
stubs...
nickjbrowning Sep 6, 2024
e51085e
jax update.
nickjbrowning Sep 6, 2024
839fe19
updated the include and cuda_src paths
nickjbrowning Sep 6, 2024
82bc168
static linking
nickjbrowning Sep 6, 2024
2e66eb9
removed static linking for now.
nickjbrowning Sep 6, 2024
393b821
Print error code when failing to load the kernels
ceriottm Sep 6, 2024
66fb009
Merge branch 'nvrtc' of github.com:lab-cosmo/sphericart into nvrtc
ceriottm Sep 6, 2024
b0c8c74
added explict architecture spec.
nickjbrowning Sep 6, 2024
3f0dd1d
Merge branch 'nvrtc' of https://github.com/lab-cosmo/sphericart into …
nickjbrowning Sep 6, 2024
4b7f4fd
added sync default back.
nickjbrowning Sep 6, 2024
b167f44
added in dyn_cuda for architecture compute
nickjbrowning Sep 6, 2024
83cd27a
rename files to cpp, added cudart dl links
nickjbrowning Sep 9, 2024
3e9a6aa
comments.
nickjbrowning Sep 9, 2024
79d5793
comments
nickjbrowning Sep 9, 2024
86db8ee
removed stubs, conditioned out cuda in cpp
nickjbrowning Sep 9, 2024
44cc745
re-adding the CPU-only stubs to get around CI
nickjbrowning Sep 9, 2024
407b81f
comments and removed cuda function from macro.
nickjbrowning Sep 9, 2024
ac5bddb
comment.
nickjbrowning Sep 9, 2024
e4977fc
removed unused header
nickjbrowning Sep 9, 2024
2438fae
removed some unecessary files
nickjbrowning Sep 9, 2024
0c22605
updated timeout to 1H for CI
nickjbrowning Sep 9, 2024
5e7d7fe
comments and timeouts
nickjbrowning Sep 9, 2024
16dabf8
missing a handle close.
nickjbrowning Sep 9, 2024
8d8b31d
updated to distribute cuda src and search for it.
nickjbrowning Sep 10, 2024
657f520
x
nickjbrowning Sep 10, 2024
a0836ad
formatting.
nickjbrowning Sep 10, 2024
9406198
ad package_data to whl deployment.
nickjbrowning Sep 10, 2024
bc5a7a4
package_data not required, handled by install.
nickjbrowning Sep 10, 2024
3db55a2
use rfind instead
nickjbrowning Sep 10, 2024
66d70d5
Install sphericart for torch-tests in tox
Luthaf Sep 11, 2024
53a8c99
removed linking against cuda libs, header-only now.
nickjbrowning Sep 12, 2024
ac7b4c4
Merge branch 'nvrtc' of https://github.com/lab-cosmo/sphericart into …
nickjbrowning Sep 12, 2024
79c2531
some preliminary stuff for string literal-based work.
nickjbrowning Sep 13, 2024
d485ff7
cmake update
nickjbrowning Sep 13, 2024
47a465c
some updates
nickjbrowning Sep 13, 2024
8233a70
still non-functional.
nickjbrowning Sep 13, 2024
1f0f74d
formatting
nickjbrowning Sep 13, 2024
7eba1d6
formatting
nickjbrowning Sep 13, 2024
27cc680
comment and cleanup
nickjbrowning Sep 13, 2024
8d1f361
fixed issue building sphericart.torch
nickjbrowning Sep 13, 2024
988a3c1
updates
nickjbrowning Sep 13, 2024
4e06d2b
added a very ugly hack to get around autograd context issues.
nickjbrowning Sep 13, 2024
a2c6f7a
update to fix autograd issues.
nickjbrowning Sep 18, 2024
d742b3c
forgot to re-add smem check.
nickjbrowning Sep 18, 2024
acd30bb
stub issues
nickjbrowning Sep 19, 2024
c550bee
linting
nickjbrowning Sep 19, 2024
2cbbb56
update bench
nickjbrowning Sep 19, 2024
acd15ee
minor changes.
nickjbrowning Sep 19, 2024
3e85aeb
separated DynamicCUDA into 3 classes representing RT, driver, nvrtc
nickjbrowning Sep 19, 2024
48f27d3
changed to compile on first launch
nickjbrowning Sep 19, 2024
0cf1809
Merge branch 'main' into nvrtc
nickjbrowning Sep 19, 2024
93bd107
changed data packing to vector
nickjbrowning Sep 19, 2024
9acd1cf
should not be compatible with cuda 11
nickjbrowning Sep 19, 2024
ffca329
added a possible parent class for dynamic loading, unused atm.
nickjbrowning Sep 19, 2024
1642440
removed message
nickjbrowning Sep 19, 2024
bab09c8
removed experimental dynamic loader class
nickjbrowning Sep 30, 2024
e828ea9
remove dunused wrapper
nickjbrowning Sep 30, 2024
92cae0b
removed comments
nickjbrowning Sep 30, 2024
0345556
arch native default to OFF
nickjbrowning Sep 30, 2024
f5a280d
Update sphericart/CMakeLists.txt
nickjbrowning Oct 2, 2024
90ff9f1
Update sphericart/CMakeLists.txt
nickjbrowning Oct 2, 2024
ea80574
Update sphericart/CMakeLists.txt
nickjbrowning Oct 2, 2024
b9b34d7
Update sphericart/CMakeLists.txt
nickjbrowning Oct 2, 2024
972eba6
issue in sphericart_impl, removed intermediary header
nickjbrowning Oct 2, 2024
9973a6d
Merge branch 'nvrtc' of https://github.com/lab-cosmo/sphericart into …
nickjbrowning Oct 2, 2024
6d82aac
removed --define-macro from nvrtc call.
nickjbrowning Oct 2, 2024
2b25654
arch-native back to ON by default.
nickjbrowning Oct 2, 2024
97883a1
Update sphericart/CMakeLists.txt
nickjbrowning Oct 8, 2024
b979f57
x
nickjbrowning Oct 8, 2024
01b6898
Merge branch 'nvrtc' of https://github.com/lab-cosmo/sphericart into …
nickjbrowning Oct 8, 2024
b827ab0
Update sphericart/include/cuda_cache.hpp
nickjbrowning Oct 8, 2024
2285c91
Update sphericart/include/cuda_cache.hpp
nickjbrowning Oct 8, 2024
cac407e
Update sphericart/include/cuda_cache.hpp
nickjbrowning Oct 8, 2024
61c8ca6
Update sphericart/include/cuda_cache.hpp
nickjbrowning Oct 8, 2024
36fe3c5
Update sphericart/include/cuda_cache.hpp
nickjbrowning Oct 8, 2024
109d14b
Update sphericart/include/cuda_cache.hpp
nickjbrowning Oct 8, 2024
393f9a8
Update sphericart/include/cuda_cache.hpp
nickjbrowning Oct 8, 2024
07a74c3
Update sphericart/include/cuda_cache.hpp
nickjbrowning Oct 8, 2024
a82a490
Update sphericart/include/dynamic_cuda.hpp
nickjbrowning Oct 8, 2024
f9fdc07
Update sphericart/include/dynamic_cuda.hpp
nickjbrowning Oct 8, 2024
bf8e648
Update sphericart/include/dynamic_cuda.hpp
nickjbrowning Oct 8, 2024
518464c
Update sphericart/include/dynamic_cuda.hpp
nickjbrowning Oct 8, 2024
bbfddeb
Merge branch 'nvrtc' of https://github.com/lab-cosmo/sphericart into …
nickjbrowning Oct 8, 2024
a7eae9a
update
nickjbrowning Oct 9, 2024
d4c5be4
removed ifdef
nickjbrowning Oct 9, 2024
58590f4
stubs matter
nickjbrowning Oct 9, 2024
25560c8
I forgot another stub
nickjbrowning Oct 9, 2024
b4c4b30
I forgot something else in the stub
nickjbrowning Oct 9, 2024
22fdc4c
reverted back to demangling.
nickjbrowning Oct 9, 2024
f0de86c
removed comment
nickjbrowning Oct 10, 2024
4d58d39
fixed error
nickjbrowning Oct 10, 2024
b2c8e4c
header shuffling
nickjbrowning Oct 10, 2024
264b789
update to fix CI?
nickjbrowning Oct 10, 2024
a7f3430
update git ignore.
nickjbrowning Oct 10, 2024
20971bd
added pyc back to ignore, updated tox and cuda_base formatting.
nickjbrowning Oct 10, 2024
1f43725
String quoting
ceriottm Oct 10, 2024
08d4452
Merge branch 'nvrtc' of github.com:lab-cosmo/sphericart into nvrtc
ceriottm Oct 10, 2024
e797787
Make sure the jax examples are run on GPU (if available)
frostedoyster Oct 11, 2024
3f5917d
x
nickjbrowning Oct 11, 2024
ce9a829
Merge branch 'main' into nvrtc
nickjbrowning Oct 11, 2024
ab71510
formatting.
nickjbrowning Oct 11, 2024
800d295
removed block comment
nickjbrowning Oct 11, 2024
8cf7810
removed warning code.
nickjbrowning Oct 11, 2024
5a578d9
formatting.
nickjbrowning Oct 11, 2024
f3b3c1f
tox update test
nickjbrowning Oct 11, 2024
4794894
added in nifty counter to get correct constructor/destructor ordering.
nickjbrowning Oct 11, 2024
e823f7f
added dlclose calls back in.
nickjbrowning Oct 11, 2024
7d3fea6
slight mod, same result.
nickjbrowning Oct 11, 2024
eb309f5
comment.
nickjbrowning Oct 11, 2024
15850df
formatting.
nickjbrowning Oct 11, 2024
1ee85fe
removed comment.
nickjbrowning Oct 11, 2024
b433b2e
removed sphericart_jax_cuda.cu, added consts
nickjbrowning Oct 11, 2024
c2be239
forgot about the stubs again
nickjbrowning Oct 11, 2024
46718ef
Clean up sphericart-jax-cuda even more
frostedoyster Oct 12, 2024
028b274
Update sphericart/CMakeLists.txt
nickjbrowning Oct 15, 2024
2bac526
Update sphericart/include/cuda_cache.hpp
nickjbrowning Oct 15, 2024
d42a9be
Update sphericart/include/dynamic_cuda.hpp
nickjbrowning Oct 15, 2024
9f62d29
updates.
nickjbrowning Oct 15, 2024
4afc4cd
Merge branch 'nvrtc' of https://github.com/lab-cosmo/sphericart into …
nickjbrowning Oct 15, 2024
c896ff0
added comments.
nickjbrowning Oct 15, 2024
fc90df4
more changes to fix review comments.
nickjbrowning Oct 15, 2024
09709c7
updated comment on remove_pointer
nickjbrowning Oct 15, 2024
0ad7f92
comment update
nickjbrowning Oct 15, 2024
64b1c97
typo.
nickjbrowning Oct 15, 2024
fbcfd92
updated runtime error.
nickjbrowning Oct 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ test_job:
SLURM_JOB_NUM_NODES: 1
SLURM_PARTITION: normal
SLURM_NTASKS: 1
SLURM_TIMELIMIT: '00:40:00'
SLURM_TIMELIMIT: '02:30:00'
GIT_STRATEGY: fetch
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def run(self):
"sphericart": [
"sphericart/lib/*",
"sphericart/include/*",
]
],
},
extras_require=extras_require,
)
2 changes: 1 addition & 1 deletion sphericart-jax/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ if(CMAKE_CUDA_COMPILER AND SPHERICART_ENABLE_CUDA)

enable_language(CUDA)
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
set(CUDA_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/sphericart_jax_cuda.cu ${CMAKE_CURRENT_LIST_DIR}/src/sphericart_jax_cuda.cpp)
set(CUDA_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/sphericart_jax_cuda.cpp)
pybind11_add_module(sphericart_jax_cuda ${CUDA_SOURCES})
set_target_properties(sphericart_jax_cuda PROPERTIES CUDA_ARCHITECTURES native)
install(TARGETS sphericart_jax_cuda DESTINATION sphericart_jax_cuda)
Expand Down
47 changes: 0 additions & 47 deletions sphericart-jax/include/sphericart/sphericart_jax_cuda.hpp

This file was deleted.

2 changes: 1 addition & 1 deletion sphericart-jax/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,6 @@ def run(self):
"sphericart-jax": [
"sphericart/jax/lib/*",
"sphericart/jax/include/*",
]
],
},
)
1 change: 0 additions & 1 deletion sphericart-jax/src/sphericart_jax_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ void cpu_sph_with_hessians(void* out_tuple, const void** in) {
}

// Registration of the custom calls with pybind11

pybind11::dict Registrations() {
pybind11::dict dict;
dict["cpu_spherical_f32"] = EncapsulateFunction(cpu_sph<sphericart::SphericalHarmonics, float>);
Expand Down
119 changes: 99 additions & 20 deletions sphericart-jax/src/sphericart_jax_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,116 @@
// devices. It is exposed as a standard pybind11 module defining "capsule"
// objects containing our methods. For simplicity, we export a separate capsule
// for each supported dtype.
// This file is separated from `sphericart_jax_cuda.cu` because pybind11 does
// not accept cuda files.

#include <cstdlib>
#include <map>
#include <mutex>
#include <tuple>

#include "sphericart_cuda.hpp"
#include "sphericart/pybind11_kernel_helpers.hpp"
#include "sphericart/sphericart_jax_cuda.hpp"

using namespace sphericart_jax;
using namespace sphericart_jax::cuda;
struct SphDescriptor {
std::int64_t n_samples;
std::int64_t lmax;
};

namespace {
namespace sphericart_jax {
namespace cuda {

// Registration of the custom calls with pybind11
template <template <typename> class C, typename T>
using CacheMapCUDA = std::map<size_t, std::unique_ptr<C<T>>>;

template <template <typename> class C, typename T>
std::unique_ptr<C<T>>& _get_or_create_sph_cuda(size_t l_max) {
// Static map to cache instances based on parameters
static CacheMapCUDA<C, T> sph_cache;
static std::mutex cache_mutex;

// Check if instance exists in cache, if not create and store it
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frostedoyster should we purge the instances after some point? Or do we let this map grow unbounded? (This would be for a separate PR!)

std::lock_guard<std::mutex> lock(cache_mutex);
auto it = sph_cache.find(l_max);
if (it == sph_cache.end()) {
it = sph_cache.insert({l_max, std::make_unique<C<T>>(l_max)}).first;
}
return it->second;
}

template <template <typename> class C, typename T>
inline void cuda_sph(void* stream, void** in, const char* opaque, std::size_t opaque_len) {
// Parse the inputs
const T* xyz = reinterpret_cast<const T*>(in[0]);
T* sph = reinterpret_cast<T*>(in[1]);

const SphDescriptor& d = *UnpackDescriptor<SphDescriptor>(opaque, opaque_len);
const std::int64_t n_samples = d.n_samples;
const std::int64_t lmax = d.lmax;

auto& calculator = _get_or_create_sph_cuda<C, T>(lmax);

calculator->compute(xyz, n_samples, sph, stream);
}

template <template <typename> class C, typename T>
inline void cuda_sph_with_gradients(
void* stream, void** in, const char* opaque, std::size_t opaque_len
) {
// Parse the inputs
const T* xyz = reinterpret_cast<const T*>(in[0]);
T* sph = reinterpret_cast<T*>(in[1]);
T* dsph = reinterpret_cast<T*>(in[2]);

const SphDescriptor& d = *UnpackDescriptor<SphDescriptor>(opaque, opaque_len);
const std::int64_t n_samples = d.n_samples;
const std::int64_t lmax = d.lmax;

auto& calculator = _get_or_create_sph_cuda<C, T>(lmax);
calculator->compute_with_gradients(xyz, n_samples, sph, dsph, stream);
}

template <template <typename> class C, typename T>
inline void cuda_sph_with_hessians(
void* stream, void** in, const char* opaque, std::size_t opaque_len
) {
// Parse the inputs
const T* xyz = reinterpret_cast<const T*>(in[0]);
T* sph = reinterpret_cast<T*>(in[1]);
T* dsph = reinterpret_cast<T*>(in[2]);
T* ddsph = reinterpret_cast<T*>(in[3]);

const SphDescriptor& d = *UnpackDescriptor<SphDescriptor>(opaque, opaque_len);
const std::int64_t n_samples = d.n_samples;
const std::int64_t lmax = d.lmax;

auto& calculator = _get_or_create_sph_cuda<C, T>(lmax);
calculator->compute_with_hessians(xyz, n_samples, sph, dsph, ddsph, stream);
}

// Registration of the custom calls with pybind11
pybind11::dict Registrations() {
pybind11::dict dict;
dict["cuda_spherical_f32"] = EncapsulateFunction(cuda_spherical_f32);
dict["cuda_spherical_f64"] = EncapsulateFunction(cuda_spherical_f64);
dict["cuda_dspherical_f32"] = EncapsulateFunction(cuda_dspherical_f32);
dict["cuda_dspherical_f64"] = EncapsulateFunction(cuda_dspherical_f64);
dict["cuda_ddspherical_f32"] = EncapsulateFunction(cuda_ddspherical_f32);
dict["cuda_ddspherical_f64"] = EncapsulateFunction(cuda_ddspherical_f64);
dict["cuda_solid_f32"] = EncapsulateFunction(cuda_solid_f32);
dict["cuda_solid_f64"] = EncapsulateFunction(cuda_solid_f64);
dict["cuda_dsolid_f32"] = EncapsulateFunction(cuda_dsolid_f32);
dict["cuda_dsolid_f64"] = EncapsulateFunction(cuda_dsolid_f64);
dict["cuda_ddsolid_f32"] = EncapsulateFunction(cuda_ddsolid_f32);
dict["cuda_ddsolid_f64"] = EncapsulateFunction(cuda_ddsolid_f64);
dict["cuda_spherical_f32"] =
EncapsulateFunction(cuda_sph<sphericart::cuda::SphericalHarmonics, float>);
dict["cuda_spherical_f64"] =
EncapsulateFunction(cuda_sph<sphericart::cuda::SphericalHarmonics, double>);
dict["cuda_dspherical_f32"] =
EncapsulateFunction(cuda_sph_with_gradients<sphericart::cuda::SphericalHarmonics, float>);
dict["cuda_dspherical_f64"] =
EncapsulateFunction(cuda_sph_with_gradients<sphericart::cuda::SphericalHarmonics, double>);
dict["cuda_ddspherical_f32"] =
EncapsulateFunction(cuda_sph_with_hessians<sphericart::cuda::SphericalHarmonics, float>);
dict["cuda_ddspherical_f64"] =
EncapsulateFunction(cuda_sph_with_hessians<sphericart::cuda::SphericalHarmonics, double>);
dict["cuda_solid_f32"] = EncapsulateFunction(cuda_sph<sphericart::cuda::SolidHarmonics, float>);
dict["cuda_solid_f64"] = EncapsulateFunction(cuda_sph<sphericart::cuda::SolidHarmonics, double>);
dict["cuda_dsolid_f32"] =
EncapsulateFunction(cuda_sph_with_gradients<sphericart::cuda::SolidHarmonics, float>);
dict["cuda_dsolid_f64"] =
EncapsulateFunction(cuda_sph_with_gradients<sphericart::cuda::SolidHarmonics, double>);
dict["cuda_ddsolid_f32"] =
EncapsulateFunction(cuda_sph_with_hessians<sphericart::cuda::SolidHarmonics, float>);
dict["cuda_ddsolid_f64"] =
EncapsulateFunction(cuda_sph_with_hessians<sphericart::cuda::SolidHarmonics, double>);
return dict;
}

Expand All @@ -44,4 +122,5 @@ PYBIND11_MODULE(sphericart_jax_cuda, m) {
});
}

} // namespace
} // namespace cuda
} // namespace sphericart_jax
148 changes: 0 additions & 148 deletions sphericart-jax/src/sphericart_jax_cuda.cu

This file was deleted.

Loading
Loading