-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update comments in data parallel example to use sampler #7914
base: master
Are you sure you want to change the base?
Conversation
if xr.world_size() > 1: | ||
train_sampler = torch.utils.data.distributed.DistributedSampler( | ||
train_dataset, | ||
num_replicas=xr.world_size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: dist.world_size
train_sampler = torch.utils.data.distributed.DistributedSampler( | ||
train_dataset, | ||
num_replicas=xr.world_size(), | ||
rank=xr.global_ordinal(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: dist.get_rank
# want each process to handle different parts of the data. | ||
''' | ||
train_sampler = None | ||
if xr.world_size() > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also dist.world_size
# below code is commented out because in this example we used a fake data | ||
# loader that does not take sampler. However this logic is needed if you | ||
# want each process to handle different parts of the data. | ||
''' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it clearer to just apply this sampler to the fake dataset anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fake dataset is a XLA util https://github.com/pytorch/xla/blob/master/examples/train_resnet_base.py#L25-L29, it does not take sampler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I just kind of assumed that SampleGenerator
was an idiomatic Dataset
. Can you try just making it inherit from IterableDataset
since it has __iter__
and __len__
already? You should then be able to wrap it in a standard sampler and data loader.
If that doesn't work, then one of us can follow up. Our examples should be as close to PyTorch as possible.
852eb2b
to
669e59f
Compare
fix #7904