Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update comments in data parallel example to use sampler #7914

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented Aug 26, 2024

fix #7904

@JackCaoG JackCaoG marked this pull request as ready for review August 26, 2024 20:10
if xr.world_size() > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xr.world_size(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dist.world_size

train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xr.world_size(),
rank=xr.global_ordinal(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dist.get_rank

# want each process to handle different parts of the data.
'''
train_sampler = None
if xr.world_size() > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also dist.world_size

# below code is commented out because in this example we used a fake data
# loader that does not take sampler. However this logic is needed if you
# want each process to handle different parts of the data.
'''
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it clearer to just apply this sampler to the fake dataset anyway?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fake dataset is a XLA util https://github.com/pytorch/xla/blob/master/examples/train_resnet_base.py#L25-L29, it does not take sampler.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I just kind of assumed that SampleGenerator was an idiomatic Dataset. Can you try just making it inherit from IterableDataset since it has __iter__ and __len__ already? You should then be able to wrap it in a standard sampler and data loader.

If that doesn't work, then one of us can follow up. Our examples should be as close to PyTorch as possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch_xla.distributed.parallel_loader doesn't shard data
2 participants