Skip to content

Commit

Permalink
Adjust output layer logic (#29)
Browse files Browse the repository at this point in the history
* backup

* remove debug

* reduce time cutoff
  • Loading branch information
loubbrad authored Apr 24, 2024
1 parent 0394a05 commit 3091aa9
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 34 deletions.
9 changes: 5 additions & 4 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
reduce_ratio: float = 0.01,
detune_ratio: float = 0.1,
detune_max_shift: float = 0.15,
spec_aug_ratio: float = 0.95,
spec_aug_ratio: float = 0.9,
):
super().__init__()
self.tokenizer = AmtTokenizer()
Expand Down Expand Up @@ -135,12 +135,13 @@ def __init__(
n_stft=self.config["n_fft"] // 2 + 1,
)
self.spec_aug = torch.nn.Sequential(
torchaudio.transforms.TimeMasking(
time_mask_param=self.time_mask_param,
iid_masks=True,
),
torchaudio.transforms.FrequencyMasking(
freq_mask_param=self.freq_mask_param, iid_masks=True
),
torchaudio.transforms.TimeMasking(
time_mask_param=self.time_mask_param, iid_masks=True
),
)

def get_params(self):
Expand Down
8 changes: 4 additions & 4 deletions amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def get_wav_mid_segments(
max_pedal_len_ms=15000,
)

# Hardcoded to 2.5s
if _check_onset_threshold(mid_feature, 2500) is False:
print("No note messages after 2.5s - skipping")
# Hardcoded to 5s
if _check_onset_threshold(mid_feature, 5000) is False:
print("No note messages after 5s - skipping")
continue

else:
Expand Down Expand Up @@ -149,7 +149,7 @@ def pianoteq_cmd_fn(mid_path: str, wav_path: str):
safe_mid_path = shlex.quote(mid_path)
safe_wav_path = shlex.quote(wav_path)

executable_path = "/home/loubb/pianoteq/x86-64bit/Pianoteq 8 STAGE"
executable_path = "/mnt/ssd-1/aria/pianoteq/x86-64bit/Pianoteq 8 STAGE"
command = f'"{executable_path}" --preset {safe_preset} --midi {safe_mid_path} --wav {safe_wav_path}'

return command
Expand Down
14 changes: 5 additions & 9 deletions amt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def forward(
wv = wv.transpose(1, 2)
wv = wv.view(batch_size, target_seq_len, self.n_head * self.d_head)

return self.out(wv), None
return self.out(wv)


class ResidualAttentionBlock(nn.Module):
Expand Down Expand Up @@ -129,9 +129,9 @@ def forward(
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
):
x = x + self.attn(self.attn_ln(x), mask=mask)[0]
x = x + self.attn(self.attn_ln(x), mask=mask)
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa)[0]
x = x + self.cross_attn(self.cross_attn_ln(x), xa)
x = x + self.mlp(self.mlp_ln(x))
return x

Expand Down Expand Up @@ -188,6 +188,7 @@ def __init__(
]
)
self.ln = nn.LayerNorm(n_state)
self.output = nn.Linear(n_state, n_vocab, bias=False)

mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
Expand All @@ -206,9 +207,7 @@ def forward(self, x: Tensor, xa: Tensor):
x = block(x, xa, mask=self.mask)

x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
logits = self.output(x)

return logits

Expand Down Expand Up @@ -245,6 +244,3 @@ def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
@property
def device(self):
return next(self.parameters()).device

def get_empty_cache(self):
return {}
44 changes: 27 additions & 17 deletions amt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import csv
import random
import traceback
import functools
import argparse
import logging
Expand All @@ -24,7 +25,7 @@
from amt.config import load_model_config
from aria.utils import _load_weight

GRADIENT_ACC_STEPS = 32
GRADIENT_ACC_STEPS = 2

# ----- USAGE -----
#
Expand Down Expand Up @@ -143,7 +144,7 @@ def _get_optim(
model.parameters(),
lr=lr,
weight_decay=0.1,
betas=(0.9, 0.98),
betas=(0.9, 0.95),
eps=1e-6,
)

Expand Down Expand Up @@ -344,6 +345,7 @@ def train_loop(
lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"])

model.train()
grad_norm = 0.0
for __step, batch in (
pbar := tqdm(
enumerate(dataloader),
Expand Down Expand Up @@ -378,8 +380,6 @@ def train_loop(
grad_norm = accelerator.clip_grad_norm_(
model.parameters(), 1.0
).item()
else:
grad_norm = 0
optimizer.step()
optimizer.zero_grad()

Expand All @@ -398,7 +398,8 @@ def train_loop(
pbar.set_postfix_str(
f"lr={lr_for_print}, "
f"loss={round(loss_buffer[-1], 4)}, "
f"trailing={round(trailing_loss, 4)}"
f"trailing={round(trailing_loss, 4)}, "
f"grad_norm={round(grad_norm, 4)}"
)

if scheduler:
Expand Down Expand Up @@ -470,6 +471,7 @@ def val_loop(dataloader, _epoch: int, aug: bool):
PAD_ID = train_dataloader.dataset.tokenizer.pad_id
logger = get_logger(__name__) # Accelerate logger
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)

logger.info(
f"Model has "
f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} "
Expand Down Expand Up @@ -522,19 +524,27 @@ def val_loop(dataloader, _epoch: int, aug: bool):
)

for epoch in range(start_epoch, epochs + start_epoch):
avg_train_loss = train_loop(dataloader=train_dataloader, _epoch=epoch)
avg_val_loss = val_loop(
dataloader=val_dataloader, _epoch=epoch, aug=False
)
avg_val_loss_aug = val_loop(
dataloader=val_dataloader, _epoch=epoch, aug=True
)
if accelerator.is_main_process:
epoch_writer.writerow(
[epoch, avg_train_loss, avg_val_loss, avg_val_loss_aug]
try:
avg_train_loss = train_loop(
dataloader=train_dataloader, _epoch=epoch
)
epoch_csv.flush()
make_checkpoint(_accelerator=accelerator, _epoch=epoch + 1, _step=0)
avg_val_loss = val_loop(
dataloader=val_dataloader, _epoch=epoch, aug=False
)
avg_val_loss_aug = val_loop(
dataloader=val_dataloader, _epoch=epoch, aug=True
)
if accelerator.is_main_process:
epoch_writer.writerow(
[epoch, avg_train_loss, avg_val_loss, avg_val_loss_aug]
)
epoch_csv.flush()
make_checkpoint(
_accelerator=accelerator, _epoch=epoch + 1, _step=0
)
except Exception as e:
logger.debug(traceback.format_exc())
raise e

logging.shutdown()
if accelerator.is_main_process:
Expand Down

0 comments on commit 3091aa9

Please sign in to comment.