diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 7ea01c64..807c16d3 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -346,9 +346,12 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): for epoch in range(args.num_epochs): torch.distributed.barrier() - torch.distributed.breakpoint() - if hasattr(train_loader.batch_sampler, 'set_epoch'): + if args.sampler in ('multipack'): train_loader.batch_sampler.set_epoch(epoch) + elif args.sampler in ('distributed'): + train_loader.sampler.set_epoch(epoch) + else: + raise NotADirectoryError if local_rank == 0: inner_pb = tqdm(range(len(train_loader)), desc=f"Epoch {epoch}") @@ -512,7 +515,8 @@ def main(args): is_granite=args.is_granite, max_batch_len=args.max_batch_len, packing_max_batch_len=packing_max_batch_len, - batch_sampler=args.batch_sampler, + samples_per_gpu=args.samples_per_gpu, + sampler=args.sampler, seed=args.seed, ) @@ -678,13 +682,13 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs): ], ) parser.add_argument( - "--batch_sampler", + "--sampler", type=str, default="multipack", help="The batch sampler type to use.", choices=[ "multipack", - "default" + "distributed" ], ) parser.add_argument("--num_warmup_steps", type=int, default=1000) diff --git a/src/instructlab/training/token_dataset.py b/src/instructlab/training/token_dataset.py index 1bbfe3f6..cab2f169 100644 --- a/src/instructlab/training/token_dataset.py +++ b/src/instructlab/training/token_dataset.py @@ -85,7 +85,8 @@ def setup_dataloader( is_granite=False, max_batch_len=60000, packing_max_batch_len=60000, - batch_sampler='multipack', + samples_per_gpu=None, + sampler='multipack', seed=47, ) -> DataLoader: collate_fn = make_collate_fn( @@ -95,7 +96,7 @@ def setup_dataloader( world_size = int(os.environ["WORLD_SIZE"]) lengths = dataset.get_lengths() - if batch_sampler == 'multipack': + if sampler == 'multipack': sampler = MultipackDistributedBatchSampler( batch_max_length=packing_max_batch_len, lengths=lengths, @@ -104,12 +105,23 @@ def setup_dataloader( seed=seed, padding=not is_granite, ) - elif batch_sampler == 'default': - sampler = None + sampler = {'batch_sampler': sampler} + elif sampler == 'distributed': + from torch.utils.data import DistributedSampler + sampler = ( + DistributedSampler(dataset) if + torch.distributed.is_initialized() else None + ) + sampler = { + 'sampler': sampler, + 'batch_size': samples_per_gpu, + } + else: + raise NotImplementedError dataloader = DataLoader( dataset, - batch_sampler=sampler, + **sampler, num_workers=num_workers, collate_fn=collate_fn, )