diff --git a/aria/run.py b/aria/run.py index a19f19b..8723882 100644 --- a/aria/run.py +++ b/aria/run.py @@ -20,6 +20,20 @@ def _parse_sample_args(): argp.add_argument("-m", help="name of model config file") argp.add_argument("-c", help="path to model checkpoint") argp.add_argument("-p", help="path to midi file") + argp.add_argument( + "-cfg", + help="change cfg value", + type=float, + required=False, + default=1.4, + ) + argp.add_argument( + "-temp", + help="change temp value", + type=float, + required=False, + default=0.85, + ) argp.add_argument( "-var", help="number of variations", @@ -138,6 +152,10 @@ def sample(args): model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) model = TransformerLM(model_config).to(device) + + if args.trunc + args.l > model_config.max_seq_len: + print("WARNING - required context exceeds max_seq_len") + try: model.load_state_dict(model_state) except: @@ -213,6 +231,8 @@ def _quantize(module, key, input_shape): device=device, force_end=force_end, max_new_tokens=max_new_tokens, + cfg_gamma=args.cfg, + temperature=args.temp, ) if os.path.isdir("samples") is False: diff --git a/aria/train.py b/aria/train.py index 618ae81..2673054 100644 --- a/aria/train.py +++ b/aria/train.py @@ -345,6 +345,7 @@ def _train( steps_per_checkpoint: int | None = None, resume_step: int | None = None, resume_epoch: int | None = None, + project_dir: str | None = None, ): def profile_flops(dataloader: DataLoader): def _bench(): @@ -431,7 +432,8 @@ def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0): f"trailing_loss={round(trailing_loss, 4)}, " f"average_loss={round(avg_train_loss, 4)}" ) - loss_writer.writerow([_epoch, step, loss.item()]) + if accelerator.is_main_process: + loss_writer.writerow([_epoch, step, loss.item()]) pbar.set_postfix_str( f"lr={lr_for_print}, " f"loss={round(loss.item(), 4)}, " @@ -497,16 +499,16 @@ def val_loop(dataloader, _epoch: int): TRAILING_LOSS_STEPS = 200 PAD_ID = train_dataloader.dataset.tokenizer.pad_id logger = get_logger(__name__) # Accelerate logger - project_dir = accelerator.project_dir loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID) profile_flops(dataloader=train_dataloader) - loss_csv = open(os.path.join(project_dir, "loss.csv"), "w") - loss_writer = csv.writer(loss_csv) - loss_writer.writerow(["epoch", "step", "loss"]) - epoch_csv = open(os.path.join(project_dir, "epoch.csv"), "w") - epoch_writer = csv.writer(epoch_csv) - epoch_writer.writerow(["epoch", "avg_train_loss", "avg_val_loss"]) + if accelerator.is_main_process: + loss_csv = open(os.path.join(project_dir, "loss.csv"), "w") + loss_writer = csv.writer(loss_csv) + loss_writer.writerow(["epoch", "step", "loss"]) + epoch_csv = open(os.path.join(project_dir, "epoch.csv"), "w") + epoch_writer = csv.writer(epoch_csv) + epoch_writer.writerow(["epoch", "avg_train_loss", "avg_val_loss"]) if resume_epoch is not None: start_epoch = resume_epoch + 1 @@ -529,21 +531,26 @@ def val_loop(dataloader, _epoch: int): _resume_step=resume_step, ) avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=resume_epoch) - epoch_writer.writerow([0, avg_train_loss, avg_val_loss]) - epoch_csv.flush() - make_checkpoint(_accelerator=accelerator, _epoch=start_epoch, _step=0) + if accelerator.is_main_process: + epoch_writer.writerow([0, avg_train_loss, avg_val_loss]) + epoch_csv.flush() + make_checkpoint( + _accelerator=accelerator, _epoch=start_epoch, _step=0 + ) for epoch in range(start_epoch, epochs + start_epoch): train_dataloader.dataset.init_epoch(epoch) avg_train_loss = train_loop(dataloader=train_dataloader, _epoch=epoch) avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=epoch) - epoch_writer.writerow([epoch, avg_train_loss, avg_val_loss]) - epoch_csv.flush() - make_checkpoint(_accelerator=accelerator, _epoch=epoch + 1, _step=0) + if accelerator.is_main_process: + epoch_writer.writerow([epoch, avg_train_loss, avg_val_loss]) + epoch_csv.flush() + make_checkpoint(_accelerator=accelerator, _epoch=epoch + 1, _step=0) - loss_csv.close() - epoch_csv.close() logging.shutdown() + if accelerator.is_main_process: + loss_csv.close() + epoch_csv.close() # NOTE: Any differences observed when resuming training are most likely the @@ -595,9 +602,12 @@ def resume_train( # TODO: Add support for verifying the resume_step and epoch, keep these # save these variables as part of the state during checkpointing - project_dir = setup_project_dir(project_dir) accelerator = accelerate.Accelerator(project_dir=project_dir) - logger = setup_logger(project_dir) + if accelerator.is_main_process: + project_dir = setup_project_dir(project_dir) + logger = setup_logger(project_dir) + + logger = get_logger(__name__) logger.info(f"Using project directory {project_dir} ") logger.warning( "Please insure that the training config and resume step are set " @@ -700,6 +710,7 @@ def resume_train( steps_per_checkpoint=steps_per_checkpoint, resume_step=resume_step, resume_epoch=resume_epoch, + project_dir=project_dir, ) @@ -745,9 +756,12 @@ def train( else: raise Exception("Invalid tokenizer name") - project_dir = setup_project_dir(project_dir) accelerator = accelerate.Accelerator(project_dir=project_dir) - logger = setup_logger(project_dir) + if accelerator.is_main_process: + project_dir = setup_project_dir(project_dir) + logger = setup_logger(project_dir) + + logger = get_logger(__name__) logger.info(f"Using project directory {project_dir}") logger.info( f"Using training config: " @@ -840,6 +854,7 @@ def train( optimizer=optimizer, scheduler=scheduler, steps_per_checkpoint=steps_per_checkpoint, + project_dir=project_dir, )