Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax compilation hangs indefinitely #26767

Open
n-gao opened this issue Feb 26, 2025 · 0 comments
Open

Jax compilation hangs indefinitely #26767

n-gao opened this issue Feb 26, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@n-gao
Copy link

n-gao commented Feb 26, 2025

Description

JAX compilation freezes in versions 0.5.0 and 0.5.1 in seemingly unpredictable settings. Unfortunately, I was unable to produce a minimal breaking example but the following instructions should suffice to reproduce the issue (evaluated on 3 different machines with different CUDA versions and drivers).
Setup:

git clone 
cd neural-pfaffian

The following code hangs indefinitely

uv add "jax==0.5.1"
uv run --with "jax==0.5.1" neural_pfaffian with 'vmc={'"'"'mcmc'"'"': {'"'"'blocks'"'"': 3, '"'"'steps'"'"': 30}}' 'systems={'"'"'molecules'"'"': [['"'"'deeperwin_molecule'"'"', {'"'"'name'"'"': '"'"'cumulene_C8H4_0deg_singlet'"'"'}]], '"'"'num_walker_per_mol'"'"': 4096}' config/models/psiformer.yaml

Though, using an older version works

uv add "jax==0.4.38"
uv run --with neural_pfaffian with 'vmc={'"'"'mcmc'"'"': {'"'"'blocks'"'"': 3, '"'"'steps'"'"': 30}}' 'systems={'"'"'molecules'"'"': [['"'"'deeperwin_molecule'"'"', {'"'"'name'"'"': '"'"'cumulene_C8H4_0deg_singlet'"'"'}]], '"'"'num_walker_per_mol'"'"': 4096}' config/models/psiformer.yaml

Weirdly enough choosing a different system also works (in both jax versions). These only vary in some tensor shapes.

uv add "jax==0.5.1"
uv run neural_pfaffian with config/models/psiformer.yaml '''systems.molecules=[["deeperwin_molecule", {"name": "cumulene_C4H4_0deg_singlet"}]]''' systems.num_walker_per_mol=4096

I attached zips with the XLA dumps for all three experiments.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.1
jaxlib: 0.5.1
numpy:  2.2.3
python: 3.12.5 (main, Aug 14 2024, 05:08:31) [Clang 18.1.8 ]
device info: NVIDIA A100-SXM4-40GB-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='mcml-dgx-008', release='5.15.0-1071-nvidia', version='#72-Ubuntu SMP Thu Jan 16 00:47:54 UTC 2025', machine='x86_64')


$ nvidia-smi
Wed Feb 26 15:26:14 2025
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.230.02             Driver Version: 535.230.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-40GB          On  | 00000000:07:00.0 Off |                    0 |
| N/A   28C    P0              58W / 400W |    429MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM4-40GB          On  | 00000000:47:00.0 Off |                    0 |
| N/A   29C    P0              58W / 400W |    425MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM4-40GB          On  | 00000000:87:00.0 Off |                    0 |
| N/A   32C    P0              62W / 400W |    425MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA A100-SXM4-40GB          On  | 00000000:B7:00.0 Off |                    0 |
| N/A   31C    P0              62W / 400W |    425MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A   3882149      C   ...2/neural-pfaffian/.venv/bin/python3      420MiB |
|    1   N/A  N/A   3882149      C   ...2/neural-pfaffian/.venv/bin/python3      416MiB |
|    2   N/A  N/A   3882149      C   ...2/neural-pfaffian/.venv/bin/python3      416MiB |
|    3   N/A  N/A   3882149      C   ...2/neural-pfaffian/.venv/bin/python3      416MiB |
+---------------------------------------------------------------------------------------+

xla_dump1.zip
xla_dump2.zip
xla_dump3.zip

@n-gao n-gao added the bug Something isn't working label Feb 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant