Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Jun 28, 2024
1 parent ec4274d commit 5f314e2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
9 changes: 3 additions & 6 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,9 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):

for epoch in range(args.num_epochs):
torch.distributed.barrier()
if args.sampler in ('multipack'):
if args.sampler in ("multipack"):
train_loader.batch_sampler.set_epoch(epoch)
elif args.sampler in ('distributed'):
elif args.sampler in ("distributed"):
train_loader.sampler.set_epoch(epoch)
else:
raise NotADirectoryError
Expand Down Expand Up @@ -686,10 +686,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs):
type=str,
default="multipack",
help="The batch sampler type to use.",
choices=[
"multipack",
"distributed"
],
choices=["multipack", "distributed"],
)
parser.add_argument("--num_warmup_steps", type=int, default=1000)
# parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
Expand Down
17 changes: 9 additions & 8 deletions src/instructlab/training/token_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def setup_dataloader(
max_batch_len=60000,
packing_max_batch_len=60000,
samples_per_gpu=None,
sampler='multipack',
sampler="multipack",
seed=47,
) -> DataLoader:
collate_fn = make_collate_fn(
Expand All @@ -96,7 +96,7 @@ def setup_dataloader(
world_size = int(os.environ["WORLD_SIZE"])

lengths = dataset.get_lengths()
if sampler == 'multipack':
if sampler == "multipack":
sampler = MultipackDistributedBatchSampler(
batch_max_length=packing_max_batch_len,
lengths=lengths,
Expand All @@ -105,16 +105,17 @@ def setup_dataloader(
seed=seed,
padding=not is_granite,
)
sampler = {'batch_sampler': sampler}
elif sampler == 'distributed':
sampler = {"batch_sampler": sampler}
elif sampler == "distributed":
# Third Party
from torch.utils.data import DistributedSampler

sampler = (
DistributedSampler(dataset) if
torch.distributed.is_initialized() else None
DistributedSampler(dataset) if torch.distributed.is_initialized() else None
)
sampler = {
'sampler': sampler,
'batch_size': samples_per_gpu,
"sampler": sampler,
"batch_size": samples_per_gpu,
}
else:
raise NotImplementedError
Expand Down

0 comments on commit 5f314e2

Please sign in to comment.