Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss curve spikes on amalagamated datasets - need full scale shuffler in dataloader #128

Closed
lessw2020 opened this issue Mar 12, 2024 · 6 comments
Labels
enhancement New feature or request

Comments

@lessw2020
Copy link
Contributor

As part of e2e training, encountered wild loss curve spikes:

Screenshot 2024-03-07 at 8 40 55 PM

After additional hyperparam tuning and further investigation, the root cause is that we are reading the dataset sequentially, so to the model, it sees data type A...learns and improves, then hits data type B...suprised (spikes) but then learns and improves, repeat.

By training with a 'single data source' dataset, in this case openwebtext, we see a very nice loss curve on e2e training, showcasing that the issue is the lack of shuffling:
Screenshot 2024-03-12 at 9 50 57 AM

@tianyu-l tianyu-l added the enhancement New feature or request label May 3, 2024
@XinDongol
Copy link

@tianyu-l @lessw2020 FYI, I am using this trick.

  hf_ds = HuggingFaceDataset(
      dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
  )
  if shuffle:
      hf_ds._data = hf_ds._data.shuffle(seed=int(rank*10007+int(time.time())))

@TJ-Solergibert
Copy link

@XinDongol Why would you shuffle the dataset with that seed? Now that Stateful DataLoaders will merge soon, you won't be able to resume training from a crash properly because you don't know how you shuffled the dataset.

Random seeds are used to ensure that results are reproducible, in this case it's completely the opposite.

@tianyu-l
Copy link
Contributor

  hf_ds = HuggingFaceDataset(
      dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
  )
  if shuffle:
      hf_ds._data = hf_ds._data.shuffle(seed=int(rank*10007+int(time.time())))

@XinDongol For map-style dataset, this works as expected. However, for IterableDataset a buffer is used to create apply randomness within. The issue won't be fixed if the buffer size is not / cannot be large enough to cover different amalgamated datasets.

@TJ-Solergibert Checkpointing the random seeds used to shuffle the dataset would solve the problem. FYI it is on our roadmap.

@TJ-Solergibert
Copy link

Thanks for your answer @tianyu-l , it makes sense 😅

I was wondering, any idea to not use .skip() when resuming training? In my setup (& colab), skipping 10000000 samples took 90s approximately.

from datasets import load_dataset
ds = load_dataset("allenai/c4", name="en", split="train", streaming=True)
ds = ds.skip(10000000)
ds = iter(ds)
next(ds)

@tianyu-l
Copy link
Contributor

I was wondering, any idea to not use .skip() when resuming training? In my setup (& colab), skipping 10000000 samples took 90s approximately.

@TJ-Solergibert

  1. We should use .skip() when resuming training. In fact, it has been put into Use stateful dataloader to checkpoint data iteration order and token buffer #279.
  2. It doesn't mean this is the ideal solution. E.g., the C4 en section has more than 300M entries, which, according to your example, means over 45min to skip if we stop somewhere towards the end of the dataset. Ideally, even for streaming=True IterableDataset, skip should be able to directly seek the file position. As far as we know this is something HF is working on.

@tianyu-l
Copy link
Contributor

Shuffling at the entire dataset level should be part of data preprocessing, not data loading. So closing this task for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants