-
Notifications
You must be signed in to change notification settings - Fork 91
/
Copy pathtest_toksequence.py
66 lines (50 loc) · 1.83 KB
/
test_toksequence.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Test methods."""
from collections.abc import Callable
from pathlib import Path
import pytest
from miditok import TSD, TokenizerConfig, TokSequence
from .utils_tests import MIDI_PATHS_MULTITRACK
def test_tokseq_concat():
ids1 = list(range(10))
ids2 = list(range(10, 20))
str1 = [str(id_ * 2) for id_ in ids1]
str2 = [str(id_ * 2) for id_ in ids2]
bytes1 = "".join(str1)
bytes2 = "".join(str2)
tokseq1 = TokSequence(ids=ids1, tokens=str1, bytes=bytes1)
tokseq2 = TokSequence(ids=ids2, tokens=str2, bytes=bytes2)
seq_concat = tokseq1 + tokseq2
assert seq_concat.ids == ids1 + ids2
assert seq_concat.tokens == str1 + str2
assert seq_concat.bytes == bytes1 + bytes2
def test_tokseq_slice_and_concat():
ids1 = list(range(20))
str1 = [str(id_ * 2) for id_ in ids1]
bytes1 = "".join(str1)
tokseq = TokSequence(ids=ids1, tokens=str1, bytes=bytes1)
subseq1 = tokseq[:10]
subseq2 = tokseq[10:]
assert subseq1.ids == ids1[:10]
assert subseq1.tokens == str1[:10]
assert subseq1.bytes == bytes1[:10]
assert subseq2.ids == ids1[10:]
assert subseq2.tokens == str1[10:]
assert subseq2.bytes == bytes1[10:]
tokseq_concat = subseq1 + subseq2
assert tokseq == tokseq_concat
@pytest.mark.parametrize("file_path", MIDI_PATHS_MULTITRACK, ids=lambda p: p.name)
def test_split_tokseq_per_bars_beats(file_path: Path, tokenization: Callable = TSD):
tokenizer = tokenization(TokenizerConfig(use_programs=True))
tokseq = tokenizer(file_path)
# Split per bars
seqs = tokseq.split_per_bars()
concat_seq = seqs.pop(0)
for seq in seqs:
concat_seq += seq
assert concat_seq == tokseq
# Split per beats
seqs = tokseq.split_per_beats()
concat_seq = seqs.pop(0)
for seq in seqs:
concat_seq += seq
assert concat_seq == tokseq