diff --git a/test/test_datasets_samplers.py b/test/test_datasets_samplers.py index f99c63e65d3..90f3f3806aa 100644 --- a/test/test_datasets_samplers.py +++ b/test/test_datasets_samplers.py @@ -5,7 +5,11 @@ import unittest from torchvision import io -from torchvision.datasets.samplers import RandomClipSampler, UniformClipSampler +from torchvision.datasets.samplers import ( + DistributedSampler, + RandomClipSampler, + UniformClipSampler, +) from torchvision.datasets.video_utils import VideoClips, unfold from torchvision import get_video_backend @@ -83,6 +87,31 @@ def test_uniform_clip_sampler_insufficient_clips(self): indices = torch.tensor(list(iter(sampler))) self.assertTrue(indices.equal(torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))) + def test_distributed_sampler_and_uniform_clip_sampler(self): + with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: + video_clips = VideoClips(video_list, 5, 5) + clip_sampler = UniformClipSampler(video_clips, 3) + + distributed_sampler_rank0 = DistributedSampler( + clip_sampler, + num_replicas=2, + rank=0, + group_size=3, + ) + indices = torch.tensor(list(iter(distributed_sampler_rank0))) + self.assertEqual(len(distributed_sampler_rank0), 6) + self.assertTrue(indices.equal(torch.tensor([0, 2, 4, 10, 12, 14]))) + + distributed_sampler_rank1 = DistributedSampler( + clip_sampler, + num_replicas=2, + rank=1, + group_size=3, + ) + indices = torch.tensor(list(iter(distributed_sampler_rank1))) + self.assertEqual(len(distributed_sampler_rank1), 6) + self.assertTrue(indices.equal(torch.tensor([5, 7, 9, 0, 2, 4]))) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/datasets/samplers/clip_sampler.py b/torchvision/datasets/samplers/clip_sampler.py index 3d4c788fc61..b3c01c5e508 100644 --- a/torchvision/datasets/samplers/clip_sampler.py +++ b/torchvision/datasets/samplers/clip_sampler.py @@ -9,9 +9,32 @@ class DistributedSampler(Sampler): """ Extension of DistributedSampler, as discussed in https://github.com/pytorch/pytorch/issues/23430 + + Example: + dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + num_replicas: 4 + shuffle: False + + when group_size = 1 + RANK | shard_dataset + ========================= + rank_0 | [0, 4, 8, 12] + rank_1 | [1, 5, 9, 13] + rank_2 | [2, 6, 10, 0] + rank_3 | [3, 7, 11, 1] + + when group_size = 2 + + RANK | shard_dataset + ========================= + rank_0 | [0, 1, 8, 9] + rank_1 | [2, 3, 10, 11] + rank_2 | [4, 5, 12, 13] + rank_3 | [6, 7, 0, 1] + """ - def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False): + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, group_size=1): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") @@ -20,11 +43,20 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False): if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() + assert len(dataset) % group_size == 0, ( + "dataset length must be a multiplier of group size" + "dataset length: %d, group size: %d" % (len(dataset), group_size) + ) self.dataset = dataset + self.group_size = group_size self.num_replicas = num_replicas self.rank = rank self.epoch = 0 - self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + dataset_group_length = len(dataset) // group_size + self.num_group_samples = int( + math.ceil(dataset_group_length * 1.0 / self.num_replicas) + ) + self.num_samples = self.num_group_samples * group_size self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle @@ -41,8 +73,14 @@ def __iter__(self): indices += indices[:(self.total_size - len(indices))] assert len(indices) == self.total_size + total_group_size = self.total_size // self.group_size + indices = torch.reshape( + torch.LongTensor(indices), (total_group_size, self.group_size) + ) + # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank:total_group_size:self.num_replicas, :] + indices = torch.reshape(indices, (-1,)).tolist() assert len(indices) == self.num_samples if isinstance(self.dataset, Sampler):