From 48a2efaf8eb590042cc195e085e790e19bef2245 Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 18 Jul 2024 14:22:19 +0100 Subject: [PATCH] Add tick_to_ms function to MidiDict (#113) --- aria/data/datasets.py | 4 ++-- aria/data/midi.py | 10 +++++++++- aria/sample.py | 5 +++++ config/config.json | 2 +- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/aria/data/datasets.py b/aria/data/datasets.py index b8699cd..05fc95f 100644 --- a/aria/data/datasets.py +++ b/aria/data/datasets.py @@ -668,7 +668,7 @@ def get_seqs( ) midi_dict_iter = [_ for _ in midi_dict_iter] - with Pool(num_proc // 2 if num_proc > 1 else 1) as pool: + with Pool(16) as pool: results = pool.imap( functools.partial(_get_seqs, _tokenizer=tokenizer), midi_dict_iter ) @@ -1184,7 +1184,7 @@ def _build_epoch(_save_path, _midi_dataset): _idx += 1 if _idx % 250 == 0: - logger.info(f"finished processing {_idx}") + logger.info(f"Finished processing {_idx}") logger = setup_logger() assert max_seq_len > 0, "max_seq_len must be greater than 0" diff --git a/aria/data/midi.py b/aria/data/midi.py index 3c36bec..5ba92c3 100644 --- a/aria/data/midi.py +++ b/aria/data/midi.py @@ -191,6 +191,14 @@ def calculate_hash(self): json.dumps(msg_dict_to_hash, sort_keys=True).encode() ).hexdigest() + def tick_to_ms(self, tick: int): + return get_duration_ms( + start_tick=0, + end_tick=tick, + tempo_msgs=self.tempo_msgs, + ticks_per_beat=self.ticks_per_beat, + ) + def _build_pedal_intervals(self): """Returns pedal-on intervals for each channel.""" self.pedal_msgs.sort(key=lambda msg: msg["tick"]) @@ -726,7 +734,7 @@ def meta_maestro_json( mid: mido.MidiFile, msg_data: dict, composer_names: list, form_names: list ): if os.path.isfile("maestro.json") is False: - print("maestro.json not found") + print("MAESTRO metadata function enabled but ./maestro.json not found.") return {} file_name = pathlib.Path(mid.filename).name diff --git a/aria/sample.py b/aria/sample.py index 819581d..f39217b 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -142,6 +142,11 @@ def greedy_sample( ), ) + if tokenizer.name == "separated_abs": + logits[:, tokenizer.tok_to_id[tokenizer.inst_start_tok]] = float( + "-inf" + ) + if temperature > 0.0: probs = torch.softmax(logits / temperature, dim=-1) next_token_ids = sample_top_p(probs, top_p).flatten() diff --git a/config/config.json b/config/config.json index f8c4943..6905796 100644 --- a/config/config.json +++ b/config/config.json @@ -78,7 +78,7 @@ } }, "maestro_json": { - "run": true, + "run": false, "args": { "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"], "form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"]