From 56b66eacfef72081ea0eaca8184fab4792b54d19 Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 22 Feb 2024 22:41:54 +0000 Subject: [PATCH] Add resolve_pedal method to MidiDict (#104) --- aria/data/midi.py | 80 +++++++++++++++++++++++++++++++++++++ aria/tokenizer/tokenizer.py | 6 +++ tests/test_data.py | 9 +++++ 3 files changed, 95 insertions(+) diff --git a/aria/data/midi.py b/aria/data/midi.py index 6bd2a9b..ef4c7b6 100644 --- a/aria/data/midi.py +++ b/aria/data/midi.py @@ -6,6 +6,7 @@ import os import logging import pathlib +import copy import mido from collections import defaultdict @@ -188,6 +189,85 @@ def calculate_hash(self): json.dumps(msg_dict_to_hash, sort_keys=True).encode() ).hexdigest() + def _build_pedal_intervals(self): + """Returns pedal-on intervals for each channel.""" + self.pedal_msgs.sort(key=lambda msg: msg["tick"]) + channel_to_pedal_intervals = defaultdict(list) + pedal_status = {} + + for pedal_msg in self.pedal_msgs: + tick = pedal_msg["tick"] + channel = pedal_msg["channel"] + data = pedal_msg["data"] + + if data == 1 and pedal_status.get(channel, None) is None: + pedal_status[channel] = tick + elif data == 0 and pedal_status.get(channel, None) is not None: + # Close pedal interval + _start_tick = pedal_status[channel] + _end_tick = tick + channel_to_pedal_intervals[channel].append( + [_start_tick, _end_tick] + ) + del pedal_status[channel] + + # Close all unclosed pedals at end of track + final_tick = self.note_msgs[-1]["data"]["end"] + for channel, start_tick in pedal_status.items(): + channel_to_pedal_intervals[channel].append([start_tick, final_tick]) + + return channel_to_pedal_intervals + + def _resolve_overlaps(self): + """Resolves overlaps between notes on the same channel, this will only + change anything if pedal intervals have been resolved.""" + + # Organize notes by channel and pitch + note_msgs_c = defaultdict(lambda: defaultdict(list)) + for msg in self.note_msgs: + _channel = msg["channel"] + _pitch = msg["data"]["pitch"] + note_msgs_c[_channel][_pitch].append(msg) + + # We can modify notes by reference as they are dictionaries + for channel, msgs_by_pitch in note_msgs_c.items(): + for pitch, msgs in msgs_by_pitch.items(): + msgs.sort( + key=lambda msg: (msg["data"]["start"], msg["data"]["end"]) + ) + prev_off_tick = -1 + for idx, msg in enumerate(msgs): + on_tick = msg["data"]["start"] + off_tick = msg["data"]["end"] + if prev_off_tick > on_tick: + # Adjust end of previous (idx - 1) msg to remove overlap + msgs[idx - 1]["data"]["end"] = on_tick + prev_off_tick = off_tick + + return self + + def resolve_pedal(self): + """Remove pedal_msgs by extending note offsets and resolving overlaps""" + # Organize note messages by channel + note_msgs_c = defaultdict(list) + for msg in self.note_msgs: + _channel = msg["channel"] + note_msgs_c[_channel].append(msg) + + # We can modify notes by reference as they are dictionaries + channel_to_pedal_intervals = self._build_pedal_intervals() + for channel, msgs in note_msgs_c.items(): + for msg in msgs: + note_end_tick = msg["data"]["end"] + for pedal_interval in channel_to_pedal_intervals[channel]: + pedal_start, pedal_end = pedal_interval + if pedal_start < note_end_tick < pedal_end: + msg["data"]["end"] = pedal_end + break + + self.pedal_msgs = [] + return self._resolve_overlaps() + # TODO: # - Add remove drums (aka remove channel 9) pre-processing # - Add similar method for removing specific programs diff --git a/aria/tokenizer/tokenizer.py b/aria/tokenizer/tokenizer.py index 71a57e9..6fe9000 100644 --- a/aria/tokenizer/tokenizer.py +++ b/aria/tokenizer/tokenizer.py @@ -14,6 +14,12 @@ from aria.config import load_config from aria.data.midi import get_duration_ms +# TODO: +# - Add a warning when tokenizing degenerate MIDI files. e.g. if we have two +# pianos with overlapping notes, this could potentially cause an issue for +# some tokenizers. I'm not sure if this will cause an issue AbsTokenizer +# however it might with some others such as the AmtTokenizer. + class Tokenizer: """Abstract Tokenizer class for tokenizing midi_dict objects. diff --git a/tests/test_data.py b/tests/test_data.py index 7ab9f62..b8b059e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -45,6 +45,15 @@ def get_short_seq(): ] +class TestMidiDict(unittest.TestCase): + def test_resolve_pedal(self): + midi_dict = MidiDict.from_midi("tests/test_data/maestro.mid") + midi_dict.resolve_pedal() + self.assertListEqual(midi_dict.pedal_msgs, []) + mid = midi_dict.to_midi() + mid.save("tests/test_results/maestro_npedal.mid") + + class TestMidiDataset(unittest.TestCase): def test_build(self): dataset = datasets.MidiDataset.build(