Skip to content

Commit

Permalink
Add tick_to_ms function to MidiDict (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad authored Jul 18, 2024
1 parent fc50c44 commit 48a2efa
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 4 deletions.
4 changes: 2 additions & 2 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def get_seqs(
)
midi_dict_iter = [_ for _ in midi_dict_iter]

with Pool(num_proc // 2 if num_proc > 1 else 1) as pool:
with Pool(16) as pool:
results = pool.imap(
functools.partial(_get_seqs, _tokenizer=tokenizer), midi_dict_iter
)
Expand Down Expand Up @@ -1184,7 +1184,7 @@ def _build_epoch(_save_path, _midi_dataset):

_idx += 1
if _idx % 250 == 0:
logger.info(f"finished processing {_idx}")
logger.info(f"Finished processing {_idx}")

logger = setup_logger()
assert max_seq_len > 0, "max_seq_len must be greater than 0"
Expand Down
10 changes: 9 additions & 1 deletion aria/data/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,14 @@ def calculate_hash(self):
json.dumps(msg_dict_to_hash, sort_keys=True).encode()
).hexdigest()

def tick_to_ms(self, tick: int):
return get_duration_ms(
start_tick=0,
end_tick=tick,
tempo_msgs=self.tempo_msgs,
ticks_per_beat=self.ticks_per_beat,
)

def _build_pedal_intervals(self):
"""Returns pedal-on intervals for each channel."""
self.pedal_msgs.sort(key=lambda msg: msg["tick"])
Expand Down Expand Up @@ -726,7 +734,7 @@ def meta_maestro_json(
mid: mido.MidiFile, msg_data: dict, composer_names: list, form_names: list
):
if os.path.isfile("maestro.json") is False:
print("maestro.json not found")
print("MAESTRO metadata function enabled but ./maestro.json not found.")
return {}

file_name = pathlib.Path(mid.filename).name
Expand Down
5 changes: 5 additions & 0 deletions aria/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def greedy_sample(
),
)

if tokenizer.name == "separated_abs":
logits[:, tokenizer.tok_to_id[tokenizer.inst_start_tok]] = float(
"-inf"
)

if temperature > 0.0:
probs = torch.softmax(logits / temperature, dim=-1)
next_token_ids = sample_top_p(probs, top_p).flatten()
Expand Down
2 changes: 1 addition & 1 deletion config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
}
},
"maestro_json": {
"run": true,
"run": false,
"args": {
"composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"],
"form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"]
Expand Down

0 comments on commit 48a2efa

Please sign in to comment.