Skip to content

Commit

Permalink
Multipack padding fix (#19)
Browse files Browse the repository at this point in the history
For non-dolomite models, packing_max_batch_len was being calculated incorrectly. This calculates the necessary increase to make sure that average batch size is similar to the specified effective_batch_size.

---------

Signed-off-by: Mustafa Eyceoz <[email protected]>
  • Loading branch information
Maxusmusti authored Jun 20, 2024
1 parent 204bcb5 commit c6bffd4
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 94 deletions.
4 changes: 4 additions & 0 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,10 @@ def main(args):
avg_sample_len=dataset.get_lengths().mean(),
effective_batch_size=args.effective_batch_size,
max_batch_len_per_gpu=args.max_batch_len,
is_padding=not args.is_granite,
dataset=dataset,
pad_id=tokenizer.pad_token_id,
seed=args.seed,
)
args.samples_per_gpu = (
args.effective_batch_size // grad_accum // torch.distributed.get_world_size()
Expand Down
163 changes: 159 additions & 4 deletions src/instructlab/training/multipack_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,156 @@

# Standard
from typing import List, Optional
import os

# Third Party
from torch.utils.data import Sampler
from torch.utils.data import DataLoader, Sampler
import numba
import numpy as np
import torch.distributed as dist

# First Party
from instructlab.training.utils import make_collate_fn


def guess_starting_avg_padding(base_avg, goal, num_gpus, grad_accum, sorted_lengths):
"""
Return a starting middle point for the binary search
(to find optimal addition to packing_max_batch_len
to account for padding)
Uses the largest initial bucket to approximate an
upper-bound for average padding, should overshoot.
"""
addition = 0
packing_max_batch_len = int(
(base_avg + addition) * ((goal / num_gpus) / grad_accum)
)

bucket_zero = []
max = sorted_lengths[0]
sum = 0
for length in sorted_lengths:
if sum + max <= packing_max_batch_len:
sum += max
bucket_zero.append(length)
else:
break

total_pad = 0
for length in bucket_zero:
total_pad += max - length
addition = round(total_pad / len(bucket_zero))
return addition


def simulate_buckets(
base_avg,
goal,
num_gpus,
grad_accum,
pad_id,
max_batch_len,
lengths,
seed,
dataset,
addition,
):
"""
Given an addition to packing_max_batch_len, simulate the
packing to find the updated average effective batch size.
"""
packing_max_batch_len = int(
(base_avg + addition) * ((goal / num_gpus) / grad_accum)
)

collate_fn = make_collate_fn(pad_id, is_granite=False, max_batch_len=max_batch_len)
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

sampler = MultipackDistributedBatchSampler(
batch_max_length=packing_max_batch_len,
lengths=lengths,
num_replicas=world_size,
rank=rank,
seed=seed,
padding=True,
)
simulation_loader = DataLoader(
dataset,
batch_sampler=sampler,
num_workers=8,
collate_fn=collate_fn,
)

avg_ebs = len(dataset) / len(simulation_loader)
return avg_ebs


def find_padding_max_batch_len_addition(
base_avg, goal, dataset, num_gpus, grad_accum, pad_id, max_batch_len, seed
):
"""
Do a modified binary search to find optimal padding addition for
packing_maximum_batch_len. Starts with an upper-bound guess, and
increases upper-bound until guess overshoots. Then perform standard
binary search until within a threshold for average effective batch
size.
"""
lengths = dataset.get_lengths()
sorted_lengths = list(lengths)
sorted_lengths.sort(reverse=True)

# Use first default bucket avg padding as starting value for addition
addition = guess_starting_avg_padding(
base_avg, goal, num_gpus, grad_accum, sorted_lengths
)

# binary search correct addition value from starting value
first_over_hit = False
l = 0
r = 2 * addition
while r - l > 1:
avg_ebs = simulate_buckets(
base_avg,
goal,
num_gpus,
grad_accum,
pad_id,
max_batch_len,
lengths,
seed,
dataset,
addition,
)

# check if simulation resulted in batch sizes close enough to goal and adjust if needed
if abs(avg_ebs - goal) <= max(10, round(goal * 0.02)):
break

if avg_ebs > goal:
first_over_hit = True
r = addition
elif avg_ebs < goal:
if not first_over_hit:
# If the starting midpoint failed to overshoot, increase the bounds of the search
r = r * 2
else:
l = addition
addition = l + ((r - l) // 2)

return addition


def find_packing_max_batch_len_and_grad_accum(
num_gpus, avg_sample_len, effective_batch_size, max_batch_len_per_gpu
num_gpus,
avg_sample_len,
effective_batch_size,
max_batch_len_per_gpu,
is_padding,
dataset,
pad_id,
seed,
):
"""
Calculate the minimum gradient accumulation steps required and the corresponding maximum batch length.
Expand All @@ -58,12 +198,27 @@ def find_packing_max_batch_len_and_grad_accum(
without exceeding the per-GPU limit, and the second element is the minimum number of gradient
accumulation steps required to maintain the effective batch size.
"""

packing_max_batch_len = max_batch_len_per_gpu + 1
grad_accum = 0
while packing_max_batch_len > max_batch_len_per_gpu:
grad_accum += 1
total_micro_batch = effective_batch_size / grad_accum
packing_max_batch_len = int(avg_sample_len * total_micro_batch / num_gpus)
total_micro_batch = (effective_batch_size / grad_accum) / num_gpus
if is_padding:
addition = find_padding_max_batch_len_addition(
avg_sample_len,
effective_batch_size,
dataset,
num_gpus,
grad_accum,
pad_id,
max_batch_len_per_gpu,
seed,
)
else:
addition = 0
packing_max_batch_len = int((avg_sample_len + addition) * total_micro_batch)

return packing_max_batch_len, grad_accum


Expand Down
91 changes: 1 addition & 90 deletions src/instructlab/training/token_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch
import torch.nn.functional as F

# First Party
from instructlab.training.multipack_sampler import MultipackDistributedBatchSampler
from instructlab.training.utils import log_rank_0
from instructlab.training.utils import log_rank_0, make_collate_fn


class TokenDataset(Dataset):
Expand Down Expand Up @@ -66,94 +65,6 @@ def get_lengths(self):
return np.array([len(self.input_ids[0])] * len(self.input_ids))


def make_collate_fn(pad_token_id, is_granite=False, max_batch_len=60000):
rank = int(os.environ["RANK"])
if is_granite:

def pad_collate_fn(batch):
lens = np.array([len(item["input_ids"]) for item in batch])

cumsum_lens = np.cumsum(lens)
valid_up_to = int((cumsum_lens < max_batch_len).sum())
total_len = cumsum_lens[valid_up_to - 1]

batch = batch[:valid_up_to]
input_ids = [x["input_ids"].tolist() for x in batch]
labels = [x["labels"].tolist() for x in batch]
num_loss_counted_tokens = sum(
[(x["labels"] != -100).sum().item() for x in batch]
)

print(
f"\033[96m total length: {total_len} dropped: {cumsum_lens[-1] - total_len} "
f"num samples {len(batch)} - rank: {rank} "
f"max len: {lens.max()} min len: {lens.min()} avg len: {lens.mean()} "
f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m"
)

return {
"input_ids": input_ids,
"labels": labels,
"num_loss_counted_tokens": num_loss_counted_tokens,
}

else:

def pad_collate_fn(batch):
lens = np.array([len(item["input_ids"]) for item in batch])
max_len = max(lens)

input_ids = torch.stack(
[
F.pad(
item["input_ids"],
(max_len - len(item["input_ids"]), 0),
mode="constant",
value=pad_token_id,
)
for item in batch
]
)
labels = torch.stack(
[
F.pad(
item["labels"],
(max_len - len(item["labels"]), 0),
mode="constant",
value=-100,
)
for item in batch
]
)
num_loss_counted_tokens = (labels != -100).sum()

attention_mask = torch.stack(
[
F.pad(
item["attention_mask"],
(max_len - len(item["attention_mask"]), 0),
mode="constant",
value=0,
)
for item in batch
]
)
print(
f"\033[96m total tokens: {max_len * len(batch)} num samples: {len(batch)} num padding tokens: {max_len * len(batch) - lens.sum()} - rank: {rank} "
f"max len: {max_len} min len: {min(lens)} avg len: {lens.mean()} "
f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m"
)

return {
"input_ids": input_ids,
"labels": labels,
"num_loss_counted_tokens": num_loss_counted_tokens,
"attention_mask": attention_mask,
}

return pad_collate_fn


def setup_dataset(
data_path: str,
mock: bool = False,
Expand Down
Loading

0 comments on commit c6bffd4

Please sign in to comment.