Skip to content

Commit

Permalink
Fix sampling truncation (#100)
Browse files Browse the repository at this point in the history
* fix truncate bug

* fix formatting
  • Loading branch information
loubbrad authored Feb 10, 2024
1 parent 7ac416c commit c2e6180
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
10 changes: 5 additions & 5 deletions aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,6 @@ def sample(args):
model_config.grad_checkpoint = False
model = TransformerLM(model_config).to(device)

if args.trunc + args.l > model_config.max_seq_len:
print("WARNING - required context exceeds max_seq_len")

try:
model.load_state_dict(model_state)
except:
Expand Down Expand Up @@ -251,9 +248,12 @@ def _quantize(module, key, input_shape):
f"Instruments: {set([MidiDict.get_program_to_instrument()[msg['data']] for msg in midi_dict.instrument_msgs])}"
) # Not working with al.mid ?
prompt_seq = tokenizer.tokenize(midi_dict=midi_dict)
prompt_seq = prompt_seq[:truncate_len]
print(prompt_seq[: prompt_seq.index(tokenizer.bos_tok)])
prompt_seq = prompt_seq[
: prompt_seq.index(tokenizer.bos_tok) + truncate_len + 1
]
prompts = [prompt_seq for _ in range(num_variations)]
if len(prompt_seq) + args.l > model_config.max_seq_len:
print("WARNING Required context exceeds max_seq_len supported by model")

# Sample
results = greedy_sample(
Expand Down
2 changes: 1 addition & 1 deletion aria/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def greedy_sample(
cfg_mode: str | None = None,
neg_prompts: List[list] | None = None,
neg_prompt_len: int | None = None,
alpha: float | None = 0.4,
alpha: float | None = None,
force_end=False,
temperature: float = 0.95,
top_p: float = 0.95,
Expand Down
4 changes: 2 additions & 2 deletions aria/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def tokenize(self, midi_dict: MidiDict, **kwargs):
This function should be overridden if additional transformations are
required. For instance, in fine-tuning tokenizer you may want to insert
additional tokens. The default behaviour is to call tokenize_midi_dict.
additional tokens. The default behavior is to call tokenize_midi_dict.
"""
return self._tokenize_midi_dict(midi_dict)

Expand All @@ -76,7 +76,7 @@ def detokenize(self, tokenized_seq: list):
"""Detokenizes a MidiDict object.
This function should be overridden if additional are required during
detokenization. The default behaviour is to call detokenize_midi_dict.
detokenization. The default behavior is to call detokenize_midi_dict.
"""
return self._detokenize_midi_dict(tokenized_seq)

Expand Down

0 comments on commit c2e6180

Please sign in to comment.