Skip to content

Commit

Permalink
Fix dim_tok in sampling (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad authored Dec 15, 2023
1 parent d7f583d commit a7e305a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
18 changes: 14 additions & 4 deletions aria/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
from aria.tokenizer import Tokenizer


# TODO:
# - Truncate if end token seen
# - Fix the issue with onset tokens being <U> (5000ms?)
# - Fix the issue with dim tok being inserted at the wrong time


def _get_cfg_coeff(cfg_gamma, cfg_mode, cur_pos, start_pos, total_len):
if cfg_mode is None:
return cfg_gamma
Expand Down Expand Up @@ -243,10 +249,9 @@ def greedy_sample(
# Insert dim tokens
if force_end and cur_pos >= total_len - 130:
for _idx in range(tokens.size(0)):
if (
dim_tok_inserted[_idx] is False
and tokenizer.id_to_tok[next_token[_idx].item()][0] != "dur"
):
if dim_tok_inserted[_idx] is False and tokenizer.id_to_tok[
next_token[_idx].item()
][0] not in ("dur", "onset"):
next_token[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok]

# Update dim_tok_inserted
Expand All @@ -267,6 +272,11 @@ def greedy_sample(
pass
decoded.append(tokenizer.decode(seq))

for idx, seq in enumerate(decoded):
if tokenizer.eos_tok in seq:
eos_idx = seq.index(tokenizer.eos_tok)
decoded[idx] = seq[:eos_idx]

return decoded


Expand Down
3 changes: 3 additions & 0 deletions aria/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,9 @@ def _quantize_time(_n: int):
curr_tgt_time_tok_cnt = tgt_time // abs_time_step
curr_tgt_onset = _quantize_time(tgt_time % abs_time_step)

if curr_tgt_onset == abs_time_step:
curr_tgt_onset -= time_step

for _ in range(
curr_tgt_time_tok_cnt - prev_tgt_time_tok_cnt
):
Expand Down

0 comments on commit a7e305a

Please sign in to comment.