Skip to content

Commit

Permalink
Add option to not remove preceding silence in AbsTokenizer (#110)
Browse files Browse the repository at this point in the history
* fix reqs

* add support for preceding silence
  • Loading branch information
loubbrad authored May 14, 2024
1 parent 30f9985 commit 4e4e99e
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions aria/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,20 @@ def tokenize(self, midi_dict: MidiDict, **kwargs):
required. For instance, in fine-tuning tokenizer you may want to insert
additional tokens. The default behavior is to call tokenize_midi_dict.
"""
return self._tokenize_midi_dict(midi_dict)
return self._tokenize_midi_dict(midi_dict, **kwargs)

def _detokenize_midi_dict(self, tokenized_seq: list):
"""Abstract method for de-tokenizing a sequence of tokens into a
MidiDict Object."""
raise NotImplementedError

def detokenize(self, tokenized_seq: list):
def detokenize(self, tokenized_seq: list, **kwargs):
"""Detokenizes a MidiDict object.
This function should be overridden if additional are required during
detokenization. The default behavior is to call detokenize_midi_dict.
"""
return self._detokenize_midi_dict(tokenized_seq)
return self._detokenize_midi_dict(tokenized_seq, **kwargs)

def export_data_aug(cls):
"""Abstract method for exporting a list of all data augmentation
Expand Down Expand Up @@ -411,7 +411,9 @@ def truncate_by_time(self, tokenized_seq: list, trunc_time_ms: int):

return tokenized_seq

def _tokenize_midi_dict(self, midi_dict: MidiDict):
def _tokenize_midi_dict(
self, midi_dict: MidiDict, remove_preceding_silence: bool = True
):
ticks_per_beat = midi_dict.ticks_per_beat
midi_dict.remove_instruments(self.config["ignore_instruments"])

Expand Down Expand Up @@ -450,9 +452,13 @@ def _tokenize_midi_dict(self, midi_dict: MidiDict):
prefix.insert(0, ("prefix", "genre", genre))
random.shuffle(prefix)

# NOTE: Any preceding silence is removed implicitly
tokenized_seq = []
initial_onset_tick = midi_dict.note_msgs[0]["data"]["start"]

if remove_preceding_silence is False:
initial_onset_tick = 0
else:
initial_onset_tick = midi_dict.note_msgs[0]["data"]["start"]

curr_time_since_onset = 0
for _, msg in enumerate(midi_dict.note_msgs):
# Extract msg data
Expand Down Expand Up @@ -543,12 +549,14 @@ def _detokenize_midi_dict(self, tokenized_seq: list):

# Add non-drum instrument_msgs, breaks at first note token
channel_idx = 0
curr_tick = 0
for idx, tok in enumerate(tokenized_seq):
if channel_idx == 9: # Skip channel reserved for drums
channel_idx += 1

if tok in self.special_tokens:
# Skip special tokens
if tok == self.time_tok:
curr_tick += self.abs_time_step
continue
elif (
tok[0] == "prefix"
Expand Down Expand Up @@ -590,7 +598,6 @@ def _detokenize_midi_dict(self, tokenized_seq: list):

# Note messages
note_msgs = []
curr_tick = 0
for tok_1, tok_2, tok_3 in zip(
tokenized_seq[start:],
tokenized_seq[start + 1 :],
Expand Down

0 comments on commit 4e4e99e

Please sign in to comment.