From 03aeac2fcc1ed8f3541b24614c0eb9468c56c1a3 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sun, 4 Feb 2024 19:33:45 +0000 Subject: [PATCH] fix format --- aria/model/model_neox.py | 40 +++++++++++++++++++++++----------------- aria/train.py | 30 +++++++++++++++++++++++------- 2 files changed, 46 insertions(+), 24 deletions(-) diff --git a/aria/model/model_neox.py b/aria/model/model_neox.py index 2890e73..fd409ae 100644 --- a/aria/model/model_neox.py +++ b/aria/model/model_neox.py @@ -9,7 +9,6 @@ class GPTNeoXAria(TransformerLM): - """A wrapper for GPTNeoXForCausalLM.""" def __init__(self, model_config: ModelConfig, use_cache: bool = False): @@ -39,19 +38,26 @@ def __init__(self, model_config: ModelConfig, use_cache: bool = False): self.model.gradient_checkpointing_enable() def forward( - self, - src: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - past_kv: Optional[list[KVCache]] = None, - ): - if past_kv is None: - output = self.model(src, attention_mask=attn_mask) - else: - bs = src.size(0) - hf_past_kv = tuple((kv.k_cache[:bs, :kv.next_pos], kv.v_cache[:bs, :kv.next_pos]) for i, kv in enumerate(past_kv)) - output: CausalLMOutputWithPast = self.model(src, attention_mask=attn_mask, past_key_values=hf_past_kv) - if output.past_key_values is not None: - for i, kv in enumerate(past_kv): - kv.update(output.past_key_values[i][0][:, kv.next_pos:], - output.past_key_values[i][1][:, kv.next_pos:]) - return output.logits + self, + src: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + past_kv: Optional[list[KVCache]] = None, + ): + if past_kv is None: + output = self.model(src, attention_mask=attn_mask) + else: + bs = src.size(0) + hf_past_kv = tuple( + (kv.k_cache[:bs, : kv.next_pos], kv.v_cache[:bs, : kv.next_pos]) + for i, kv in enumerate(past_kv) + ) + output: CausalLMOutputWithPast = self.model( + src, attention_mask=attn_mask, past_key_values=hf_past_kv + ) + if output.past_key_values is not None: + for i, kv in enumerate(past_kv): + kv.update( + output.past_key_values[i][0][:, kv.next_pos :], + output.past_key_values[i][1][:, kv.next_pos :], + ) + return output.logits diff --git a/aria/train.py b/aria/train.py index 5f761f0..5daae82 100644 --- a/aria/train.py +++ b/aria/train.py @@ -391,7 +391,12 @@ def make_checkpoint(_accelerator, _epoch: int, _step: int): # This is all slightly messy as train_loop and val_loop make use of the # variables in the wider scope. Perhaps refactor this at some point. - def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0, grad_acc: int = 1): + def train_loop( + dataloader: DataLoader, + _epoch: int, + _resume_step: int = 0, + grad_acc: int = 1, + ): avg_train_loss = 0 trailing_loss = 0 loss_buffer = [] @@ -404,7 +409,10 @@ def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0, grad_ lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) model.train() - pbar = tqdm(total=len(dataloader) // grad_acc + _resume_step, initial=_resume_step) + pbar = tqdm( + total=len(dataloader) // grad_acc + _resume_step, + initial=_resume_step, + ) for __step, batch in enumerate(dataloader): step = __step + _resume_step + 1 src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz) @@ -420,7 +428,7 @@ def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0, grad_ avg_train_loss = rolling_average( avg_train_loss, loss.item(), __step ) - + if step % grad_acc == 0: # Logging pbar.update(1) @@ -543,7 +551,9 @@ def val_loop(dataloader, _epoch: int): for epoch in range(start_epoch, epochs + start_epoch): train_dataloader.dataset.init_epoch(epoch) - avg_train_loss = train_loop(dataloader=train_dataloader, _epoch=epoch, grad_acc=grad_acc) + avg_train_loss = train_loop( + dataloader=train_dataloader, _epoch=epoch, grad_acc=grad_acc + ) avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=epoch) if accelerator.is_main_process: epoch_writer.writerow([epoch, avg_train_loss, avg_val_loss]) @@ -935,7 +945,9 @@ def parse_resume_args(): argp.add_argument( "-use_neox", help="use neox model", action="store_true", default=False ) - argp.add_argument("-grad_acc", help="gradient accumulation steps.", type=int, default=1) + argp.add_argument( + "-grad_acc", help="gradient accumulation steps.", type=int, default=1 + ) return argp.parse_args(sys.argv[2:]) @@ -955,7 +967,9 @@ def parse_pretrain_args(): argp.add_argument( "-use_neox", help="use neox model", action="store_true", default=False ) - argp.add_argument("-grad_acc", help="gradient accumulation steps.", type=int, default=1) + argp.add_argument( + "-grad_acc", help="gradient accumulation steps.", type=int, default=1 + ) return argp.parse_args(sys.argv[2:]) @@ -976,7 +990,9 @@ def parse_finetune_args(): argp.add_argument( "-use_neox", help="use neox model", action="store_true", default=False ) - argp.add_argument("-grad_acc", help="gradient accumulation steps.", type=int, default=1) + argp.add_argument( + "-grad_acc", help="gradient accumulation steps.", type=int, default=1 + ) return argp.parse_args(sys.argv[2:])