Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stratified #1

Open
wants to merge 12 commits into
base: dev-2024
Choose a base branch
from
141 changes: 120 additions & 21 deletions megatron/core/datasets/blended_megatron_dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import logging
import math
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Iterable, List, Optional, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union, Tuple

import numpy
import torch

from megatron.core.datasets.blended_dataset import BlendedDataset
from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
from megatron.core.datasets.megatron_dataset import LowLevelDataset, MegatronDataset
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.utils import Split, normalize
from megatron.core.parallel_state import get_virtual_pipeline_model_parallel_rank
from megatron.core.utils import log_single_rank
Expand All @@ -25,6 +26,51 @@
TopLevelDataset, MidLevelDataset, LowLevelDataset, torch.utils.data.Dataset
]

class StratifiedDataset:
def __init__(self, datasets: dict):
"""
Initialize the StratifiedDataset with a dictionary of datasets.

Args:
datasets (dict): A dictionary where each key is a prefix and each value is a dictionary
containing 'dataset' (an instance of MegatronDataset) and 'weight' (a float).
"""
self.datasets = datasets

def __len__(self):
"""
Return the total number of samples in all datasets.

Returns:
int: The sum of the number of samples in all datasets.
"""
total_samples = 0
for dataset_info in self.datasets.values():
dataset = dataset_info['dataset']
total_samples += len(dataset)
return total_samples

def __getitems__(self, indices: List[Tuple[str, int]]) -> List[Any]:
"""
Fetch samples from the datasets based on the provided indices.

Args:
indices (List[Tuple[str, int]]): A list of tuples where each tuple contains a dataset prefix
and a sample index within that dataset.

Returns:
List[Any]: A list of samples fetched from the datasets.
"""
samples = []
for prefix, sample_index in indices:
dataset_info = self.datasets.get(prefix)
if dataset_info is not None:
dataset = dataset_info['dataset']
sample = dataset[sample_index]
samples.append(sample)
else:
raise KeyError(f"Dataset with prefix '{prefix}' not found.")
return samples

class BlendedMegatronDatasetBuilder(object):
"""Builder class for the BlendedDataset and MegatronDataset classes
Expand Down Expand Up @@ -121,10 +167,13 @@ def build(self) -> List[Optional[TopLevelDataset]]:
- Build a top-level dataset with no excess mid-level dataset sampling

Returns:
List[Optional[TopLevelDataset]]: A list containing a dataset instance (or None) per split
Union[List[Optional[TopLevelDataset]], List[Dict[str, Dict[str, Union[MegatronDataset, float]]]]]: A list containing a dataset instance (or None) per split
"""
datasets = self._build_blended_dataset_splits()

if self.config.stratified:
return datasets

for dataset in datasets:
if dataset is not None and len(dataset) > 0:
if isinstance(dataset, BlendedDataset):
Expand Down Expand Up @@ -156,13 +205,29 @@ def build(self) -> List[Optional[TopLevelDataset]]:

return datasets

def _build_blended_dataset_splits(self) -> List[Optional[TopLevelDataset]]:
def _build_blended_dataset_splits(
self
) -> Union[
List[Optional[TopLevelDataset]],
List[Dict[str, Dict[str, Union[LowLevelDataset, float]]]]
]:
"""Build all dataset splits according to the provided blend(s)

See the BlendedMegatronDatasetBuilder.build alias for more information.

Returns:
List[Optional[TopLevelDataset]]: A list containing a dataset instance (or None) per split
Union[List[Optional[TopLevelDataset]], List[Dict[str, Dict[str, Union[MegatronDataset, float]]]]]:
- If not using stratified batching:
A list containing a dataset instance (or None) per split
- If using stratified batching:
A list of dictionaries, where each dictionary contains:
{
prefix: {
'dataset': MegatronDataset,
'weight': float
}
}
for each enabled split
"""
##
# Return fake "mock" datasets
Expand All @@ -186,52 +251,86 @@ def _build_blended_dataset_splits(self) -> List[Optional[TopLevelDataset]]:

split = self.config.split_matrix

# Blend consists of a single prefix
# If we only have one dataset prefix and no weights specified,
# we can directly build a single dataset without blending
if len(prefixes) == 1 and weights is None:
return self._build_megatron_dataset_splits(prefixes[0], split, self.sizes)

# Build the mid-level datasets
# For multiple datasets, we need to determine the size of each dataset per split
# If no weights are provided, set all sizes to None to use full dataset sizes
if weights is None:
# Create a 2D list of None values with dimensions:
# [num_prefixes][num_splits]
sizes_per_dataset = [[None for split in Split] for prefix in prefixes]
else:
# Calculate the size of each dataset per split based on the provided weights
# This ensures datasets are sampled proportionally according to weights
sizes_per_dataset = _get_size_per_split_per_dataset(weights, self.sizes)

# build each dataset in parallel
# Build all the individual datasets in parallel for efficiency
# Each dataset is built according to its prefix, split configuration,
# and target size per split
megatron_datasets = self._build_megatron_datasets_parallel(
prefixes, split, sizes_per_dataset
)

# Build the top-level datasets
blended_datasets = [None] * len(Split)
for i in range(len(Split)):
if split[i] is not None:
weights_i = weights
# If using stratified batching, return datasets with their proportions
if self.config.stratified:
stratified_datasets = []
# Only iterate through enabled splits based on split matrix
for i in range(len(Split)):
if split[i] is not None:
split_dict = {}
for prefix, dataset, weight in zip(prefixes, megatron_datasets[i], weights):
split_dict[prefix] = {'dataset': dataset, 'weight': weight}
stratified_datasets.append(split_dict)
return stratified_datasets

# Build the top-level datasets by blending multiple datasets together
blended_datasets = [None] * len(Split) # Initialize list to store blended datasets for each split
for i in range(len(Split)): # Iterate through each split (train/val/test)
if split[i] is not None: # Only process if this split is enabled
weights_i = weights # Get weights for blending datasets

# Case 1: We have predefined weights and sizes
if weights_i is not None and self.sizes[i] is not None:
# Get size for each dataset in this split
size_per_dataset = list(zip(*sizes_per_dataset))[i]
size_i = sum(size_per_dataset)
size_i = sum(size_per_dataset) # Total size for this split
# Optionally renormalize weights based on actual dataset sizes
if self.config.renormalize_blend_weights:
weights_i = list(map(lambda _size: _size / size_i, size_per_dataset))

# Case 2: No predefined weights - use dataset lengths as weights
elif weights_i is None:
try:
# Try to get lengths of each dataset to use as weights
weights_i = [
len(megatron_dataset) for megatron_dataset in megatron_datasets[i]
]
except TypeError:
# If lengths not available, use equal weights of 0
weights_i = [0 for _ in prefixes]

# Set size to either specified size or sum of dataset lengths
if self.sizes[i] is not None:
size_i = min(self.sizes[i], sum(weights_i))
else:
size_i = None # => the size will be sum(weights_i)
size_i = None # Will default to sum of all dataset lengths

# Case 3: Invalid state - weights without sizes
else:
raise RuntimeError

# Create the blended dataset by combining individual datasets
blended_datasets[i] = self.build_generic_dataset(
BlendedDataset,
self.is_built_on_rank,
True, # synchronize_ranks, default behavior to build on rank-0 first
megatron_datasets[i],
weights_i,
size_i,
self.config,
BlendedDataset, # Class to instantiate
self.is_built_on_rank, # Whether to build on this rank
True, # synchronize_ranks: build on rank-0 first
megatron_datasets[i], # List of datasets to blend
weights_i, # Weights for blending
size_i, # Total size of blended dataset
self.config, # Configuration object
)

return blended_datasets
Expand Down
3 changes: 3 additions & 0 deletions megatron/core/datasets/blended_megatron_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ class BlendedMegatronDatasetConfig:
tokenizer: Optional[MegatronTokenizer] = None
"""The MegatronTokenizer instance or None. Required for datasets which do online tokenization."""

stratified: bool = False
"""Whether to use stratified batching."""

def __post_init__(self) -> None:
"""Do asserts and set fields post init"""
if self.blend_per_split is not None and any(self.blend_per_split):
Expand Down
148 changes: 146 additions & 2 deletions megatron/legacy/data/data_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from torch.utils.data import Dataset
from megatron.training import get_args
from megatron.core import mpu

from typing import List, Tuple, Any
from megatron.core.datasets.blended_megatron_dataset_builder import StratifiedDataset

def build_pretraining_data_loader(dataset, consumed_samples):
"""Build dataloader given an input dataset."""
Expand Down Expand Up @@ -38,7 +39,19 @@ def build_pretraining_data_loader(dataset, consumed_samples):
elif args.dataloader_type == "external":
# External dataloaders are passed through. User is expected to provide a
# torch-compatible dataloader and define samplers, if needed.
return dataset
# return dataset
batch_sampler = MegatronPretrainingStratifiedSampler(
dataset_with_weight=dataset,
global_batch_size=args.global_batch_size,
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())

# Merge the datasets for DataLoader, i.e.
# from: {prefix: {'weight': float, 'total_samples': int, 'dataset': MegatronDataset}}
# to: StratifiedDataset
dataset = StratifiedDataset(dataset)
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))
Expand Down Expand Up @@ -190,3 +203,134 @@ def __iter__(self):
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []

class MegatronPretrainingStratifiedSampler:
"""Sampler that handles multiple datasets with specified weights"""

def __init__(
self,
dataset_with_weight: dict, # Dict of {prefix: {'weight': float, 'dataset': MegatronDataset}}
global_batch_size: int,
micro_batch_size: int,
consumed_samples: int,
data_parallel_rank: int,
data_parallel_size: int
):
# Validate dataset_with_weight data type
assert isinstance(dataset_with_weight, dict), "dataset_with_weight must be a dictionary"
for key, value in dataset_with_weight.items():
assert isinstance(key, str), "Each key in dataset_with_weight must be a string"
assert isinstance(value, dict), "Each value in dataset_with_weight must be a dictionary"
assert 'weight' in value, "Each dictionary in dataset_with_weight must contain 'weight' key"
assert isinstance(value['weight'], float), "'weight' must be a float"
assert 'dataset' in value, "Each dictionary in dataset_with_weight must contain 'dataset' key"

self.dataset_with_weight = dataset_with_weight
self.global_batch_size = global_batch_size
self.micro_batch_size = micro_batch_size
self.consumed_samples = consumed_samples
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size

# Validate weights sum to 1
total_weight = sum(info['weight'] for info in dataset_with_weight.values())
assert abs(total_weight - 1.0) < 1e-6, f"weights must sum to 1, got {total_weight}"

# Assert that the global_batch_size is divisible by (micro_batch_size * data_parallel_size)
assert global_batch_size % (micro_batch_size * data_parallel_size) == 0, (
f"global_batch_size ({global_batch_size}) must be divisible by "
f"micro_batch_size ({micro_batch_size}) * data_parallel_size ({data_parallel_size})"
)

# Calculate samples per dataset in the global batch using round
self.dataset_num_samples = {
prefix: round(global_batch_size * info['weight'])
for prefix, info in dataset_with_weight.items()
}

# Adjust rounding errors
total_samples = sum(self.dataset_num_samples.values())
if total_samples != global_batch_size:
# Calculate the difference
difference = global_batch_size - total_samples

# Adjust the dataset with the largest weight
max_weight = 0
max_prop_dataset = None
for dataset, info in dataset_with_weight.items():
if info['weight'] > max_weight:
max_weight = info['weight']
max_prop_dataset = dataset

# Adjust the number of samples for the largest weight dataset
self.dataset_num_samples[max_prop_dataset] += difference

def __len__(self):
# Sum the lengths of all datasets
return sum(len(info['dataset']) for info in self.dataset_with_weight.values())

def __collate_global_batch__(self):
"""Collates samples from all datasets into a global batch.

Returns:
list: List of (prefix, idx) tuples representing the global batch
"""
global_batch_indices = []

# For each dataset, generate its portion of samples for the global batch
for prefix, info in self.dataset_with_weight.items():
num_samples = self.dataset_num_samples[prefix]
total_samples_local = len(info['dataset']) # Get the length of the dataset
weight_local = info['weight']

# Calculate local consumed samples and epoch
consumed_samples_local = int(self.consumed_samples * weight_local)
epoch_local = consumed_samples_local // total_samples_local
bucket_offset_local = consumed_samples_local % total_samples_local

# Generate random permutation for this dataset
g_local = torch.Generator()
g_local.manual_seed(epoch_local)

# Generate indices for this dataset's portion
indices = torch.randperm(total_samples_local, generator=g_local).tolist()[bucket_offset_local:]

# Add (dataset, index) tuples to global batch
global_batch_indices.extend([(prefix, idx) for idx in indices[:num_samples]])

# Permute the global_batch_indices according to the epoch_global
effective_length = (self.__len__() // self.global_batch_size) * self.global_batch_size
epoch_global = self.consumed_samples // effective_length
g_global = torch.Generator()
g_global.manual_seed(epoch_global)
permuted_indices = torch.randperm(len(global_batch_indices), generator=g_global).tolist()
global_batch_indices = [global_batch_indices[i] for i in permuted_indices]

return global_batch_indices

def __iter__(self):
while True:
# Get collated global batch
global_batch_indices = self.__collate_global_batch__()
# print(f"{self.data_parallel_rank}: global_batch_indices regenerated")

# Calculate the number of micro-batches in the global batch
num_micro_batches = self.global_batch_size // self.micro_batch_size
num_micro_batches_per_rank = num_micro_batches // self.data_parallel_size
# print(f"{self.data_parallel_rank}: num_micro_batches = {self.global_batch_size} // {self.micro_batch_size} = {num_micro_batches}")

# Interleave indices for each rank
for i in range(num_micro_batches_per_rank):
# Calculate the starting index for this rank
start_idx = self.data_parallel_rank + i * (self.data_parallel_size * self.micro_batch_size)
# Collect indices for this rank's global-batch
# print(f"Slicing of rank_indices: {list(range(start_idx, self.global_batch_size, self.data_parallel_size))}")
rank_indices = global_batch_indices[start_idx:self.global_batch_size:self.data_parallel_size]
# Collect indices for this rank's micro-batch
micro_batch_indices=rank_indices[:self.micro_batch_size]
# print(f"{self.data_parallel_rank}: micro_batch_indices = {micro_batch_indices}")

yield micro_batch_indices

# Update consumed samples
self.consumed_samples += self.global_batch_size
Loading