-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathgenerate.py
80 lines (65 loc) · 3.05 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import argparse
from typing import Dict
from commu.midi_generator.generate_pipeline import MidiGenerationPipeline
from commu.preprocessor.utils import constants
def parse_args() -> Dict[str, argparse.ArgumentParser]:
model_arg_parser = argparse.ArgumentParser(description="Model Arguments")
input_arg_parser = argparse.ArgumentParser(description="Input Arguments")
# Model Arguments
model_arg_parser.add_argument("--checkpoint_dir", type=str)
# Input Arguments
input_arg_parser.add_argument("--output_dir", type=str, required=True)
## Input meta
input_arg_parser.add_argument("--bpm", type=int)
input_arg_parser.add_argument("--audio_key", type=str, choices=list(constants.KEY_MAP.keys()))
input_arg_parser.add_argument("--time_signature", type=str, choices=list(constants.TIME_SIG_MAP.keys()))
input_arg_parser.add_argument("--pitch_range", type=str, choices=list(constants.PITCH_RANGE_MAP.keys()))
input_arg_parser.add_argument("--num_measures", type=float)
input_arg_parser.add_argument(
"--inst", type=str, choices=list(constants.INST_MAP.keys()),
)
input_arg_parser.add_argument(
"--genre", type=str, default="cinematic", choices=list(constants.GENRE_MAP.keys())
)
input_arg_parser.add_argument(
"--track_role", type=str, choices=list(constants.TRACK_ROLE_MAP.keys())
)
input_arg_parser.add_argument(
"--rhythm", type=str, default="standard", choices=list(constants.RHYTHM_MAP.keys())
)
input_arg_parser.add_argument("--min_velocity", type=int, choices=range(1, 128))
input_arg_parser.add_argument("--max_velocity", type=int, choices=range(1, 128))
input_arg_parser.add_argument(
"--chord_progression", type=str, help='Chord progression ex) C-C-E-E-G-G ...'
)
# Inference 시 필요 정보
input_arg_parser.add_argument("--num_generate", type=int)
input_arg_parser.add_argument("--top_k", type=int, default=32)
input_arg_parser.add_argument("--temperature", type=float, default=0.95)
arg_dict = {
"model_args": model_arg_parser,
"input_args": input_arg_parser
}
return arg_dict
def main(model_args: argparse.Namespace, input_args: argparse.Namespace):
pipeline = MidiGenerationPipeline()
pipeline.initialize_model(vars(model_args))
pipeline.initialize_generation()
inference_cfg = pipeline.model_initialize_task.inference_cfg
model = pipeline.model_initialize_task.execute()
encoded_meta = pipeline.preprocess_task.execute(vars(input_args))
input_data = pipeline.preprocess_task.input_data
pipeline.inference_task(
model=model,
input_data=input_data,
inference_cfg=inference_cfg
)
sequences = pipeline.inference_task.execute(encoded_meta)
pipeline.postprocess_task(input_data=input_data)
pipeline.postprocess_task.execute(
sequences=sequences,
)
if __name__ == "__main__":
model_args, _ = parse_args()["model_args"].parse_known_args()
input_args, _ = parse_args()["input_args"].parse_known_args()
main(model_args, input_args)