Skip to content

Commit

Permalink
Epoch size default behavior (#374)
Browse files Browse the repository at this point in the history
* epoch size default behavior fixed

* epoch size default behavior fixed

* Delete .DS_Store

* added comments, removed print statements from test

* added ds store to gitignore
  • Loading branch information
snarayan21 authored Aug 10, 2023
1 parent 42c0c61 commit 62e4906
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 16 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,6 @@ dmypy.json

# pycharm
.idea/

# OS X
.DS_Store
5 changes: 0 additions & 5 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,16 +331,11 @@ def __init__(self,
default = Stream(remote=remote,
local=local,
split=split,
choose=epoch_size_value,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip)
streams = [default]
# reset `epoch_size_value` to None when we initialize StreamingDataset with no
# streams so that when we `apply_weights` over this single stream we use the
# epoch size to absolutely weight the single stream.
epoch_size_value = None

# Validate the stream weighting scheme (relative or absolute) to catch errors before we go
# to the trouble of loading them.
Expand Down
5 changes: 3 additions & 2 deletions streaming/base/shuffle/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ def get_shuffle_naive(shard_sizes: NDArray[np.int64],
block_size: int = 1 << 18) -> NDArray[np.int64]:
"""Get the shuffled global ordering of samples for an epoch.
The assignment of shards to nodes is fixed across epochs, but each grouping of shards is
processed concurrently in a different order by each node's workers each epoch.
The assignment of shards to nodes is fixed across epochs, but each grouping
of shards is processed concurrently in a different order by each node's
workers each epoch.
Args:
shard_sizes (NDArray[np.int64]): Number of samples contained in each shard, in order.
Expand Down
22 changes: 18 additions & 4 deletions streaming/base/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import os
import tempfile
from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Tuple

import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -191,7 +191,7 @@ def apply_default(self, default: dict) -> None:
self.safe_keep_zip = default['keep_zip'] or self.remote in {None, self.local}

@classmethod
def validate_weights(cls, streams: Sequence[Self]) -> bool:
def validate_weights(cls, streams: Sequence[Self]) -> Tuple[bool, bool]:
"""Validate stream weights, returning whether relative or absolute weighting was used.
Args:
Expand All @@ -202,6 +202,7 @@ def validate_weights(cls, streams: Sequence[Self]) -> bool:
"""
# Validate stream weights ("proportion", "repeat", "choose", or none).
is_proportional = hasattr(streams[0], 'proportion')
is_unspecified = True
for stream_id, stream in enumerate(streams):
has_proportion = hasattr(stream, 'proportion')
has_repeat = hasattr(stream, 'repeat')
Expand All @@ -213,7 +214,9 @@ def validate_weights(cls, streams: Sequence[Self]) -> bool:
raise ValueError(f'Relative (`proportion`) and absolute (`repeat`, `choose`, ' +
f'none) stream weights are incompatible with each other (error ' +
f'in stream {stream_id})')
return is_proportional
if has_proportion or has_repeat or has_choose:
is_unspecified = False
return is_proportional, is_unspecified

@classmethod
def apply_weights(cls, streams: Sequence[Self], samples_per_stream: NDArray[np.int64],
Expand All @@ -232,7 +235,7 @@ def apply_weights(cls, streams: Sequence[Self], samples_per_stream: NDArray[np.i
int: Number of samples to draw per epoch.
"""
# Validate provided weights, determining whether they are relative or absolute.
are_weights_relative = cls.validate_weights(streams)
are_weights_relative, are_weights_unspecified = cls.validate_weights(streams)

# Derive weights.
if are_weights_relative:
Expand All @@ -247,6 +250,17 @@ def apply_weights(cls, streams: Sequence[Self], samples_per_stream: NDArray[np.i
indices = rng.choice(len(streams), shortfall, False)
choose_per_stream[indices] += 1
repeat_per_stream = choose_per_stream / samples_per_stream
elif are_weights_unspecified and choose_per_epoch:
# weights are unspecified, but epoch size (choose_per_epoch) is provided.
# sample from each stream in proportion stream's samples
proportion_per_stream = samples_per_stream.copy().astype(np.float64)
proportion_per_stream /= proportion_per_stream.sum()
choose_per_stream = (choose_per_epoch * proportion_per_stream).astype(np.int64)
shortfall = choose_per_epoch - choose_per_stream.sum()
rng = np.random.default_rng(seed)
indices = rng.choice(len(streams), shortfall, False)
choose_per_stream[indices] += 1
repeat_per_stream = choose_per_stream / samples_per_stream
else:
# Absolute.
if choose_per_epoch:
Expand Down
13 changes: 10 additions & 3 deletions tests/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@ class SequenceDataset:
Args:
size (int): number of samples. Defaults to 100.
column_names List[str]: A list of features' and target name. Defaults to ['id', 'sample'].
offset: Offset to start the sequence from. Defaults to 0.
"""

def __init__(self, size: int = 100, column_names: List[str] = ['id', 'sample']) -> None:
def __init__(
self,
size: int = 100,
column_names: List[str] = ['id', 'sample'],
offset: int = 0,
) -> None:
self.size = size
self.column_encodings = ['str', 'int']
self.column_sizes = [None, 8]
self.column_names = column_names
self.offset = offset
self._index = 0

def __len__(self) -> int:
Expand All @@ -30,7 +37,7 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
if index < self.size:
return {
self.column_names[0]: f'{index:06}',
self.column_names[1]: 3 * index,
self.column_names[1]: (3 * index) + self.offset,
}
raise IndexError('Index out of bound')

Expand All @@ -41,7 +48,7 @@ def __next__(self) -> Dict[str, Any]:
if self._index >= self.size:
raise StopIteration
id = f'{self._index:06}'
data = 3 * self._index
data = (3 * self._index) + self.offset
self._index += 1
return {
self.column_names[0]: id,
Expand Down
7 changes: 6 additions & 1 deletion tests/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,17 @@ def convert_to_mds(**kwargs: Any):
dataset_name = kwargs['dataset_name'].lower()
out_root = kwargs['out_root']
num_samples = kwargs.get('num_samples', 117)
offset = kwargs.get('offset', 0)
keep_local = kwargs.get('keep_local', False)
compression = kwargs.get('compression', None)
hashes = kwargs.get('hashes', None)
size_limit = kwargs.get('size_limit', 1 << 8)

dataset = dataset_mapping[dataset_name](num_samples)
if (dataset_name == 'sequencedataset' and offset != 0):
dataset = dataset_mapping[dataset_name](num_samples, offset=offset)
else:
dataset = dataset_mapping[dataset_name](num_samples)

columns = dict(zip(dataset.column_names, dataset.column_encodings))

with MDSWriter(out=out_root,
Expand Down
96 changes: 95 additions & 1 deletion tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
# SPDX-License-Identifier: Apache-2.0

import math
import os
import shutil
from typing import Any, Tuple

import pytest
from torch.utils.data import DataLoader

from streaming.base import StreamingDataLoader, StreamingDataset
from streaming.base import Stream, StreamingDataLoader, StreamingDataset
from tests.common.utils import convert_to_mds


Expand Down Expand Up @@ -58,6 +59,99 @@ def test_dataloader_epoch_size_no_streams(local_remote_dir: Tuple[str,
assert samples_seen == epoch_size


@pytest.mark.parametrize('batch_size', [4])
@pytest.mark.parametrize('seed', [2222])
@pytest.mark.parametrize('shuffle', [False])
@pytest.mark.parametrize('drop_last', [False, True])
@pytest.mark.parametrize('num_workers', [3, 6])
@pytest.mark.parametrize('num_canonical_nodes', [4, 8])
@pytest.mark.parametrize('epoch_size', [10, 200])
@pytest.mark.usefixtures('local_remote_dir')
def test_dataloader_epoch_size_multiple_streams_default(local_remote_dir: Tuple[str, str],
batch_size: int, seed: int, shuffle: bool,
drop_last: bool, num_workers: int,
num_canonical_nodes: int, epoch_size: int):
# create mock datasets for 2 streams. Second one has 1.5x the samples
local, remote = local_remote_dir
local1 = os.path.join(local, 'stream1')
local2 = os.path.join(local, 'stream2')
remote1 = os.path.join(remote, 'stream1')
remote2 = os.path.join(remote, 'stream2')

# stream 1 has samples 0->600
convert_to_mds(out_root=remote1,
dataset_name='sequencedataset',
num_samples=200,
size_limit=1 << 8)
# stream 2 has samples 600 and above. This lets us differentiate between the samples from each stream
convert_to_mds(out_root=remote2,
dataset_name='sequencedataset',
num_samples=300,
offset=600,
size_limit=1 << 8)

stream1 = Stream(local=local1, remote=remote1)
stream2 = Stream(local=local2, remote=remote2)

# Build StreamingDataset
dataset = StreamingDataset(streams=[stream1, stream2],
shuffle=shuffle,
batch_size=batch_size,
shuffle_seed=seed,
num_canonical_nodes=num_canonical_nodes,
epoch_size=epoch_size)

# Build DataLoader
dataloader = StreamingDataLoader(dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=drop_last)

# track the number of samples seen overall in the epoch,
# and also track the number of samples seen from each stream.
# we expect the number of samples from each stream in the epoch
# to be proportional to the number of total samples in the stream,
# in the case when proportion, repeat, and choose are all unspecified.
samples_seen_stream1 = 0
samples_seen_stream2 = 0
samples_seen = 0
for batch in dataloader:
samples = batch['sample']
samples_seen += samples.size(dim=0)
stream1_seen = (samples < 600).sum().item()
stream2_seen = (samples > 600).sum().item()
samples_seen_stream1 += stream1_seen
samples_seen_stream2 += stream2_seen

# if epoch size is not divisible by canonical nodes the partition algorithm will have some repeated samples
# so the number of samples seen will be within some tolerance of the epoch size
# in all cases though, stream 1 and stream 2 samples should be approximately in a 2:3 ratio
# in accordance with the number of samples each stream has (stream 1: 200, stream 2: 300)
if epoch_size % num_canonical_nodes != 0:
assert samples_seen == (math.ceil(epoch_size / num_canonical_nodes) * num_canonical_nodes)
assert samples_seen_stream1 == int(
samples_seen * 0.4) or samples_seen_stream1 == int(samples_seen * 0.4) + 1
assert samples_seen_stream2 == int(
samples_seen * 0.6) or samples_seen_stream2 == int(samples_seen * 0.6) + 1
else:
# if drop_last is True, we will drop incomplete batches, so samples_seen can
# be less than epoch_size
if drop_last:
assert samples_seen == epoch_size - (epoch_size % batch_size)
assert samples_seen_stream1 == int(
samples_seen * 0.4) or samples_seen_stream1 == int(samples_seen * 0.4) + 1
assert samples_seen_stream2 == int(
samples_seen * 0.6) or samples_seen_stream2 == int(samples_seen * 0.6) + 1
# drop_last is false, and epoch_size is divisible by num_canonical_nodes, so samples_seen
# should be the same as epoch_size
else:
assert samples_seen == epoch_size
assert samples_seen_stream1 == int(
samples_seen * 0.4) or samples_seen_stream1 == int(samples_seen * 0.4) + 1
assert samples_seen_stream2 == int(
samples_seen * 0.6) or samples_seen_stream2 == int(samples_seen * 0.6) + 1


@pytest.mark.parametrize('batch_size', [4])
@pytest.mark.parametrize('seed', [2222])
@pytest.mark.parametrize('shuffle', [False, True])
Expand Down

0 comments on commit 62e4906

Please sign in to comment.