From a7e305a7e8b8029ab1aab61a24bb8ba534a1a997 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 15 Dec 2023 14:42:19 +0000 Subject: [PATCH] Fix dim_tok in sampling (#82) --- aria/sample.py | 18 ++++++++++++++---- aria/tokenizer/tokenizer.py | 3 +++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/aria/sample.py b/aria/sample.py index e1fb713..8f88c46 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -16,6 +16,12 @@ from aria.tokenizer import Tokenizer +# TODO: +# - Truncate if end token seen +# - Fix the issue with onset tokens being (5000ms?) +# - Fix the issue with dim tok being inserted at the wrong time + + def _get_cfg_coeff(cfg_gamma, cfg_mode, cur_pos, start_pos, total_len): if cfg_mode is None: return cfg_gamma @@ -243,10 +249,9 @@ def greedy_sample( # Insert dim tokens if force_end and cur_pos >= total_len - 130: for _idx in range(tokens.size(0)): - if ( - dim_tok_inserted[_idx] is False - and tokenizer.id_to_tok[next_token[_idx].item()][0] != "dur" - ): + if dim_tok_inserted[_idx] is False and tokenizer.id_to_tok[ + next_token[_idx].item() + ][0] not in ("dur", "onset"): next_token[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok] # Update dim_tok_inserted @@ -267,6 +272,11 @@ def greedy_sample( pass decoded.append(tokenizer.decode(seq)) + for idx, seq in enumerate(decoded): + if tokenizer.eos_tok in seq: + eos_idx = seq.index(tokenizer.eos_tok) + decoded[idx] = seq[:eos_idx] + return decoded diff --git a/aria/tokenizer/tokenizer.py b/aria/tokenizer/tokenizer.py index 990b264..82ee515 100644 --- a/aria/tokenizer/tokenizer.py +++ b/aria/tokenizer/tokenizer.py @@ -817,6 +817,9 @@ def _quantize_time(_n: int): curr_tgt_time_tok_cnt = tgt_time // abs_time_step curr_tgt_onset = _quantize_time(tgt_time % abs_time_step) + if curr_tgt_onset == abs_time_step: + curr_tgt_onset -= time_step + for _ in range( curr_tgt_time_tok_cnt - prev_tgt_time_tok_cnt ):