diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index a0c03e5a..7ea01c64 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -346,7 +346,9 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): for epoch in range(args.num_epochs): torch.distributed.barrier() - train_loader.batch_sampler.set_epoch(epoch) + torch.distributed.breakpoint() + if hasattr(train_loader.batch_sampler, 'set_epoch'): + train_loader.batch_sampler.set_epoch(epoch) if local_rank == 0: inner_pb = tqdm(range(len(train_loader)), desc=f"Epoch {epoch}") @@ -510,6 +512,7 @@ def main(args): is_granite=args.is_granite, max_batch_len=args.max_batch_len, packing_max_batch_len=packing_max_batch_len, + batch_sampler=args.batch_sampler, seed=args.seed, ) @@ -674,6 +677,16 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs): "constant_with_warmup", ], ) + parser.add_argument( + "--batch_sampler", + type=str, + default="multipack", + help="The batch sampler type to use.", + choices=[ + "multipack", + "default" + ], + ) parser.add_argument("--num_warmup_steps", type=int, default=1000) # parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--save_samples", type=int) diff --git a/src/instructlab/training/token_dataset.py b/src/instructlab/training/token_dataset.py index e36b7395..1bbfe3f6 100644 --- a/src/instructlab/training/token_dataset.py +++ b/src/instructlab/training/token_dataset.py @@ -85,6 +85,7 @@ def setup_dataloader( is_granite=False, max_batch_len=60000, packing_max_batch_len=60000, + batch_sampler='multipack', seed=47, ) -> DataLoader: collate_fn = make_collate_fn( @@ -94,14 +95,18 @@ def setup_dataloader( world_size = int(os.environ["WORLD_SIZE"]) lengths = dataset.get_lengths() - sampler = MultipackDistributedBatchSampler( - batch_max_length=packing_max_batch_len, - lengths=lengths, - num_replicas=world_size, - rank=rank, - seed=seed, - padding=not is_granite, - ) + if batch_sampler == 'multipack': + sampler = MultipackDistributedBatchSampler( + batch_max_length=packing_max_batch_len, + lengths=lengths, + num_replicas=world_size, + rank=rank, + seed=seed, + padding=not is_granite, + ) + elif batch_sampler == 'default': + sampler = None + dataloader = DataLoader( dataset, batch_sampler=sampler,