Skip to content

Commit

Permalink
You must set batch size. There is no other way. (#624)
Browse files Browse the repository at this point in the history
* batch size MUST BE SET

* throw error on iter

* added test
  • Loading branch information
snarayan21 authored Mar 19, 2024
1 parent 5dd1109 commit d8bf491
Show file tree
Hide file tree
Showing 29 changed files with 132 additions and 83 deletions.
4 changes: 2 additions & 2 deletions docs/source/getting_started/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ To load the same dataset files that were created in the above steps, create a `C
from streaming import StreamingDataset

class CustomDataset(StreamingDataset):
def __init__(self, local, remote):
super().__init__(local=local, remote=remote)
def __init__(self, local, remote, batch_size):
super().__init__(local=local, remote=remote, batch_size=batch_size)

def __getitem__(self, idx: int) -> Any:
obj = super().__getitem__(idx)
Expand Down
7 changes: 4 additions & 3 deletions simulation/core/sim_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ class SimulationDataset(StreamingDataset):
For sequential sample ordering, set ``shuffle`` to ``False`` and
``num_canonical_nodes`` to the number of physical nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
batch_size (int, optional): Per-device batch size, the same as what is passed to the
DataLoader. This affects how the dataset is partitioned over the workers and is
necessary for deterministic resumption and optimal performance. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``.
Expand Down Expand Up @@ -191,7 +192,7 @@ def __init__(self,
elif self.predownload is None:
self.predownload = 8 * self.batch_size if self.batch_size is not None else 64

self.batch_size = batch_size or 1
self.batch_size = batch_size

# Convert epoch size from string to int, if needed. Cannot be negative.
epoch_size_value = None
Expand Down
3 changes: 2 additions & 1 deletion streaming/base/batching/per_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def generate_work_per_stream_batching(dataset: StreamingDataset, world: World, e
# nodes/devices/workers. We handle sample_in_epoch (for resumption) at the end.
partition_per_stream = []

batch_size = dataset.batch_size or 1
batch_size = dataset.batch_size
assert isinstance(batch_size, int), f'Batch size must be an integer. Got {type(batch_size)}.'

for stream_id, stream in enumerate(dataset.streams):
shuffle_units, small_per_big = dataset.resample_streams(epoch, stream_id)
Expand Down
5 changes: 4 additions & 1 deletion streaming/base/batching/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ def generate_work_random_batching(dataset: StreamingDataset, world: World, epoch
# sample ID to its underlying "small" sample ID.
shuffle_units, small_per_big = dataset.resample_streams(epoch)

batch_size = dataset.batch_size
assert isinstance(batch_size, int), f'Batch size must be an integer. Got {type(batch_size)}.'

# Partition the global sample space (of resampled "big" sample IDs) into a tensor of shape
# (num physical nodes, ranks per node, workers per rank, batches per worker, samples per
# batch) such that we have an elastically deterministic sample order.
big_ids = get_partitions(dataset.partition_algo, dataset.epoch_size,
dataset.num_canonical_nodes, world.num_nodes, world.ranks_per_node,
world.workers_per_rank, dataset.batch_size, sample_in_epoch,
world.workers_per_rank, batch_size, sample_in_epoch,
dataset.initial_physical_nodes)

# If we need to shuffle, shuffle in a node-aware and *underlying* shard-aware way.
Expand Down
4 changes: 3 additions & 1 deletion streaming/base/batching/stratified.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def generate_work_stratified_batching(dataset: StreamingDataset, world: World, e
# Then, we also partition each stream's samples over nodes/devices/workers.
# We handle sample_in_epoch (for resumption) at the end.

batch_size = dataset.batch_size or 1
batch_size = dataset.batch_size
assert isinstance(batch_size, int), f'Batch size must be an integer. Got {type(batch_size)}.'

global_batch_size = batch_size * world.ranks_per_node * world.num_nodes
partition_per_stream = []
batch_portion_per_stream = []
Expand Down
12 changes: 10 additions & 2 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,9 @@ class StreamingDataset(Array, IterableDataset):
For sequential sample ordering, set ``shuffle`` to ``False`` and
``num_canonical_nodes`` to the number of physical nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
batch_size (int, optional): Per-device batch size, the same as what is passed to the
DataLoader. This affects how the dataset is partitioned over the workers and is
necessary for deterministic resumption and optimal performance. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``.
Expand Down Expand Up @@ -1000,6 +1001,13 @@ def _get_work(self, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]:
f"for sample replication when using StreamingDataset's " +
f'`state_dict` method for deterministic resumption. Otherwise, ' +
f'you will resume training from the wrong sample.')
# Ensure that batch_size is passed in, and is an integer. This is necessary for
# deterministic resumption and optimal performance.
if not isinstance(self.batch_size, int):
raise ValueError(f'Please pass `batch_size` to StreamingDataset. It should be ' +
f'set the same as the DataLoader, and is the number of samples ' +
f'per batch, for each device. It is necessary for ' +
f'deterministic resumption and optimal performance.')
epoch_sample_ids = generate_work(self.batching_method, self, p_world, epoch,
sample_in_epoch)
shape_shm, data_shm = self._share_work(epoch_sample_ids)
Expand Down
6 changes: 3 additions & 3 deletions streaming/base/partition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_partitions(algo: str,
num_physical_nodes: int,
ranks_per_node: int,
workers_per_rank: int,
batch_size: Optional[int] = None,
batch_size: int,
drop_first: int = 0,
initial_physical_nodes: Optional[int] = None) -> NDArray[np.int64]:
"""Partition the given number of samples to nodes, ranks, and workers.
Expand All @@ -41,8 +41,8 @@ def get_partitions(algo: str,
num_physical_nodes (int): Number of physical nodes.
ranks_per_node (int): Number of ranks per node.
workers_per_rank (int): Number of worker partitions per rank.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
batch_size (int): Batch size of DataLoader and dataset, which affects how the dataset is
partitioned over the workers.
drop_first (int): Number of samples seen already, which are dropped. Defaults to ``0``.
initial_physical_nodes (int, optional): Number of physical nodes at the start of training.
Defaults to ``None``.
Expand Down
8 changes: 3 additions & 5 deletions streaming/base/partition/orig.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_partitions_orig(num_samples: int,
num_physical_nodes: int,
ranks_per_node: int,
workers_per_rank: int,
batch_size: Optional[int] = None,
batch_size: int,
drop_first: int = 0,
initial_physical_nodes: Optional[int] = None) -> NDArray[np.int64]:
"""Partition the given number of samples to nodes, ranks, and workers.
Expand All @@ -36,8 +36,8 @@ def get_partitions_orig(num_samples: int,
num_physical_nodes (int): Number of physical nodes.
ranks_per_node (int): Number of ranks per node.
workers_per_rank (int): Number of worker partitions per rank.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
batch_size (int): Batch size of DataLoader and dataset, which affects how the dataset is
partitioned over the workers.
drop_first (int): Number of samples seen already, which are dropped. Defaults to ``0``.
initial_physical_nodes (int, optional): Number of physical nodes at the start of training.
Defaults to ``None``.
Expand All @@ -61,8 +61,6 @@ def get_partitions_orig(num_samples: int,
'the other, otherwise striping slices of shards over nodes may ' +
'lead to each node downloading all shards')

batch_size = batch_size or 1

# If drop_first isn't a multiple of num_physical_nodes, round down to make it divisible.
if drop_first % num_physical_nodes:
logger.warning(
Expand Down
7 changes: 3 additions & 4 deletions streaming/base/partition/relaxed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_partitions_relaxed(num_samples: int,
num_physical_nodes: int,
ranks_per_node: int,
workers_per_rank: int,
batch_size: Optional[int] = None,
batch_size: int,
drop_first: int = 0,
initial_physical_nodes: Optional[int] = None) -> NDArray[np.int64]:
"""Partition the given number of samples to nodes, ranks, and workers.
Expand All @@ -39,8 +39,8 @@ def get_partitions_relaxed(num_samples: int,
num_physical_nodes (int): Number of physical nodes.
ranks_per_node (int): Number of ranks per node.
workers_per_rank (int): Number of worker partitions per rank.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
batch_size (int): Batch size of DataLoader and dataset, which affects how the dataset is
partitioned over the workers.
drop_first (int): Number of samples seen already, which are dropped. Defaults to ``0``.
initial_physical_nodes (int, optional): Number of physical nodes at the start of training.
Defaults to ``None``.
Expand All @@ -65,7 +65,6 @@ def get_partitions_relaxed(num_samples: int,
return get_partitions_orig(num_samples, num_canonical_nodes, num_physical_nodes,
ranks_per_node, workers_per_rank, batch_size, drop_first)
else:
batch_size = batch_size or 1
# First, make a partition over the initial number of physical nodes and device batch size.
# We assume that ranks_per_node and workers_per_rank stay constant during resumptions.
global_batch_size = num_physical_nodes * ranks_per_node * batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def main(args: Namespace) -> None:
dirname = os.path.join(args.mds_root, name)
stream = Stream(local=dirname, proportion=1 / len(subsets_present))
streams.append(stream)
dataset = StreamingDataset(streams=streams, epoch_size=50)
dataset = StreamingDataset(streams=streams, epoch_size=50, batch_size=1)

# Print the size of each sub-dataset.
for name, num_samples in zip(sorted(subsets_present), dataset.samples_per_stream):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main(args: Namespace) -> None:
args (Namespace): Command-line arguments.
"""
hashes = args.hashes.split(',') if args.hashes else []
dataset = StreamingDataset(local=getattr(args, 'in'))
dataset = StreamingDataset(local=getattr(args, 'in'), batch_size=1)
with MDSWriter(out=args.out_mds,
columns=out_columns,
compression=args.compression,
Expand Down
10 changes: 6 additions & 4 deletions streaming/multimodal/webvid.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ class StreamingInsideWebVid(StreamingDataset):
For sequential sample ordering, set ``shuffle`` to ``False`` and
``num_canonical_nodes`` to the number of physical nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
batch_size (int, optional): Per-device batch size, the same as what is passed to the
DataLoader. This affects how the dataset is partitioned over the workers and is
necessary for deterministic resumption and optimal performance. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``.
Expand Down Expand Up @@ -134,8 +135,9 @@ class StreamingOutsideGIWebVid(StreamingDataset):
For sequential sample ordering, set ``shuffle`` to ``False`` and
``num_canonical_nodes`` to the number of physical nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
batch_size (int, optional): Per-device batch size, the same as what is passed to the
DataLoader. This affects how the dataset is partitioned over the workers and is
necessary for deterministic resumption and optimal performance. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``.
Expand Down
5 changes: 3 additions & 2 deletions streaming/text/c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ class StreamingC4(StreamingDataset):
For sequential sample ordering, set ``shuffle`` to ``False`` and
``num_canonical_nodes`` to the number of physical nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
batch_size (int, optional): Per-device batch size, the same as what is passed to the
DataLoader. This affects how the dataset is partitioned over the workers and is
necessary for deterministic resumption and optimal performance. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``.
Expand Down
5 changes: 3 additions & 2 deletions streaming/text/enwiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ class StreamingEnWiki(StreamingDataset):
For sequential sample ordering, set ``shuffle`` to ``False`` and
``num_canonical_nodes`` to the number of physical nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
batch_size (int, optional): Per-device batch size, the same as what is passed to the
DataLoader. This affects how the dataset is partitioned over the workers and is
necessary for deterministic resumption and optimal performance. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``.
Expand Down
5 changes: 3 additions & 2 deletions streaming/text/pile.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ class StreamingPile(StreamingDataset):
For sequential sample ordering, set ``shuffle`` to ``False`` and
``num_canonical_nodes`` to the number of physical nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
batch_size (int, optional): Per-device batch size, the same as what is passed to the
DataLoader. This affects how the dataset is partitioned over the workers and is
necessary for deterministic resumption and optimal performance. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``.
Expand Down
5 changes: 3 additions & 2 deletions streaming/vision/ade20k.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ class StreamingADE20K(StreamingDataset):
For sequential sample ordering, set ``shuffle`` to ``False`` and
``num_canonical_nodes`` to the number of physical nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
batch_size (int, optional): Per-device batch size, the same as what is passed to the
DataLoader. This affects how the dataset is partitioned over the workers and is
necessary for deterministic resumption and optimal performance. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``.
Expand Down
2 changes: 1 addition & 1 deletion streaming/vision/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(self,
cache_limit: Optional[int] = None,
partition_algo: str = 'orig',
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
batch_size: int,
shuffle: bool = False,
shuffle_algo: str = 'py1s',
shuffle_seed: int = 9176,
Expand Down
5 changes: 3 additions & 2 deletions streaming/vision/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ class StreamingCOCO(StreamingDataset):
For sequential sample ordering, set ``shuffle`` to ``False`` and
``num_canonical_nodes`` to the number of physical nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
batch_size (int, optional): Per-device batch size, the same as what is passed to the
DataLoader. This affects how the dataset is partitioned over the workers and is
necessary for deterministic resumption and optimal performance. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def test_dataset_compression(compressed_local_remote_dir: Tuple[str, str, str],
size_limit=size_limit,
compression=None)

dataset = StreamingDataset(local=local, remote=compressed, shuffle=shuffle)
dataset = StreamingDataset(local=local, remote=compressed, shuffle=shuffle, batch_size=1)

for _ in dataset:
pass # download sample
Expand Down
Loading

0 comments on commit d8bf491

Please sign in to comment.