-
Notifications
You must be signed in to change notification settings - Fork 91
/
Copy pathtest_attribute_controls.py
150 lines (132 loc) · 4.62 KB
/
test_attribute_controls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""Test methods."""
from __future__ import annotations
from random import seed
from typing import TYPE_CHECKING, Any, Literal
import miditok
import pytest
from miditok.attribute_controls import create_random_ac_indexes
from symusic import Score
from .utils_tests import (
BARS_RANDOM_RATIO_RANGE,
MIDI_PATHS_ALL,
MIDI_PATHS_ONE_TRACK,
SEED,
TRACKS_RANDOM_RATIO_RANGE,
check_control_tokens_are_well_inserted,
)
if TYPE_CHECKING:
from collections.abc import Sequence
from pathlib import Path
TOKENIZATIONS = ["REMI", "TSD", "MMM"]
TOKENIZER_PARAMS = {
"pitch_range": (21, 109),
"beat_res": {(0, 4): 8, (4, 12): 4},
"num_velocities": 32,
"special_tokens": ["PAD", "BOS", "EOS", "MASK"],
"use_chords": True,
"use_rests": False,
"use_tempos": True,
"use_time_signatures": True,
"use_programs": False,
"num_tempos": 32, # number of tempo bins
"tempo_range": (40, 250), # (min, max)
"base_tokenizer": "REMI",
"ac_polyphony_track": True,
"ac_polyphony_bar": True,
"ac_pitch_class_bar": True,
"ac_note_density_track": True,
"ac_note_density_bar": True,
"ac_note_duration_bar": True,
"ac_note_duration_track": True,
"ac_repetition_track": True,
}
VOCAB_SIZE = 2000
NUM_ADDITIONAL_TOKENS_SECOND_TRAINING = 400
WORDPIECE_MAX_INPUT_CHARS_PER_WORD_BAR = 500 # higher than default MidiTok values
WORDPIECE_MAX_INPUT_CHARS_PER_WORD_BEAT = 150
@pytest.mark.parametrize("file_path", MIDI_PATHS_ALL, ids=lambda path: path.name)
@pytest.mark.parametrize("tokenization", TOKENIZATIONS)
@pytest.mark.parametrize(
"random_tracks_idx",
[False, True],
ids=lambda r: "rand_tracks" if r else "all_tracks",
)
@pytest.mark.parametrize(
"random_bars_idx", [False, True], ids=lambda r: "rand_bars" if r else "all_bars"
)
def test_attribute_controls_computation(
file_path: Path,
tokenization: str,
random_tracks_idx: bool,
random_bars_idx: bool,
tokenizer_params: dict[str, Any] | None = None,
) -> None:
if tokenizer_params is None:
tokenizer_params = TOKENIZER_PARAMS
tokenizer: miditok.MusicTokenizer = getattr(miditok, tokenization)(
tokenizer_config=miditok.TokenizerConfig(**tokenizer_params)
)
score = Score(file_path)
score = tokenizer.preprocess_score(score)
# Set attribute controls indexes
seed(SEED)
tracks_idx_ratio = (0, 1) if random_tracks_idx else 1
bars_idx_ratio = (0, 1) if random_bars_idx else 1
ac_indexes = create_random_ac_indexes(
score,
tokenizer.attribute_controls,
tracks_idx_ratio,
bars_idx_ratio,
)
# Tokenize Score with attribute controls injected
tokens = tokenizer.encode(
score, no_preprocess_score=True, attribute_controls_indexes=ac_indexes
)
# Check for errors
injection_errors = check_control_tokens_are_well_inserted(
tokenizer, score, tokens, ac_indexes
)
assert len(injection_errors) == 0
@pytest.mark.parametrize("tokenization", TOKENIZATIONS)
@pytest.mark.parametrize("model", ["BPE"])
@pytest.mark.parametrize(
"encode_ids_split",
["no", "bar", "beat"],
ids=lambda s: f"{s}_split",
)
def test_tokenizer_training_and_encoding_decoding(
tokenization: str,
model: Literal["BPE", "Unigram", "WordPiece"],
encode_ids_split: Literal["bar", "beat", "no"],
files_paths: Sequence[Path] = MIDI_PATHS_ONE_TRACK,
vocab_size: int = VOCAB_SIZE,
):
r"""
Train a tokenizer to make sure the training iterator works with attribute controls.
:param files_paths: list of paths of music files to use for the tests.
:param encode_ids_split: type of token ids split before encoding/training.
"""
if encode_ids_split == "no" and model == "WordPiece":
pytest.skip(f"Skipping training with {model} and {encode_ids_split} split")
# Creates tokenizers
TOKENIZER_PARAMS["encode_ids_split"] = encode_ids_split
tokenizer: miditok.MusicTokenizer = getattr(miditok, tokenization)(
tokenizer_config=miditok.TokenizerConfig(**TOKENIZER_PARAMS)
)
training_kwargs = {}
if model == "WordPiece":
training_kwargs["max_input_chars_per_word"] = (
WORDPIECE_MAX_INPUT_CHARS_PER_WORD_BAR
if encode_ids_split == "bar"
else WORDPIECE_MAX_INPUT_CHARS_PER_WORD_BEAT
)
# Train the tokenizer
training_iterator = miditok.TokTrainingIterator(
tokenizer, files_paths, TRACKS_RANDOM_RATIO_RANGE, BARS_RANDOM_RATIO_RANGE
)
tokenizer.train(
vocab_size=vocab_size + NUM_ADDITIONAL_TOKENS_SECOND_TRAINING,
model=model,
iterator=training_iterator,
**training_kwargs,
)