Skip to content

Commit

Permalink
Fix multi_gpu (#81)
Browse files Browse the repository at this point in the history
* change api

* fix and add warning

* fix

* fix logging and project dir

* fix

* pdir fix

* fix

* fix
  • Loading branch information
loubbrad authored Dec 14, 2023
1 parent 7fc0908 commit d7f583d
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 20 deletions.
20 changes: 20 additions & 0 deletions aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@ def _parse_sample_args():
argp.add_argument("-m", help="name of model config file")
argp.add_argument("-c", help="path to model checkpoint")
argp.add_argument("-p", help="path to midi file")
argp.add_argument(
"-cfg",
help="change cfg value",
type=float,
required=False,
default=1.4,
)
argp.add_argument(
"-temp",
help="change temp value",
type=float,
required=False,
default=0.85,
)
argp.add_argument(
"-var",
help="number of variations",
Expand Down Expand Up @@ -138,6 +152,10 @@ def sample(args):
model_config = ModelConfig(**load_model_config(model_name))
model_config.set_vocab_size(tokenizer.vocab_size)
model = TransformerLM(model_config).to(device)

if args.trunc + args.l > model_config.max_seq_len:
print("WARNING - required context exceeds max_seq_len")

try:
model.load_state_dict(model_state)
except:
Expand Down Expand Up @@ -213,6 +231,8 @@ def _quantize(module, key, input_shape):
device=device,
force_end=force_end,
max_new_tokens=max_new_tokens,
cfg_gamma=args.cfg,
temperature=args.temp,
)

if os.path.isdir("samples") is False:
Expand Down
55 changes: 35 additions & 20 deletions aria/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def _train(
steps_per_checkpoint: int | None = None,
resume_step: int | None = None,
resume_epoch: int | None = None,
project_dir: str | None = None,
):
def profile_flops(dataloader: DataLoader):
def _bench():
Expand Down Expand Up @@ -431,7 +432,8 @@ def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0):
f"trailing_loss={round(trailing_loss, 4)}, "
f"average_loss={round(avg_train_loss, 4)}"
)
loss_writer.writerow([_epoch, step, loss.item()])
if accelerator.is_main_process:
loss_writer.writerow([_epoch, step, loss.item()])
pbar.set_postfix_str(
f"lr={lr_for_print}, "
f"loss={round(loss.item(), 4)}, "
Expand Down Expand Up @@ -497,16 +499,16 @@ def val_loop(dataloader, _epoch: int):
TRAILING_LOSS_STEPS = 200
PAD_ID = train_dataloader.dataset.tokenizer.pad_id
logger = get_logger(__name__) # Accelerate logger
project_dir = accelerator.project_dir
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)
profile_flops(dataloader=train_dataloader)

loss_csv = open(os.path.join(project_dir, "loss.csv"), "w")
loss_writer = csv.writer(loss_csv)
loss_writer.writerow(["epoch", "step", "loss"])
epoch_csv = open(os.path.join(project_dir, "epoch.csv"), "w")
epoch_writer = csv.writer(epoch_csv)
epoch_writer.writerow(["epoch", "avg_train_loss", "avg_val_loss"])
if accelerator.is_main_process:
loss_csv = open(os.path.join(project_dir, "loss.csv"), "w")
loss_writer = csv.writer(loss_csv)
loss_writer.writerow(["epoch", "step", "loss"])
epoch_csv = open(os.path.join(project_dir, "epoch.csv"), "w")
epoch_writer = csv.writer(epoch_csv)
epoch_writer.writerow(["epoch", "avg_train_loss", "avg_val_loss"])

if resume_epoch is not None:
start_epoch = resume_epoch + 1
Expand All @@ -529,21 +531,26 @@ def val_loop(dataloader, _epoch: int):
_resume_step=resume_step,
)
avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=resume_epoch)
epoch_writer.writerow([0, avg_train_loss, avg_val_loss])
epoch_csv.flush()
make_checkpoint(_accelerator=accelerator, _epoch=start_epoch, _step=0)
if accelerator.is_main_process:
epoch_writer.writerow([0, avg_train_loss, avg_val_loss])
epoch_csv.flush()
make_checkpoint(
_accelerator=accelerator, _epoch=start_epoch, _step=0
)

for epoch in range(start_epoch, epochs + start_epoch):
train_dataloader.dataset.init_epoch(epoch)
avg_train_loss = train_loop(dataloader=train_dataloader, _epoch=epoch)
avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=epoch)
epoch_writer.writerow([epoch, avg_train_loss, avg_val_loss])
epoch_csv.flush()
make_checkpoint(_accelerator=accelerator, _epoch=epoch + 1, _step=0)
if accelerator.is_main_process:
epoch_writer.writerow([epoch, avg_train_loss, avg_val_loss])
epoch_csv.flush()
make_checkpoint(_accelerator=accelerator, _epoch=epoch + 1, _step=0)

loss_csv.close()
epoch_csv.close()
logging.shutdown()
if accelerator.is_main_process:
loss_csv.close()
epoch_csv.close()


# NOTE: Any differences observed when resuming training are most likely the
Expand Down Expand Up @@ -595,9 +602,12 @@ def resume_train(

# TODO: Add support for verifying the resume_step and epoch, keep these
# save these variables as part of the state during checkpointing
project_dir = setup_project_dir(project_dir)
accelerator = accelerate.Accelerator(project_dir=project_dir)
logger = setup_logger(project_dir)
if accelerator.is_main_process:
project_dir = setup_project_dir(project_dir)
logger = setup_logger(project_dir)

logger = get_logger(__name__)
logger.info(f"Using project directory {project_dir} ")
logger.warning(
"Please insure that the training config and resume step are set "
Expand Down Expand Up @@ -700,6 +710,7 @@ def resume_train(
steps_per_checkpoint=steps_per_checkpoint,
resume_step=resume_step,
resume_epoch=resume_epoch,
project_dir=project_dir,
)


Expand Down Expand Up @@ -745,9 +756,12 @@ def train(
else:
raise Exception("Invalid tokenizer name")

project_dir = setup_project_dir(project_dir)
accelerator = accelerate.Accelerator(project_dir=project_dir)
logger = setup_logger(project_dir)
if accelerator.is_main_process:
project_dir = setup_project_dir(project_dir)
logger = setup_logger(project_dir)

logger = get_logger(__name__)
logger.info(f"Using project directory {project_dir}")
logger.info(
f"Using training config: "
Expand Down Expand Up @@ -840,6 +854,7 @@ def train(
optimizer=optimizer,
scheduler=scheduler,
steps_per_checkpoint=steps_per_checkpoint,
project_dir=project_dir,
)


Expand Down

0 comments on commit d7f583d

Please sign in to comment.