Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tick_to_ms function to MidiDict #113

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def get_seqs(
)
midi_dict_iter = [_ for _ in midi_dict_iter]

with Pool(num_proc // 2 if num_proc > 1 else 1) as pool:
with Pool(16) as pool:
results = pool.imap(
functools.partial(_get_seqs, _tokenizer=tokenizer), midi_dict_iter
)
Expand Down Expand Up @@ -1184,7 +1184,7 @@ def _build_epoch(_save_path, _midi_dataset):

_idx += 1
if _idx % 250 == 0:
logger.info(f"finished processing {_idx}")
logger.info(f"Finished processing {_idx}")

logger = setup_logger()
assert max_seq_len > 0, "max_seq_len must be greater than 0"
Expand Down
10 changes: 9 additions & 1 deletion aria/data/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,14 @@ def calculate_hash(self):
json.dumps(msg_dict_to_hash, sort_keys=True).encode()
).hexdigest()

def tick_to_ms(self, tick: int):
return get_duration_ms(
start_tick=0,
end_tick=tick,
tempo_msgs=self.tempo_msgs,
ticks_per_beat=self.ticks_per_beat,
)

def _build_pedal_intervals(self):
"""Returns pedal-on intervals for each channel."""
self.pedal_msgs.sort(key=lambda msg: msg["tick"])
Expand Down Expand Up @@ -726,7 +734,7 @@ def meta_maestro_json(
mid: mido.MidiFile, msg_data: dict, composer_names: list, form_names: list
):
if os.path.isfile("maestro.json") is False:
print("maestro.json not found")
print("MAESTRO metadata function enabled but ./maestro.json not found.")
return {}

file_name = pathlib.Path(mid.filename).name
Expand Down
5 changes: 5 additions & 0 deletions aria/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def greedy_sample(
),
)

if tokenizer.name == "separated_abs":
logits[:, tokenizer.tok_to_id[tokenizer.inst_start_tok]] = float(
"-inf"
)

if temperature > 0.0:
probs = torch.softmax(logits / temperature, dim=-1)
next_token_ids = sample_top_p(probs, top_p).flatten()
Expand Down
2 changes: 1 addition & 1 deletion config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
}
},
"maestro_json": {
"run": true,
"run": false,
"args": {
"composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"],
"form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"]
Expand Down
Loading