-
What is necessary in order to get deterministic results on a GPU? I was under the impression that we get determinism "for free" by virtue of using
... and my own code. Each of these issues/PRs has a variety of comments from different people with varying claims of what flags are sufficient to receive deterministic results, but they vary in exactly what those flags are, whether or not the flags are deprecated, and some of the comments are relatively old. There's also been discussion of a What is the official stance on how to achieve deterministic results with JAX on a GPU? Is there documentation for this? When should I file a determinism bug report? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Hey! It seems like a good time to answer this question! Or at least, better late than never... I want to distinguish two things:
I think the question here is really about the latter. GPUs have several sources of nondeterminism, as mentioned in the (unmerged) #4824. For example, XLA:GPU might autotune cudnn conv kernels or Triton kernels, deciding which to employ based on which it measures to be fastest on your card at the moment. As another example, any multithreaded reduction operation implemented using atomics will be nondeterministic from execution to execution just by virtue of how GPU hardware works. These don't have to do with pseudorandom numbers at all, and would potentially affect any code run on GPU. I'm 95% sure that we get full determinism, at some performance penalty, by setting I'm working on finding out and then finally merging a version of #4824. EDIT The above is sufficient for determinism, but not necessary. All you need for Nvidia GPUs is |
Beta Was this translation helpful? Give feedback.
-
Hi @mattjj, I’m currently using an H100 GPU, and I've noticed that setting Thanks in advance for your help! |
Beta Was this translation helpful? Give feedback.
Hey! It seems like a good time to answer this question! Or at least, better late than never...
I want to distinguish two things:
jax.random
generates deterministic random bits as a pure function of the random seed given, andjax.random
or pseudorandom number generation.I think the question here is really about the latter. GPUs have several sources of nondeterminism, as mentioned in the (unmerged) #4824. For example, XLA:GPU might autotune cudnn conv kernels or Triton kernels, deciding which to employ based on which it measures to be fastest on your card at the moment. As another example, any multithreaded reductio…