-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
70 lines (53 loc) · 2.36 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argh
import yaml
from pathlib import Path
from typing import Optional
from src.data_class.context import Context
from src.utils.pytorch import setup_torch
from src.utils.formatting import syntax_print
from src.executable.preprocess import preprocess_data
from src.executable.train import train_model
def get_context(config_path: Optional[str] = None) -> Context:
'''
Loads context from provided config. Otherwise loads default.
'''
if config_path is not None:
config = Path(config_path)
assert config.suffix == '.yaml', 'Expected a .yaml file for config_path'
ctx = Context(config_path=config)
else:
ctx = Context()
return ctx
@argh.arg('-i', '--in_path', default='data.txt', help='Path for data to be preprocessed')
@argh.arg('-o', '--out_path', default='out.tensor', help='Path for data to be preprocessed')
def preprocess(in_path: str = 'data.txt', out_path: str = "out.tensor"):
'''
Processing original data into `out.tensor`
'''
preprocess_data(in_path, out_path)
@argh.arg('-c', '--config_path', default='configs/small.yaml', help='Path for the config file')
@argh.arg('-s', '--steps', default=0, help='Number of steps to take. 0 = infinite')
@argh.arg('-l', '--load_model', default=False, help='Whether to load an existing model checkpoint')
def train(config_path: Optional[str] = None, steps: int = 0, load_model: bool = False):
'''
Trains a model given the config file.
'''
ctx = get_context(config_path)
setup_torch(0)
dump = yaml.dump(ctx.serialize(), indent=4)
syntax_print(dump, "yaml", title="Config")
train_model(ctx, steps, load_model)
@argh.arg('-g', '--generated_tokens', default='20', help='Number of tokens to be generated after prompt')
@argh.arg('-t', '--temp', default='0.2', help='Temperature of the model.\nlower = consistency\nhigher = "creativity"')
@argh.arg('-c', '--config_path', help='Path for the config file')
def inference(generated_tokens: int = 20, temp: float = 0.2, config_path: str = None):
'''
Runs inference of input data on desired model
'''
assert config_path is not None, "Expected Config file!"
ctx = get_context(config_path)
# inference_cli(ctx, float(temp), int(generated_tokens))
if __name__ == '__main__':
parser = argh.ArghParser()
parser.add_commands([preprocess, train, inference])
parser.dispatch()