Skip to content

Commit

Permalink
Add resolve_pedal method to MidiDict (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad authored Feb 22, 2024
1 parent a8989a5 commit 56b66ea
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 0 deletions.
80 changes: 80 additions & 0 deletions aria/data/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import logging
import pathlib
import copy
import mido

from collections import defaultdict
Expand Down Expand Up @@ -188,6 +189,85 @@ def calculate_hash(self):
json.dumps(msg_dict_to_hash, sort_keys=True).encode()
).hexdigest()

def _build_pedal_intervals(self):
"""Returns pedal-on intervals for each channel."""
self.pedal_msgs.sort(key=lambda msg: msg["tick"])
channel_to_pedal_intervals = defaultdict(list)
pedal_status = {}

for pedal_msg in self.pedal_msgs:
tick = pedal_msg["tick"]
channel = pedal_msg["channel"]
data = pedal_msg["data"]

if data == 1 and pedal_status.get(channel, None) is None:
pedal_status[channel] = tick
elif data == 0 and pedal_status.get(channel, None) is not None:
# Close pedal interval
_start_tick = pedal_status[channel]
_end_tick = tick
channel_to_pedal_intervals[channel].append(
[_start_tick, _end_tick]
)
del pedal_status[channel]

# Close all unclosed pedals at end of track
final_tick = self.note_msgs[-1]["data"]["end"]
for channel, start_tick in pedal_status.items():
channel_to_pedal_intervals[channel].append([start_tick, final_tick])

return channel_to_pedal_intervals

def _resolve_overlaps(self):
"""Resolves overlaps between notes on the same channel, this will only
change anything if pedal intervals have been resolved."""

# Organize notes by channel and pitch
note_msgs_c = defaultdict(lambda: defaultdict(list))
for msg in self.note_msgs:
_channel = msg["channel"]
_pitch = msg["data"]["pitch"]
note_msgs_c[_channel][_pitch].append(msg)

# We can modify notes by reference as they are dictionaries
for channel, msgs_by_pitch in note_msgs_c.items():
for pitch, msgs in msgs_by_pitch.items():
msgs.sort(
key=lambda msg: (msg["data"]["start"], msg["data"]["end"])
)
prev_off_tick = -1
for idx, msg in enumerate(msgs):
on_tick = msg["data"]["start"]
off_tick = msg["data"]["end"]
if prev_off_tick > on_tick:
# Adjust end of previous (idx - 1) msg to remove overlap
msgs[idx - 1]["data"]["end"] = on_tick
prev_off_tick = off_tick

return self

def resolve_pedal(self):
"""Remove pedal_msgs by extending note offsets and resolving overlaps"""
# Organize note messages by channel
note_msgs_c = defaultdict(list)
for msg in self.note_msgs:
_channel = msg["channel"]
note_msgs_c[_channel].append(msg)

# We can modify notes by reference as they are dictionaries
channel_to_pedal_intervals = self._build_pedal_intervals()
for channel, msgs in note_msgs_c.items():
for msg in msgs:
note_end_tick = msg["data"]["end"]
for pedal_interval in channel_to_pedal_intervals[channel]:
pedal_start, pedal_end = pedal_interval
if pedal_start < note_end_tick < pedal_end:
msg["data"]["end"] = pedal_end
break

self.pedal_msgs = []
return self._resolve_overlaps()

# TODO:
# - Add remove drums (aka remove channel 9) pre-processing
# - Add similar method for removing specific programs
Expand Down
6 changes: 6 additions & 0 deletions aria/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
from aria.config import load_config
from aria.data.midi import get_duration_ms

# TODO:
# - Add a warning when tokenizing degenerate MIDI files. e.g. if we have two
# pianos with overlapping notes, this could potentially cause an issue for
# some tokenizers. I'm not sure if this will cause an issue AbsTokenizer
# however it might with some others such as the AmtTokenizer.


class Tokenizer:
"""Abstract Tokenizer class for tokenizing midi_dict objects.
Expand Down
9 changes: 9 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ def get_short_seq():
]


class TestMidiDict(unittest.TestCase):
def test_resolve_pedal(self):
midi_dict = MidiDict.from_midi("tests/test_data/maestro.mid")
midi_dict.resolve_pedal()
self.assertListEqual(midi_dict.pedal_msgs, [])
mid = midi_dict.to_midi()
mid.save("tests/test_results/maestro_npedal.mid")


class TestMidiDataset(unittest.TestCase):
def test_build(self):
dataset = datasets.MidiDataset.build(
Expand Down

0 comments on commit 56b66ea

Please sign in to comment.