-
-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate also jax-cuda-plugin and jax-cuda-pjrt in cuda builds and bump CUDA used at built time to 12.6 #288
base: main
Are you sure you want to change the base?
Conversation
The |
Hi! This is the friendly automated conda-forge-linting service. I just wanted to let you know that I linted all conda-recipes in your PR ( |
@conda-forge-admin, please rerender |
…nda-forge-pinning 2024.11.18.19.00.37
I started a built of a |
I tried to build several hours ago but got the following error
|
Indeed the same for me:
The full log: log-jaxlib-cuda.txt . Probably somehow some headers try to use the internal cudnn. |
Probably we need to add |
The cudnn fix worked fine, now the new error is:
|
The only occurrence of a similar problem are in conda-forge/bazel-feedstock#188 (comment), but then the affected user reports that the problem was solved, without saying what is the corresponding change (see https://xkcd.com/979/, but in this case the user is myself :D ). |
Actually, now that I think of this, probably I did a patched that then was rebased together to clean the PR. Probably the related patch is something like https://github.com/conda-forge/bazel-feedstock/blob/764ac0bb362224f0e8deb53b1a6a3f441b6ead7d/recipe/patches/0002-Build-with-native-dependencies.patch#L179-L189 . |
The linker command seems contain some absl libraries, but not all the one required:
|
After a bit of an hack (passing the missing linker flags all as part of an unrelated absl target that I know as linked) the compilation end successfully, but the produced jaxlib crashes at runtime:
|
The backtrace is:
|
Related to this part of code: https://github.com/openxla/xla/blob/7fd2196f3f21f67bd1bbde9adfe819117454acb3/xla/pjrt/c/pjrt_c_api_gpu.cc#L25-L30 . |
xref: abseil/abseil-cpp#1656 . |
Indeed this issue seems to describe exactly the issue. In a nutshell, apparently two parts of the code call Probably this does not happen on the PyPI packages, as there the Possible solutions: Use static abseil (at least for
|
recipe/build.sh
Outdated
@@ -78,7 +78,7 @@ build --verbose_failures | |||
build --toolchain_resolution_debug | |||
build --define=PREFIX=${PREFIX} | |||
build --define=PROTOBUF_INCLUDE_PATH=${PREFIX}/include | |||
build --local_cpu_resources=${CPU_COUNT} | |||
build --local_cpu_resources=120 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, this was not supposed to be committed, my bad.
That's a workaround I would be happy with for now. I would expect that the packages will always be imported one after another, |
…nda-forge-pinning 2024.11.22.09.17.35
Cleaned up a bit and implemented the suggestion. @traversaro Can you check whether this fixes your problem? |
Hi! This is the friendly automated conda-forge-linting service. I just wanted to let you know that I linted all conda-recipes in your PR ( I do have some suggestions for making it better though... For recipe/meta.yaml:
This message was generated by GitHub Actions workflow run https://github.com/conda-forge/conda-forge-webservices/actions/runs/11993479866. Examine the logs at this URL for more detail. |
Thanks, I will check it now. |
@@ -399,6 +350,10 @@ index 0000000..6ff4e1d | |||
++ "-labsl_log_internal_check_op", | |||
++ "-labsl_log_internal_message", | |||
++ "-labsl_log_internal_nullguard", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need -labsl_log_internal_globals
to fix undefined reference to `absl::lts_20240722::log_internal::IsInitialized()'
I built it locally, and JAX does find my GPUs, as posted below. I upload that package to my channel if anyone wants to test it: https://anaconda.org/njzjz/jaxlib/files
|
I tried to run a simple example like
The reason is that somehow it can't find the
However, I am trying now a more complex example (based on https://github.com/ami-iit/jaxsim/blob/main/examples/jaxsim_as_physics_engine.ipynb), and it fails with this bt:
|
I was able to replicate the same segfault with an official jax esample https://github.com/jax-ml/jax/blob/jax-v0.4.34/examples/mnist_classifier.py . |
Although I didn't get such errors, when I set |
Interesting, how did you tested jax? Via https://github.com/jax-ml/jax/blob/jax-v0.4.34/examples/mnist_classifier.py or something else? |
For
But it's reasonable - the XLA code is #if CUDA_VERSION >= 12030
VLOG(2) << "Beginning stream " << stream << " capture in "
<< StreamCaptureModeToString(mode) << " mode to graph " << graph;
return cuda::ToStatus(
cuStreamBeginCaptureToGraph(stream, graph,
/*dependencies=*/nullptr,
/*dependencyData=*/nullptr,
/*numDependencies=*/0, cu_mode),
"Failed to begin stream capture to graph");
#else
return absl::UnimplementedError(
"StreamBeginCaptureToGraph is not implemented");
#endif // CUDA_VERSION >= 12030 We used CUDA 12.0 to build JAX. |
I don't get errors with With |
We may need to migrate to CUDA 12.6, see conda-forge/conda-forge-pinning-feedstock#6630 |
…nda-forge-pinning 2024.11.23.20.32.37
You're in luck, that PR was just merged a few hours ago. ;-) |
Unfortunately, it looks like the cross-compilation on aarch runs into the same error that I observed here (also due to
The bad news is that I have no idea what's happening there, though the silver lining is that it should go away on restart, once the builds from conda-forge/c-ares-feedstock#43 are through the CDN. |
To save resources, I've cancelled the aarch builds now. I'll restart once new c-ares is available. |
@conda-forge/jaxlib, this looks green, but in this case you might want to still run further tests? |
I did a couple of tests. One was on a Ubuntu 22.04 cluster node, and everything worked out of the box, so this is a net improvement over the current state of the CUDA packages. However, I suspect that this is happening as CUDA is installed even at system level (and I can't uninstall it as this is a shared cluster), and so for example Another test is on a Ubuntu 24.04 WSL2 machine, in which CUDA was not installed at the system level. In that case, I still obtain exactly the failure that I reported in #288 (comment) . Instead, everything started working there if CUDA was installed at the system level (almost, I had to manually copy However, WSL2 + Nvidia GPU support is listed as an experimental (see https://github.com/jax-ml/jax/blob/b372ce4b1ab0bee7a1da495b098ff3948a6c0d4d/README.md?plain=1#L388), so this is probably not a great test. Probably the ideal tests would be:
Anyhow, having jaxlib working with CUDA in a system where cuda is installed at the system level is still a net improvement over the current status quo where the CUDA packages are always broken, so personally I am not against in merging, while I try to continue the investigation. @njzjz if you did further tests feel free to report them, thanks! |
TensorFlow has the same issue with XLA; see conda-forge/tensorflow-feedstock#296 (comment). I believe @hmaarrfk has done some work with XLA but I am not sure whether the problem has been resolved. |
Thanks a lot for the pointer @njzjz, this is really useful. From that thread it seems that the ideas it to set the absolute location of Related code (commits are random, not the one actually used by jaxlib): |
Ok, I replicated the segfault in #288 (comment) even on a physical linux machine, fortunately creating a Docker instances that exposes the host CUDA without actually installing CUDA in the docker image is as simple as:
|
The |
|
Yes, that segfault. The related part of the code is https://github.com/openxla/xla/blob/626f1d2aadd2bb6d2217ffdcf6dba3933cffa183/xla/stream_executor/cuda/cuda_blas.cc#L188-L208 . I need to understand how to investigate better but my guess is that the following is happening: somehow the cuda cuBLAS is not found/not initialized (and this is the real problem), while if cuBLAS is installed in the system it is correctly found/initialized. Then, an error message would be printed, but using the log results in a segfault. |
Ok, for now I just inspected the code, but I think I am understanding what is going on (no, that was the wrong system). However, the CUDA xla plugins calls cuBLAS via a trampoline, and the trampoline is quite picky on the version of cuBLAS installed, trying explicitly to load the exact version used to build:
|
Ok, I noticed that also |
Attempt to fix #285 and conda-forge/jax-feedstock#162 .
Checklist
0
(if the version changed)conda-smithy
(Use the phrase@conda-forge-admin, please rerender
in a comment in this PR for automated rerendering)