Skip to content

Commit

Permalink
Add instrument check to detokenize (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad authored Dec 22, 2023
1 parent 3667e27 commit 00d9c85
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 20 deletions.
2 changes: 1 addition & 1 deletion aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def _build(_midi_dataset):
_build(_midi_dataset=midi_dataset)

logger.info(
f"Finished building, saved PretrainingDataset to {save_path}"
f"Finished building, saved Finetuning to {save_path}"
)

return cls(file_path=save_path, tokenizer=tokenizer)
6 changes: 1 addition & 5 deletions aria/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@
from aria.tokenizer import Tokenizer


# TODO:
# - Truncate if end token seen
# - Fix the issue with onset tokens being <U> (5000ms?)
# - Fix the issue with dim tok being inserted at the wrong time

# TODO: Add which instruments were detected in the prompt

def _get_cfg_coeff(cfg_gamma, cfg_mode, cur_pos, start_pos, total_len):
if cfg_mode is None:
Expand Down
33 changes: 19 additions & 14 deletions aria/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,22 +564,27 @@ def detokenize_midi_dict(self, tokenized_seq: list):
_start_tick = curr_tick + tok_2[1]
_end_tick = _start_tick + self.time_step
_pitch = tok_1[1]
_channel = instrument_to_channel["drum"]
_channel = instrument_to_channel.get(tok_1[0], None)
_velocity = self.config["drum_velocity"]

note_msgs.append(
{
"type": "note",
"data": {
"pitch": _pitch,
"start": _start_tick,
"end": _end_tick,
"velocity": _velocity,
},
"tick": _start_tick,
"channel": _channel,
}
)
if _channel is None:
logging.warning(
"Tried to decode note message for unexpected instrument"
)
else:
note_msgs.append(
{
"type": "note",
"data": {
"pitch": _pitch,
"start": _start_tick,
"end": _end_tick,
"velocity": _velocity,
},
"tick": _start_tick,
"channel": _channel,
}
)

elif (
_tok_type_1 in self.instruments_nd
Expand Down

0 comments on commit 00d9c85

Please sign in to comment.