Skip to content

Commit

Permalink
add batch sampling
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Jun 28, 2024
1 parent 0d88f30 commit 10970f2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
15 changes: 14 additions & 1 deletion src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,9 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):

for epoch in range(args.num_epochs):
torch.distributed.barrier()
train_loader.batch_sampler.set_epoch(epoch)
torch.distributed.breakpoint()
if hasattr(train_loader.batch_sampler, 'set_epoch'):
train_loader.batch_sampler.set_epoch(epoch)

if local_rank == 0:
inner_pb = tqdm(range(len(train_loader)), desc=f"Epoch {epoch}")
Expand Down Expand Up @@ -510,6 +512,7 @@ def main(args):
is_granite=args.is_granite,
max_batch_len=args.max_batch_len,
packing_max_batch_len=packing_max_batch_len,
batch_sampler=args.batch_sampler,
seed=args.seed,
)

Expand Down Expand Up @@ -674,6 +677,16 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs):
"constant_with_warmup",
],
)
parser.add_argument(
"--batch_sampler",
type=str,
default="multipack",
help="The batch sampler type to use.",
choices=[
"multipack",
"default"
],
)
parser.add_argument("--num_warmup_steps", type=int, default=1000)
# parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--save_samples", type=int)
Expand Down
21 changes: 13 additions & 8 deletions src/instructlab/training/token_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def setup_dataloader(
is_granite=False,
max_batch_len=60000,
packing_max_batch_len=60000,
batch_sampler='multipack',
seed=47,
) -> DataLoader:
collate_fn = make_collate_fn(
Expand All @@ -94,14 +95,18 @@ def setup_dataloader(
world_size = int(os.environ["WORLD_SIZE"])

lengths = dataset.get_lengths()
sampler = MultipackDistributedBatchSampler(
batch_max_length=packing_max_batch_len,
lengths=lengths,
num_replicas=world_size,
rank=rank,
seed=seed,
padding=not is_granite,
)
if batch_sampler == 'multipack':
sampler = MultipackDistributedBatchSampler(
batch_max_length=packing_max_batch_len,
lengths=lengths,
num_replicas=world_size,
rank=rank,
seed=seed,
padding=not is_granite,
)
elif batch_sampler == 'default':
sampler = None

dataloader = DataLoader(
dataset,
batch_sampler=sampler,
Expand Down

0 comments on commit 10970f2

Please sign in to comment.