Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deterministic execution produces different results for single vs. batch input on both CPU and GPU #26795

Open
jashvira opened this issue Feb 27, 2025 · 2 comments
Labels
bug Something isn't working

Comments

@jashvira
Copy link

jashvira commented Feb 27, 2025

Description

When applying a simple MLP with Flax, the same input vector (as a single batch vs. the first row of a larger batch) does not produce bitwise-identical outputs, even when running on CPU and with XLA_FLAGS=--xla_gpu_deterministic_ops=true on GPU. The differences are small (1e-7 level), but they break strict equality checks in my application.

Colab notebook to reproduce.

System info (python version, jaxlib version, accelerator, etc.)

Colab CPU:

jax:    0.4.33
jaxlib: 0.4.33
numpy:  1.26.4
python: 3.11.11 (main, Dec  4 2024, 08:55:07) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='79425bdf5d7f', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')

Colab GPU:

jax:    0.4.33
jaxlib: 0.4.33
numpy:  1.26.4
python: 3.11.11 (main, Dec  4 2024, 08:55:07) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='7f58aca50f90', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')
@jashvira jashvira added the bug Something isn't working label Feb 27, 2025
@justinjfu
Copy link
Collaborator

justinjfu commented Feb 27, 2025

I think there's a few separate issues being discussed here - (a) deterministic results on GPU and (b) identical results for unbatched vs batched computation.

For (a), can you confirm that the GPU results are indeed deterministic? While the batched results are different from the unbatched results, running the batched computation twice should yield the same (but different from unbatched) answer.

For (b), it's possible that different array shapes could change the numerics slightly (e.g. the underlying kernel could pick different hyperparameters that affect the result) - this is even true if you do a matrix multiplication as one single call (A @ B) versus blocked (concat([A1 @ B, A2 @ B])). There isn't really an easy fix here so I would recommend swapping strict equality checks when using FP arithmetic to checks with tolerance.

@jashvira
Copy link
Author

(a) This is confirmed to be true. Please check the colab to see the test.
(b) The last cell in the colab roughly tests this, and yes, different sized arrays cause problems.

This is frustrating because my application requires reproducible, high-precision results. Without them, it fails entirely. I specifically chose JAX because I believed it supported strict determinism for such numerical computations.

Could we stabilise the length the kernel picks? If that is not possible, what alternatives exist for enforcing strict reproducibility in JAX?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants