Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
honglu2875 committed Feb 4, 2024
1 parent 26178bf commit 03aeac2
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 24 deletions.
40 changes: 23 additions & 17 deletions aria/model/model_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


class GPTNeoXAria(TransformerLM):

"""A wrapper for GPTNeoXForCausalLM."""

def __init__(self, model_config: ModelConfig, use_cache: bool = False):
Expand Down Expand Up @@ -39,19 +38,26 @@ def __init__(self, model_config: ModelConfig, use_cache: bool = False):
self.model.gradient_checkpointing_enable()

def forward(
self,
src: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
past_kv: Optional[list[KVCache]] = None,
):
if past_kv is None:
output = self.model(src, attention_mask=attn_mask)
else:
bs = src.size(0)
hf_past_kv = tuple((kv.k_cache[:bs, :kv.next_pos], kv.v_cache[:bs, :kv.next_pos]) for i, kv in enumerate(past_kv))
output: CausalLMOutputWithPast = self.model(src, attention_mask=attn_mask, past_key_values=hf_past_kv)
if output.past_key_values is not None:
for i, kv in enumerate(past_kv):
kv.update(output.past_key_values[i][0][:, kv.next_pos:],
output.past_key_values[i][1][:, kv.next_pos:])
return output.logits
self,
src: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
past_kv: Optional[list[KVCache]] = None,
):
if past_kv is None:
output = self.model(src, attention_mask=attn_mask)
else:
bs = src.size(0)
hf_past_kv = tuple(
(kv.k_cache[:bs, : kv.next_pos], kv.v_cache[:bs, : kv.next_pos])
for i, kv in enumerate(past_kv)
)
output: CausalLMOutputWithPast = self.model(
src, attention_mask=attn_mask, past_key_values=hf_past_kv
)
if output.past_key_values is not None:
for i, kv in enumerate(past_kv):
kv.update(
output.past_key_values[i][0][:, kv.next_pos :],
output.past_key_values[i][1][:, kv.next_pos :],
)
return output.logits
30 changes: 23 additions & 7 deletions aria/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,12 @@ def make_checkpoint(_accelerator, _epoch: int, _step: int):

# This is all slightly messy as train_loop and val_loop make use of the
# variables in the wider scope. Perhaps refactor this at some point.
def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0, grad_acc: int = 1):
def train_loop(
dataloader: DataLoader,
_epoch: int,
_resume_step: int = 0,
grad_acc: int = 1,
):
avg_train_loss = 0
trailing_loss = 0
loss_buffer = []
Expand All @@ -404,7 +409,10 @@ def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0, grad_
lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"])

model.train()
pbar = tqdm(total=len(dataloader) // grad_acc + _resume_step, initial=_resume_step)
pbar = tqdm(
total=len(dataloader) // grad_acc + _resume_step,
initial=_resume_step,
)
for __step, batch in enumerate(dataloader):
step = __step + _resume_step + 1
src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz)
Expand All @@ -420,7 +428,7 @@ def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0, grad_
avg_train_loss = rolling_average(
avg_train_loss, loss.item(), __step
)

if step % grad_acc == 0:
# Logging
pbar.update(1)
Expand Down Expand Up @@ -543,7 +551,9 @@ def val_loop(dataloader, _epoch: int):

for epoch in range(start_epoch, epochs + start_epoch):
train_dataloader.dataset.init_epoch(epoch)
avg_train_loss = train_loop(dataloader=train_dataloader, _epoch=epoch, grad_acc=grad_acc)
avg_train_loss = train_loop(
dataloader=train_dataloader, _epoch=epoch, grad_acc=grad_acc
)
avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=epoch)
if accelerator.is_main_process:
epoch_writer.writerow([epoch, avg_train_loss, avg_val_loss])
Expand Down Expand Up @@ -935,7 +945,9 @@ def parse_resume_args():
argp.add_argument(
"-use_neox", help="use neox model", action="store_true", default=False
)
argp.add_argument("-grad_acc", help="gradient accumulation steps.", type=int, default=1)
argp.add_argument(
"-grad_acc", help="gradient accumulation steps.", type=int, default=1
)

return argp.parse_args(sys.argv[2:])

Expand All @@ -955,7 +967,9 @@ def parse_pretrain_args():
argp.add_argument(
"-use_neox", help="use neox model", action="store_true", default=False
)
argp.add_argument("-grad_acc", help="gradient accumulation steps.", type=int, default=1)
argp.add_argument(
"-grad_acc", help="gradient accumulation steps.", type=int, default=1
)

return argp.parse_args(sys.argv[2:])

Expand All @@ -976,7 +990,9 @@ def parse_finetune_args():
argp.add_argument(
"-use_neox", help="use neox model", action="store_true", default=False
)
argp.add_argument("-grad_acc", help="gradient accumulation steps.", type=int, default=1)
argp.add_argument(
"-grad_acc", help="gradient accumulation steps.", type=int, default=1
)

return argp.parse_args(sys.argv[2:])

Expand Down

0 comments on commit 03aeac2

Please sign in to comment.