You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When applying a simple MLP with Flax, the same input vector (as a single batch vs. the first row of a larger batch) does not produce bitwise-identical outputs, even when running on CPU and with XLA_FLAGS=--xla_gpu_deterministic_ops=true on GPU. The differences are small (1e-7 level), but they break strict equality checks in my application.
I think there's a few separate issues being discussed here - (a) deterministic results on GPU and (b) identical results for unbatched vs batched computation.
For (a), can you confirm that the GPU results are indeed deterministic? While the batched results are different from the unbatched results, running the batched computation twice should yield the same (but different from unbatched) answer.
For (b), it's possible that different array shapes could change the numerics slightly (e.g. the underlying kernel could pick different hyperparameters that affect the result) - this is even true if you do a matrix multiplication as one single call (A @ B) versus blocked (concat([A1 @ B, A2 @ B])). There isn't really an easy fix here so I would recommend swapping strict equality checks when using FP arithmetic to checks with tolerance.
(a) This is confirmed to be true. Please check the colab to see the test.
(b) The last cell in the colab roughly tests this, and yes, different sized arrays cause problems.
This is frustrating because my application requires reproducible, high-precision results. Without them, it fails entirely. I specifically chose JAX because I believed it supported strict determinism for such numerical computations.
Could we stabilise the length the kernel picks? If that is not possible, what alternatives exist for enforcing strict reproducibility in JAX?
Description
When applying a simple MLP with Flax, the same input vector (as a single batch vs. the first row of a larger batch) does not produce bitwise-identical outputs, even when running on CPU and with
XLA_FLAGS=--xla_gpu_deterministic_ops=true
on GPU. The differences are small (1e-7
level), but they break strict equality checks in my application.Colab notebook to reproduce.
System info (python version, jaxlib version, accelerator, etc.)
Colab CPU:
Colab GPU:
The text was updated successfully, but these errors were encountered: