-
Notifications
You must be signed in to change notification settings - Fork 1
/
optimize.py
61 lines (49 loc) · 1.28 KB
/
optimize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from pathlib import Path
import typer
from src.optimize_utils import _profile, _fuse, _quantize, _prune
from src.utils import setup_logging
app = typer.Typer()
@app.command()
def profile(
exp_path: Path,
checkpoint: int = 3,
iterations: int = 100,
precision: str = "int8",
prune_amount: float = 0.3,
device: str = "cpu"
):
"""
Profile model latency given an input yaml file.
"""
return _profile(exp_path, checkpoint, iterations, precision, prune_amount, device)
@app.command()
def fuse(exp_path: Path, checkpoint: int = 0, device: str = "cpu"):
"""
Convert model to torchscript for jit.
"""
return _fuse(exp_path, checkpoint, device)
@app.command()
def quantize(
exp_path: Path,
checkpoint: int = 0,
precision: str = "int8",
device: str = "cpu"
):
"""
Post-training quantization of model with various precisions.
"""
return _quantize(exp_path, checkpoint, precision, device)
@app.command()
def prune(
exp_path: Path,
checkpoint: int = 0,
prune_amount: float = 0.3,
device: str = "cpu"
):
"""
Prune model connections to sparse representation.
"""
return _prune(exp_path, checkpoint, prune_amount, device)
if __name__ == "__main__":
setup_logging()
app()