From 00d9c85b95cf8d1b19bea8c0152cc86d46669aa0 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 22 Dec 2023 13:54:42 +0000 Subject: [PATCH] Add instrument check to detokenize (#89) --- aria/data/datasets.py | 2 +- aria/sample.py | 6 +----- aria/tokenizer/tokenizer.py | 33 +++++++++++++++++++-------------- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/aria/data/datasets.py b/aria/data/datasets.py index 508aca0..b9da67c 100644 --- a/aria/data/datasets.py +++ b/aria/data/datasets.py @@ -833,7 +833,7 @@ def _build(_midi_dataset): _build(_midi_dataset=midi_dataset) logger.info( - f"Finished building, saved PretrainingDataset to {save_path}" + f"Finished building, saved Finetuning to {save_path}" ) return cls(file_path=save_path, tokenizer=tokenizer) diff --git a/aria/sample.py b/aria/sample.py index 8f88c46..7e40fd7 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -16,11 +16,7 @@ from aria.tokenizer import Tokenizer -# TODO: -# - Truncate if end token seen -# - Fix the issue with onset tokens being (5000ms?) -# - Fix the issue with dim tok being inserted at the wrong time - +# TODO: Add which instruments were detected in the prompt def _get_cfg_coeff(cfg_gamma, cfg_mode, cur_pos, start_pos, total_len): if cfg_mode is None: diff --git a/aria/tokenizer/tokenizer.py b/aria/tokenizer/tokenizer.py index 82ee515..5fecc66 100644 --- a/aria/tokenizer/tokenizer.py +++ b/aria/tokenizer/tokenizer.py @@ -564,22 +564,27 @@ def detokenize_midi_dict(self, tokenized_seq: list): _start_tick = curr_tick + tok_2[1] _end_tick = _start_tick + self.time_step _pitch = tok_1[1] - _channel = instrument_to_channel["drum"] + _channel = instrument_to_channel.get(tok_1[0], None) _velocity = self.config["drum_velocity"] - note_msgs.append( - { - "type": "note", - "data": { - "pitch": _pitch, - "start": _start_tick, - "end": _end_tick, - "velocity": _velocity, - }, - "tick": _start_tick, - "channel": _channel, - } - ) + if _channel is None: + logging.warning( + "Tried to decode note message for unexpected instrument" + ) + else: + note_msgs.append( + { + "type": "note", + "data": { + "pitch": _pitch, + "start": _start_tick, + "end": _end_tick, + "velocity": _velocity, + }, + "tick": _start_tick, + "channel": _channel, + } + ) elif ( _tok_type_1 in self.instruments_nd