Skip to content

Commit

Permalink
remove debug
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Apr 24, 2024
1 parent c22624e commit ada0673
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 34 deletions.
19 changes: 1 addition & 18 deletions amt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,6 @@ def __init__(self, n_state: int, n_head: int):
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state, bias=False)
self.out = nn.Linear(n_state, n_state, bias=False)

# self.x_norm = None
# self.q_norm = None
# self.k_norm = None
# self.v_norm = None
# self.out_norm = None

def forward(
self,
Expand All @@ -84,11 +78,6 @@ def forward(
q = q.view(batch_size, target_seq_len, self.n_head, self.d_head)
k = k.view(batch_size, source_seq_len, self.n_head, self.d_head)
v = v.view(batch_size, source_seq_len, self.n_head, self.d_head)

# self.x_norm = torch.norm(x, dim=-1).mean()
# self.q_norm = torch.norm(q, dim=-1).mean()
# self.k_norm = torch.norm(k, dim=-1).mean()
# self.v_norm = torch.norm(v, dim=-1).mean()

# (bz, L, nh, dh) -> (bz, nh, L, dh)
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
Expand All @@ -104,8 +93,6 @@ def forward(
value=v,
is_causal=_is_causal,
)

# self.out_norm = torch.norm(wv, dim=-1).mean()

# (bz, nh, L, dh) -> (bz, L, nh, dh) -> (bz, L, d)
wv = wv.transpose(1, 2)
Expand Down Expand Up @@ -222,10 +209,6 @@ def forward(self, x: Tensor, xa: Tensor):
x = self.ln(x)
logits = self.output(x)

# logits = (
# x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
# ).float()

return logits


Expand Down Expand Up @@ -260,4 +243,4 @@ def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:

@property
def device(self):
return next(self.parameters()).device
return next(self.parameters()).device
16 changes: 0 additions & 16 deletions amt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,22 +313,6 @@ def make_checkpoint(_accelerator, _epoch: int, _step: int):
f"EPOCH {_epoch}/{epochs + start_epoch}: Saving checkpoint - {checkpoint_dir}"
)
_accelerator.save_state(checkpoint_dir)

def log_activation_norms(_model: AmtEncoderDecoder, _accelerator: accelerate.Accelerator):
for idx, block in enumerate(_model.decoder.blocks):
x_norm = _accelerator.gather(block.attn.x_norm).mean()
q_norm = _accelerator.gather(block.attn.q_norm).mean()
k_norm = _accelerator.gather(block.attn.k_norm).mean()
v_norm = _accelerator.gather(block.attn.v_norm).mean()
out_norm = _accelerator.gather(block.attn.out_norm).mean()
logger.debug(f"{idx}.attn - x: {x_norm}, q: {q_norm}, k: {k_norm}, v: {v_norm}, out: {out_norm}")

x_norm = _accelerator.gather(block.cross_attn.x_norm).mean()
q_norm = _accelerator.gather(block.cross_attn.q_norm).mean()
k_norm = _accelerator.gather(block.cross_attn.k_norm).mean()
v_norm = _accelerator.gather(block.cross_attn.v_norm).mean()
out_norm = _accelerator.gather(block.cross_attn.out_norm).mean()
logger.debug(f"{idx}.cross_attn - x: {x_norm}, q: {q_norm}, k: {k_norm}, v: {v_norm}, out: {out_norm}")

def get_max_norm(named_parameters):
max_grad_norm = {"val": 0.0}
Expand Down

0 comments on commit ada0673

Please sign in to comment.