-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tests for out of order with checkpointing #1428
Add tests for out of order with checkpointing #1428
Conversation
How can this work for out-of-order iteration? eg let's say the first sample blocks until the end of the epoch, but you take a checkpoint in the middle, what happens then? Without replaying all of history, how will sample 0 be returned later? |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/data/1428
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cea6bdb with merge base e25df94 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
That's exactly what the tests I've added are checking - the first sample for one of the workers is slow. |
@michael-diggin let's leave the warnings in for now just in case, while we think through this a bit more. Thanks again! |
@andrewkho no problem - added back in just now. |
self.assertEqual(sorted(output), list(range(10))) | ||
|
||
def test_out_of_order_iterable_ds(self): | ||
dataset = _TestSlowIterableDataset(start=0, end=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For iterable dataset the slow worker will just lead to a straggler, on resume the individual workers will resume and continue, though be limited to single-worker performance. I think I can see why this might "just work" for Iterable datasets
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not entirely sure - I've added a second case that breaks before either worker finishes, and so they both resume after the restart, which gives the same correct results.
I think maybe the fast-forwarding part of the resuming is what is allowing this to work, and since the dataset is deterministic (ie the slow samples don't change) the fast forwarding by X samples will bring it back to the same point.
output = [] | ||
for i, data in enumerate(dataloader): | ||
output.append(data) | ||
if i == 5: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we try to set this lower than 5? I think for end = 10, slow_index=0, and num_workers = 2, i = 5 will be index 0, since 0 through 4 will come from worker 1, and worker 0 will be blocked until index 0 is released
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure that was true - since we added the re-distribution of work. I checked this test's output and 0 was retuned at index 8, so after the break point.
I've dropped the break to 3 down however, and also added an assertion to check that 0 isn't in the output at that point as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's land this once CI finishes and figure out details for the next release
Follow up to #1423
Changes
in_order=False
in_order=False
State management with
in_order=False
actually just works which is really nice.I'm happy to keep the warning logs in for longer if we'd like to err on the side of caution. And please let me know if there are any other test cases I should add.