Skip to content

Commit

Permalink
Remove unused dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jun 29, 2024
1 parent a74ced0 commit 768702a
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 393 deletions.
301 changes: 0 additions & 301 deletions gns/data_loader.py

This file was deleted.

30 changes: 0 additions & 30 deletions gns/distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from torch.utils import collect_env
from torch.utils.data.distributed import DistributedSampler

from gns import data_loader


def setup(local_rank: int):
"""Initializes distributed training."""
Expand Down Expand Up @@ -52,31 +50,3 @@ def spawn_train(train_fxn, flags, world_size, device):
torch.multiprocessing.spawn(
train_fxn, args=(flags, world_size, device), nprocs=world_size, join=True
)


def get_data_distributed_dataloader_by_samples(
path, input_length_sequence, batch_size, shuffle=True
):
"""Returns a distributed dataloader.
Args:
path (str): Path to dataset.
input_length_sequence (int): Length of input sequence.
batch_size (int): Batch size.
shuffle (bool): Whether to shuffle dataset.
"""
dataset = data_loader.SamplesDataset(path, input_length_sequence)
sampler = DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=shuffle,
)

return torch.utils.data.DataLoader(
dataset=dataset,
sampler=sampler,
batch_size=batch_size,
pin_memory=True,
collate_fn=data_loader.collate_fn,
)
Loading

0 comments on commit 768702a

Please sign in to comment.