Skip to content

Commit

Permalink
extend DistributedSampler to support group_size (pytorch#1512)
Browse files Browse the repository at this point in the history
* extend DistributedSampler to support group_size

* Fix lint
  • Loading branch information
stephenyan1231 authored and fmassa committed Oct 22, 2019
1 parent b60cb72 commit 355e9d2
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 4 deletions.
31 changes: 30 additions & 1 deletion test/test_datasets_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import unittest

from torchvision import io
from torchvision.datasets.samplers import RandomClipSampler, UniformClipSampler
from torchvision.datasets.samplers import (
DistributedSampler,
RandomClipSampler,
UniformClipSampler,
)
from torchvision.datasets.video_utils import VideoClips, unfold
from torchvision import get_video_backend

Expand Down Expand Up @@ -83,6 +87,31 @@ def test_uniform_clip_sampler_insufficient_clips(self):
indices = torch.tensor(list(iter(sampler)))
self.assertTrue(indices.equal(torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11])))

def test_distributed_sampler_and_uniform_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
clip_sampler = UniformClipSampler(video_clips, 3)

distributed_sampler_rank0 = DistributedSampler(
clip_sampler,
num_replicas=2,
rank=0,
group_size=3,
)
indices = torch.tensor(list(iter(distributed_sampler_rank0)))
self.assertEqual(len(distributed_sampler_rank0), 6)
self.assertTrue(indices.equal(torch.tensor([0, 2, 4, 10, 12, 14])))

distributed_sampler_rank1 = DistributedSampler(
clip_sampler,
num_replicas=2,
rank=1,
group_size=3,
)
indices = torch.tensor(list(iter(distributed_sampler_rank1)))
self.assertEqual(len(distributed_sampler_rank1), 6)
self.assertTrue(indices.equal(torch.tensor([5, 7, 9, 0, 2, 4])))


if __name__ == '__main__':
unittest.main()
44 changes: 41 additions & 3 deletions torchvision/datasets/samplers/clip_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,32 @@ class DistributedSampler(Sampler):
"""
Extension of DistributedSampler, as discussed in
https://github.com/pytorch/pytorch/issues/23430
Example:
dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
num_replicas: 4
shuffle: False
when group_size = 1
RANK | shard_dataset
=========================
rank_0 | [0, 4, 8, 12]
rank_1 | [1, 5, 9, 13]
rank_2 | [2, 6, 10, 0]
rank_3 | [3, 7, 11, 1]
when group_size = 2
RANK | shard_dataset
=========================
rank_0 | [0, 1, 8, 9]
rank_1 | [2, 3, 10, 11]
rank_2 | [4, 5, 12, 13]
rank_3 | [6, 7, 0, 1]
"""

def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, group_size=1):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
Expand All @@ -20,11 +43,20 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False):
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
assert len(dataset) % group_size == 0, (
"dataset length must be a multiplier of group size"
"dataset length: %d, group size: %d" % (len(dataset), group_size)
)
self.dataset = dataset
self.group_size = group_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
dataset_group_length = len(dataset) // group_size
self.num_group_samples = int(
math.ceil(dataset_group_length * 1.0 / self.num_replicas)
)
self.num_samples = self.num_group_samples * group_size
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle

Expand All @@ -41,8 +73,14 @@ def __iter__(self):
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size

total_group_size = self.total_size // self.group_size
indices = torch.reshape(
torch.LongTensor(indices), (total_group_size, self.group_size)
)

# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
indices = indices[self.rank:total_group_size:self.num_replicas, :]
indices = torch.reshape(indices, (-1,)).tolist()
assert len(indices) == self.num_samples

if isinstance(self.dataset, Sampler):
Expand Down

0 comments on commit 355e9d2

Please sign in to comment.