Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gencast_mini_demo.ipynb on AMD CPU #113

Open
dkokron opened this issue Dec 21, 2024 · 7 comments
Open

gencast_mini_demo.ipynb on AMD CPU #113

dkokron opened this issue Dec 21, 2024 · 7 comments

Comments

@dkokron
Copy link

dkokron commented Dec 21, 2024

I'm attempting to run the gencast_mini_demo.ipynb case on my home workstation without a GPU. The notebook recognizes that I don't have the correct software to run on the installed GPU and fails over to CPU (which is what want to happen).

Output from cell 22.
WARNING:2024-12-21 14:22:21,184:jax._src.xla_bridge:969: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

I've attached the stack trace I get from cell 23 (Autoregressive rollout (loop in python)).
gencast.failure.txt

Is this expected? Does GenCast require a GPU or TPU to work?

@andrewlkd
Copy link
Collaborator

Hey,

This looks like a splash attention related error. Splash attention is only supported on TPU.

You can try follow the GPU instructions to change attention mechanism, I believe this should work fine on CPU.

Note that without knowing the memory specifications of your device, I can't guarantee it won't run out of memory. We've also never run GenCast on CPU so cannot make any guarantees around its correctness.

Hope that helps!

Andrew

@dkokron
Copy link
Author

dkokron commented Dec 21, 2024

I will try your suggestion and report back here.

@dkokron
Copy link
Author

dkokron commented Dec 22, 2024

I followed the suggestion in the "Running Inference on GPU" section of cloud_vm_setup.md

task_config = ckpt.task_config
sampler_config = ckpt.sampler_config
noise_config = ckpt.noise_config
noise_encoder_config = ckpt.noise_encoder_config
denoiser_architecture_config = ckpt.denoiser_architecture_config
denoiser_architecture_config.sparse_transformer_config.attention_type = "triblockdiag_mha"
denoiser_architecture_config.sparse_transformer_config.mask_type = "full"

The job (4 time steps and 8 members) ran for about 2h:30m using 17GB of system RAM with an averaged CPU load of ~30 (I have 48 cores). Unfortunately, the results are all NaN.

GenCast/graphcast/GenCast/lib/python3.12/site-packages/numpy/lib/_nanfunctions_impl.py:1409: RuntimeWarning: All-NaN slice encountered
return _nanquantile_unchecked(

@andrewlkd
Copy link
Collaborator

I can't say I've seen this warning before. Could you confirm if the entire forecast was NaN? Note that we expect NaNs in the sea surface temperature variable so I wonder if this is what you might be encountering.

@dkokron
Copy link
Author

dkokron commented Dec 24, 2024

I was plotting 2m_temp for all 8 ensemble members. All members had this same warning. I'll need to run it again to view other variables.

@dkokron
Copy link
Author

dkokron commented Dec 24, 2024

specific humidity at 850 and 100, vertical speed at 850, geopotential at 500 and u and v components of wind at 925 are also NaN. I did not look at the rest.

@dkokron
Copy link
Author

dkokron commented Dec 29, 2024

Any more ideas on how to investigate this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants