-
Notifications
You must be signed in to change notification settings - Fork 91
/
Copy pathtest_saving_loading_config.py
120 lines (100 loc) · 3.87 KB
/
test_saving_loading_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Tests for the saving/loading methods of tokenizers."""
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING, Any
import miditok
import pytest
from .utils_tests import ALL_TOKENIZATIONS, MAX_BAR_EMBEDDING, MIDI_PATHS_MULTITRACK
if TYPE_CHECKING:
from pathlib import Path
ADDITIONAL_TOKENS_TEST = {
"use_chords": False, # False to speed up tests
"use_rests": True,
"use_tempos": True,
"use_time_signatures": True,
"use_programs": False,
"beat_res_rest": {(0, 16): 4},
"num_tempos": 32,
"tempo_range": (40, 250),
"base_tokenizer": "TSD",
"use_microtiming": True,
"ticks_per_quarter": 480,
"max_microtiming_shift": 0.25,
"num_microtiming_bins": 110,
}
TOK_PARAMS_MULTITRACK = []
tokenizations_non_one_stream = [
"TSD",
"REMI",
"MIDILike",
"Structured",
"CPWord",
"Octuple",
]
for tokenization_ in ALL_TOKENIZATIONS:
params_ = {"use_programs": True}
if tokenization_ == "MMM":
params_["base_tokenizer"] = "TSD"
elif tokenization_ in ["Octuple", "MuMIDI"]:
params_["max_bar_embedding"] = MAX_BAR_EMBEDDING
elif tokenization_ in ["PerTok"]:
params_["use_microtiming"] = True
params_["ticks_per_quarter"] = 220
params_["max_microtiming_shift"] = 0.25
params_["num_microtiming_bins"] = 110
TOK_PARAMS_MULTITRACK.append((tokenization_, params_))
if tokenization_ in tokenizations_non_one_stream:
params_tmp = deepcopy(params_)
params_tmp["one_token_stream_for_programs"] = False
# Disable tempos for Octuple with one_token_stream_for_programs, as tempos are
# carried by note tokens
if tokenization_ == "Octuple":
params_tmp["use_tempos"] = False
TOK_PARAMS_MULTITRACK.append((tokenization_, params_tmp))
@pytest.mark.parametrize("tokenization", ALL_TOKENIZATIONS)
def test_saving_loading_tokenizer_config(tokenization: str, tmp_path: Path):
config1 = miditok.TokenizerConfig()
config1.save_to_json(tmp_path / f"tok_conf_{tokenization}.json")
config2 = miditok.TokenizerConfig.load_from_json(
tmp_path / f"tok_conf_{tokenization}.json"
)
assert config1 == config2
config1.pitch_range = (0, 777)
assert config1 != config2
@pytest.mark.parametrize("tokenization", ALL_TOKENIZATIONS)
def test_saving_loading_tokenizer(tokenization: str, tmp_path: Path):
r"""
Make sure saving and loading end with the identical tokenizer.
Create a tokenizer, save its config, and load it back.
If all went well the reloaded tokenizer should be identical.
"""
tokenizer_config = miditok.TokenizerConfig(**ADDITIONAL_TOKENS_TEST)
tokenizer: miditok.MusicTokenizer = getattr(miditok, tokenization)(
tokenizer_config=tokenizer_config
)
tokenizer.save(tmp_path / f"{tokenization}.txt")
tokenizer2: miditok.MusicTokenizer = getattr(miditok, tokenization)(
params=tmp_path / f"{tokenization}.txt"
)
assert tokenizer == tokenizer2
if tokenization == "Octuple":
tokenizer.vocab[0]["PAD_None"] = 8
assert tokenizer != tokenizer2
@pytest.mark.parametrize("file_path", MIDI_PATHS_MULTITRACK[:3], ids=lambda p: p.name)
@pytest.mark.parametrize("tok_params_set", TOK_PARAMS_MULTITRACK)
def test_multitrack_midi_to_tokens_to_midi(
file_path: Path,
tok_params_set: tuple[str, dict[str, Any]],
tmp_path: Path,
):
# Create tokenizer
tokenization, params = tok_params_set
tokenizer: miditok.MusicTokenizer = getattr(miditok, tokenization)(
tokenizer_config=miditok.TokenizerConfig(**params)
)
# Tokenize the file, save tokens and load them back
tokens = tokenizer(file_path)
tokenizer.save_tokens(tokens, tmp_path / "tokens.json")
tokens_loaded = tokenizer.load_tokens(tmp_path / "tokens.json")
# Assert tokens are the same
assert tokens == tokens_loaded