forked from pytorch/serve
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_onnx.py
71 lines (55 loc) · 2.09 KB
/
test_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import subprocess
import torch
import torch.onnx
class ToyModel(torch.nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.linear1 = torch.nn.Linear(1, 1)
self.linear2 = torch.nn.Linear(1, 1)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
# For a custom model you still need to manually author your converter, as far as I can tell there isn't a nice out of the box that exists
def test_convert_to_onnx():
model = ToyModel()
dummy_input = torch.randn(1, 1)
model_path = "linear.onnx"
# set the model to inference mode
model.eval()
# Let's create a dummy input tensor
# Export the model
torch.onnx.export(
model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
model_path, # where to save the model
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=["modelInput"], # the model's input names
output_names=["modelOutput"], # the model's output names
dynamic_axes={
"modelInput": {0: "batch_size"}, # variable length axes
"modelOutput": {0: "batch_size"},
},
)
def test_model_packaging_and_start():
subprocess.run("mkdir model_store", shell=True)
subprocess.run(
"torch-model-archiver -f --model-name onnx --version 1.0 --serialized-file linear.onnx --export-path model_store --handler onnx_handler.py",
shell=True,
check=True,
)
def test_model_start():
subprocess.run(
"torchserve --start --ncs --model-store model_store --models onnx.mar",
shell=True,
check=True,
)
def test_inference():
subprocess.run(
"curl -X POST http://127.0.0.1:8080/predictions/onnx --data-binary '1'",
shell=True,
)
def test_stop():
subprocess.run("torchserve --stop", shell=True, check=True)