From c2e618049e9a2802f2c6465ac1d025c4e6e62dfa Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 10 Feb 2024 15:05:17 +0000 Subject: [PATCH] Fix sampling truncation (#100) * fix truncate bug * fix formatting --- aria/run.py | 10 +++++----- aria/sample.py | 2 +- aria/tokenizer/tokenizer.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/aria/run.py b/aria/run.py index 37797f2..df40de7 100644 --- a/aria/run.py +++ b/aria/run.py @@ -178,9 +178,6 @@ def sample(args): model_config.grad_checkpoint = False model = TransformerLM(model_config).to(device) - if args.trunc + args.l > model_config.max_seq_len: - print("WARNING - required context exceeds max_seq_len") - try: model.load_state_dict(model_state) except: @@ -251,9 +248,12 @@ def _quantize(module, key, input_shape): f"Instruments: {set([MidiDict.get_program_to_instrument()[msg['data']] for msg in midi_dict.instrument_msgs])}" ) # Not working with al.mid ? prompt_seq = tokenizer.tokenize(midi_dict=midi_dict) - prompt_seq = prompt_seq[:truncate_len] - print(prompt_seq[: prompt_seq.index(tokenizer.bos_tok)]) + prompt_seq = prompt_seq[ + : prompt_seq.index(tokenizer.bos_tok) + truncate_len + 1 + ] prompts = [prompt_seq for _ in range(num_variations)] + if len(prompt_seq) + args.l > model_config.max_seq_len: + print("WARNING Required context exceeds max_seq_len supported by model") # Sample results = greedy_sample( diff --git a/aria/sample.py b/aria/sample.py index 8c63762..974d4f5 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -105,7 +105,7 @@ def greedy_sample( cfg_mode: str | None = None, neg_prompts: List[list] | None = None, neg_prompt_len: int | None = None, - alpha: float | None = 0.4, + alpha: float | None = None, force_end=False, temperature: float = 0.95, top_p: float = 0.95, diff --git a/aria/tokenizer/tokenizer.py b/aria/tokenizer/tokenizer.py index a6d1ea6..f687b56 100644 --- a/aria/tokenizer/tokenizer.py +++ b/aria/tokenizer/tokenizer.py @@ -63,7 +63,7 @@ def tokenize(self, midi_dict: MidiDict, **kwargs): This function should be overridden if additional transformations are required. For instance, in fine-tuning tokenizer you may want to insert - additional tokens. The default behaviour is to call tokenize_midi_dict. + additional tokens. The default behavior is to call tokenize_midi_dict. """ return self._tokenize_midi_dict(midi_dict) @@ -76,7 +76,7 @@ def detokenize(self, tokenized_seq: list): """Detokenizes a MidiDict object. This function should be overridden if additional are required during - detokenization. The default behaviour is to call detokenize_midi_dict. + detokenization. The default behavior is to call detokenize_midi_dict. """ return self._detokenize_midi_dict(tokenized_seq)