From ada0673a47857a466e6b0780109897dcde1134f2 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 24 Apr 2024 10:47:08 +0000 Subject: [PATCH] remove debug --- amt/model.py | 19 +------------------ amt/train.py | 16 ---------------- 2 files changed, 1 insertion(+), 34 deletions(-) diff --git a/amt/model.py b/amt/model.py index e2fef9f..89b4da0 100644 --- a/amt/model.py +++ b/amt/model.py @@ -54,12 +54,6 @@ def __init__(self, n_state: int, n_head: int): self.key = nn.Linear(n_state, n_state, bias=False) self.value = nn.Linear(n_state, n_state, bias=False) self.out = nn.Linear(n_state, n_state, bias=False) - - # self.x_norm = None - # self.q_norm = None - # self.k_norm = None - # self.v_norm = None - # self.out_norm = None def forward( self, @@ -84,11 +78,6 @@ def forward( q = q.view(batch_size, target_seq_len, self.n_head, self.d_head) k = k.view(batch_size, source_seq_len, self.n_head, self.d_head) v = v.view(batch_size, source_seq_len, self.n_head, self.d_head) - - # self.x_norm = torch.norm(x, dim=-1).mean() - # self.q_norm = torch.norm(q, dim=-1).mean() - # self.k_norm = torch.norm(k, dim=-1).mean() - # self.v_norm = torch.norm(v, dim=-1).mean() # (bz, L, nh, dh) -> (bz, nh, L, dh) q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) @@ -104,8 +93,6 @@ def forward( value=v, is_causal=_is_causal, ) - - # self.out_norm = torch.norm(wv, dim=-1).mean() # (bz, nh, L, dh) -> (bz, L, nh, dh) -> (bz, L, d) wv = wv.transpose(1, 2) @@ -222,10 +209,6 @@ def forward(self, x: Tensor, xa: Tensor): x = self.ln(x) logits = self.output(x) - # logits = ( - # x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) - # ).float() - return logits @@ -260,4 +243,4 @@ def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor: @property def device(self): - return next(self.parameters()).device \ No newline at end of file + return next(self.parameters()).device diff --git a/amt/train.py b/amt/train.py index 94a4bff..a5d0736 100644 --- a/amt/train.py +++ b/amt/train.py @@ -313,22 +313,6 @@ def make_checkpoint(_accelerator, _epoch: int, _step: int): f"EPOCH {_epoch}/{epochs + start_epoch}: Saving checkpoint - {checkpoint_dir}" ) _accelerator.save_state(checkpoint_dir) - - def log_activation_norms(_model: AmtEncoderDecoder, _accelerator: accelerate.Accelerator): - for idx, block in enumerate(_model.decoder.blocks): - x_norm = _accelerator.gather(block.attn.x_norm).mean() - q_norm = _accelerator.gather(block.attn.q_norm).mean() - k_norm = _accelerator.gather(block.attn.k_norm).mean() - v_norm = _accelerator.gather(block.attn.v_norm).mean() - out_norm = _accelerator.gather(block.attn.out_norm).mean() - logger.debug(f"{idx}.attn - x: {x_norm}, q: {q_norm}, k: {k_norm}, v: {v_norm}, out: {out_norm}") - - x_norm = _accelerator.gather(block.cross_attn.x_norm).mean() - q_norm = _accelerator.gather(block.cross_attn.q_norm).mean() - k_norm = _accelerator.gather(block.cross_attn.k_norm).mean() - v_norm = _accelerator.gather(block.cross_attn.v_norm).mean() - out_norm = _accelerator.gather(block.cross_attn.out_norm).mean() - logger.debug(f"{idx}.cross_attn - x: {x_norm}, q: {q_norm}, k: {k_norm}, v: {v_norm}, out: {out_norm}") def get_max_norm(named_parameters): max_grad_norm = {"val": 0.0}