diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index a0c03e5a..41ee21f5 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -346,7 +346,12 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): for epoch in range(args.num_epochs): torch.distributed.barrier() - train_loader.batch_sampler.set_epoch(epoch) + if args.sampler in ("multipack"): + train_loader.batch_sampler.set_epoch(epoch) + elif args.sampler in ("distributed"): + train_loader.sampler.set_epoch(epoch) + else: + raise NotADirectoryError if local_rank == 0: inner_pb = tqdm(range(len(train_loader)), desc=f"Epoch {epoch}") @@ -510,6 +515,8 @@ def main(args): is_granite=args.is_granite, max_batch_len=args.max_batch_len, packing_max_batch_len=packing_max_batch_len, + samples_per_gpu=args.samples_per_gpu, + sampler=args.sampler, seed=args.seed, ) @@ -674,6 +681,13 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs): "constant_with_warmup", ], ) + parser.add_argument( + "--sampler", + type=str, + default="multipack", + help="The batch sampler type to use.", + choices=["multipack", "distributed"], + ) parser.add_argument("--num_warmup_steps", type=int, default=1000) # parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--save_samples", type=int) diff --git a/src/instructlab/training/token_dataset.py b/src/instructlab/training/token_dataset.py index e36b7395..d4acf22d 100644 --- a/src/instructlab/training/token_dataset.py +++ b/src/instructlab/training/token_dataset.py @@ -85,6 +85,8 @@ def setup_dataloader( is_granite=False, max_batch_len=60000, packing_max_batch_len=60000, + samples_per_gpu=None, + sampler="multipack", seed=47, ) -> DataLoader: collate_fn = make_collate_fn( @@ -94,17 +96,33 @@ def setup_dataloader( world_size = int(os.environ["WORLD_SIZE"]) lengths = dataset.get_lengths() - sampler = MultipackDistributedBatchSampler( - batch_max_length=packing_max_batch_len, - lengths=lengths, - num_replicas=world_size, - rank=rank, - seed=seed, - padding=not is_granite, - ) + if sampler == "multipack": + sampler = MultipackDistributedBatchSampler( + batch_max_length=packing_max_batch_len, + lengths=lengths, + num_replicas=world_size, + rank=rank, + seed=seed, + padding=not is_granite, + ) + sampler = {"batch_sampler": sampler} + elif sampler == "distributed": + # Third Party + from torch.utils.data import DistributedSampler + + sampler = ( + DistributedSampler(dataset) if torch.distributed.is_initialized() else None + ) + sampler = { + "sampler": sampler, + "batch_size": samples_per_gpu, + } + else: + raise NotImplementedError + dataloader = DataLoader( dataset, - batch_sampler=sampler, + **sampler, num_workers=num_workers, collate_fn=collate_fn, )