Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.data.experimental.sample_from_datasets non-deterministic in multi-gpu. #53846

Closed
yanniskar opened this issue Jan 21, 2022 · 19 comments
Closed
Assignees
Labels
comp:dist-strat Distribution Strategy related issues TF 2.4 for issues related to TF 2.4 type:bug Bug

Comments

@yanniskar
Copy link

See NVIDIA/framework-reproducibility#39

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: Does not apply
  • TensorFlow installed from (source or binary): pip
  • TensorFlow version (use command below): 2.4.1
  • Python version: 3.7
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: 11.2
  • GPU model and memory:
@yanniskar
Copy link
Author

@duncanriach

@duncanriach
Copy link
Contributor

@reedwm, bringing your attention to this potential issue for the 2.8 release.

@yanniskar, I see that you're using version 2.4.1; are you able to see if this issue exists in the 2.8.0-rc0 pre-release (e.g. by using the tensorflow/tensorflow:2.8.0rc0-gpu docker image or pip install tensorflow==2.8.0rc0)? This issue may have been fixed since 2.4.

@sushreebarsa sushreebarsa added comp:data tf.data related issues TF 2.4 for issues related to TF 2.4 stat:awaiting response Status - Awaiting response from author labels Jan 21, 2022
@reedwm
Copy link
Member

reedwm commented Jan 21, 2022

Agree with @duncanriach that you should try on 2.8.0rc0. I'm not sure why this would be nondeterministic on TF 2.4, but it's possible this was fixed since then.

@sushreebarsa sushreebarsa removed the stat:awaiting response Status - Awaiting response from author label Jan 22, 2022
@sushreebarsa
Copy link
Contributor

@yanniskar Could you please try with TF v2.8.0rc0 and let us know the outcome? Please refer to the above comments as well. Thanks!

@sushreebarsa sushreebarsa added the stat:awaiting response Status - Awaiting response from author label Jan 22, 2022
@yanniskar
Copy link
Author

Thanks for the response @sushreebarsa @duncanriach. I need to use the tensorflow version my org is on. Thus, I cannot upgrade to 2.8 to test this as my org does not support cloud training on that tensorflow version. If it is hard for you guys to verify this in 2.8, I can look into a workaround for testing my code in 2.8. Given other work priorities, it might take me some time before I am able to do this though.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jan 29, 2022
@duncanriach
Copy link
Contributor

duncanriach commented Jan 29, 2022

@yanniskar, if you cannot confirm that this issue is not present in the latest version of TensorFlow, then we cannot either, because we do not have access to the code you're running in order to run it on the latest version ourselves.

Another way to move this issue forward is for you to attempt to provide a self-contained reproducer, a simple piece of code that you can run and also share with us. That way, we can look at exactly what you're looking at, observe it on version 2.4.1, test if it's still present on the latest version, and then be able to potentially debug it. There are too many variables in these systems to be able to debug something that we cannot examine.

The following reproducer code demonstrates the kind of minimal example that could be provided to reproduce the observed multi-device (distributed) nondeterminism. This example distributes the dataset to both the GPU and CPU, with the two-element batch being split into one element per device. For this kind of problem (probably related to dataset distribution between devices), I doubt it matters what kind of devices are used.

The intention should be to recreate the basic configuration as accurately and minimally as possible. For example, it would probably be important to capture the distribution strategy used and when and how the dataset(s) are distributed (such as before or after applying sample_from_datasets).

With determinism, the devices should print the same sequence of values on each run, as they currently do in this example.

The following code can be run and modified in a copy of this colab notebook.

dataset1 = tf.data.Dataset.from_tensor_slices([[10, 11], [12, 13], [14, 15], [16, 17]])
dataset2 = tf.data.Dataset.from_tensor_slices([[21, 22], [23, 24], [25, 26], [27, 28]])
sample_dataset = tf.data.experimental.sample_from_datasets(
  [dataset1, dataset2], weights=[0.5, 0.5], seed=43)

my_strategy = tf.distribute.MirroredStrategy(["GPU:0", "CPU:0"])
with my_strategy.scope():
  @tf.function
  def distribute_train_epoch(dataset):
    for x in dataset:
      my_strategy.run(print, args=(x,))
 
  dist_dataset = my_strategy.experimental_distribute_dataset(sample_dataset)

for _ in range(2):
  print("------------------")
  distribute_train_epoch(dist_dataset)

Output:

------------------
[10]
[11]
[21]
[22]
[23]
[24]
[12]
[13]
[25]
[26]
[14]
[15]
[27]
[28]
[16]
[17]
------------------
[10]
[11]
[21]
[22]
[23]
[24]
[12]
[13]
[25]
[26]
[14]
[15]
[27]
[28]
[16]
[17]

@sushreebarsa sushreebarsa added the stat:awaiting response Status - Awaiting response from author label Jan 29, 2022
@duncanriach
Copy link
Contributor

@yanniskar, are you using XLA (e.g. @tf.function(jit_compile=True))?

@kabilan6
Copy link

kabilan6 commented Feb 3, 2022

when the dataset contains 1000 to millions , its non deterministic. we noticed similar problem with tf 2.7 version.

@duncanriach
Copy link
Contributor

duncanriach commented Feb 3, 2022

Hi @kabilan6,

Thanks for that. For clarity, please confirm or refute the following four points:

  1. You have a model that trains deterministically on a single GPU.
  2. When you use more than one GPU (including only two GPUs), you get nondeterminism.
  3. You're using tf.data.experimental.sample_from_datasets.
  4. When you remove only tf.data.experimental.sample_from_datasets, the nondeterminism goes away.
  5. The newest version of TensorFlow that you have reproduced this issue on is 2.7.

Please answer the following question:

Are you using XLA (e.g. @tf.function(jit_compile=True))?

@kabilan6
Copy link

kabilan6 commented Feb 3, 2022

Hello, this deterministic issue is there across and not specific to tf.data.experimental.sample_from_datasets. Please refer a similar issue which i created
#54259.

Please find below my response

You have a model that trains deterministically on a single GPU. - amn't sure about training but i guess its not related to training.
When you use more than one GPU (including only two GPUs), you get nondeterminism. -- > i noticed nondeterminism with single or multiGPU (my example used batch function)
You're using tf.data.experimental.sample_from_datasets.--> no, i used experimental_distribute_dataset
When you remove only tf.data.experimental.sample_from_datasets, the nondeterminism goes away. --> I still noticed issue i guess its tied to tf.data.batch
The newest version of TensorFlow that you have reproduced this issue on is 2.7. --> yes for my usecase i noticed it with 2.7version

@duncanriach
Copy link
Contributor

Okay, @kabilan6. From looking at your answers, I'm almost certain that you're dealing with a different issue because (1) your issue occurs with a single GPU and (2) you're not using sample_from_datasets. Thanks for opening #54259.

@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Feb 10, 2022
@yanniskar
Copy link
Author

yanniskar commented Feb 11, 2022

@duncanriach sorry for going radio silent for two weeks. Work has really picked up lately so I have not had time to come back to this. Here are my responses to your comments in order:

  1. I will try to reproduce the issue using a self-contained reproducer like you suggested. I will also do this using the latest Tensorflow version. Once I do that, I will report my findings back in this thread.
  2. As far as I know, I am not using XLA. For more context, I am using a simple keras model.fit training loop with mirrored strategy for this problem.

How does that plan sound? Lmk if you want more info on the XLA matter.

@google-ml-butler google-ml-butler bot removed stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author labels Feb 11, 2022
@duncanriach
Copy link
Contributor

Sounds great. Thanks, @yanniskar.

@plentx
Copy link

plentx commented Feb 11, 2022 via email

@sushreebarsa sushreebarsa removed the comp:data tf.data related issues label Feb 22, 2022
@sushreebarsa sushreebarsa added the comp:dist-strat Distribution Strategy related issues label Feb 22, 2022
@gadagashwini
Copy link
Contributor

@yanniskar, Did you try @duncanriach workaround.
And Please let us know if this is still an issue. Thanks!

@gadagashwini gadagashwini added the stat:awaiting response Status - Awaiting response from author label Feb 24, 2022
@yanniskar
Copy link
Author

@yanniskar, Did you try @duncanriach workaround. And Please let us know if this is still an issue. Thanks!

No luck unfortunately. Work has really picked up so I have not had time to investigate this due to competing priorities and this not being a blocking issue for development. Feel free to close this issue and I will circle back once I find some time (probably on my next PTO) to run the investigation I proposed. I don't think it is fair considering this an active issue given I have not verified it on the latest version of Tensorflow.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Feb 28, 2022
@gadagashwini
Copy link
Contributor

Thanks for confirming @yanniskar.
If you face same issue on latest version please feel free to reopen. Thanks!

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:dist-strat Distribution Strategy related issues TF 2.4 for issues related to TF 2.4 type:bug Bug
Projects
None yet
Development

No branches or pull requests

8 participants