Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for out of order with checkpointing #1428

Merged
merged 3 commits into from
Jan 30, 2025

Conversation

michael-diggin
Copy link
Contributor

Follow up to #1423

Changes

  • Add tests for state management of both index datasets and iterable datasets when in_order=False
  • Remove warning logs about in_order=False

State management with in_order=False actually just works which is really nice.

I'm happy to keep the warning logs in for longer if we'd like to err on the side of caution. And please let me know if there are any other test cases I should add.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 28, 2025
@andrewkho
Copy link
Contributor

andrewkho commented Jan 28, 2025

How can this work for out-of-order iteration? eg let's say the first sample blocks until the end of the epoch, but you take a checkpoint in the middle, what happens then? Without replaying all of history, how will sample 0 be returned later?

Copy link

pytorch-bot bot commented Jan 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/data/1428

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cea6bdb with merge base e25df94 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@michael-diggin
Copy link
Contributor Author

How can this work for out-of-order iteration? eg let's say the first sample blocks until the end of the epoch, but you take a checkpoint in the middle, what happens then? Without replaying all of history, how will sample 0 be returned later?

That's exactly what the tests I've added are checking - the first sample for one of the workers is slow.
For index datasets it looks like it uses the fact that it has progressed a certain number of steps since the last checkpoint (which is part of the state dict) to skip forward, and for the iterable dataset it looks to be using the num_yielded field from the the state dict.
If this isn't currently expected to work I can add back in the warning logs while we fully determine if it should work as is?

@andrewkho
Copy link
Contributor

@michael-diggin let's leave the warnings in for now just in case, while we think through this a bit more. Thanks again!

@michael-diggin
Copy link
Contributor Author

@michael-diggin let's leave the warnings in for now just in case, while we think through this a bit more. Thanks again!

@andrewkho no problem - added back in just now.

self.assertEqual(sorted(output), list(range(10)))

def test_out_of_order_iterable_ds(self):
dataset = _TestSlowIterableDataset(start=0, end=10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For iterable dataset the slow worker will just lead to a straggler, on resume the individual workers will resume and continue, though be limited to single-worker performance. I think I can see why this might "just work" for Iterable datasets

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure - I've added a second case that breaks before either worker finishes, and so they both resume after the restart, which gives the same correct results.
I think maybe the fast-forwarding part of the resuming is what is allowing this to work, and since the dataset is deterministic (ie the slow samples don't change) the fast forwarding by X samples will bring it back to the same point.

output = []
for i, data in enumerate(dataloader):
output.append(data)
if i == 5:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we try to set this lower than 5? I think for end = 10, slow_index=0, and num_workers = 2, i = 5 will be index 0, since 0 through 4 will come from worker 1, and worker 0 will be blocked until index 0 is released

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that was true - since we added the re-distribution of work. I checked this test's output and 0 was retuned at index 8, so after the break point.
I've dropped the break to 3 down however, and also added an assertion to check that 0 isn't in the output at that point as well.

Copy link
Contributor

@andrewkho andrewkho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's land this once CI finishes and figure out details for the next release

@ramanishsingh ramanishsingh merged commit fcdc8b9 into pytorch:main Jan 30, 2025
39 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants