forked from viam-modules/torch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_local.py
100 lines (86 loc) · 3.57 KB
/
test_local.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import unittest
from src.model.model import TorchModel
from src.model_inspector.inspector import Inspector
from viam.services.mlmodel import Metadata
from torchvision.models.detection import FasterRCNN
from torchvision.models import MobileNet_V2_Weights
import torchvision
import os
from torchvision.models.detection.rpn import AnchorGenerator
class TestInputs(unittest.TestCase):
@staticmethod
def load_resnet_weights():
return TorchModel(
path_to_serialized_file=os.path.join(
"examples", "resnet_18", "resnet18-f37072fd.pth"
)
)
@staticmethod
def load_standalone_resnet():
return TorchModel(
path_to_serialized_file=os.path.join(
"examples", "resnet_18_scripted", "resnet-18.pt"
)
)
@staticmethod
def load_detector_from_torchvision():
backbone = torchvision.models.mobilenet_v2(
weights=MobileNet_V2_Weights.DEFAULT
).features
backbone.out_channels = 1280
anchor_generator = AnchorGenerator(
sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)
)
roi_pooler = torchvision.ops.MultiScaleRoIAlign(
featmap_names=["0"], output_size=7, sampling_ratio=2
)
model = FasterRCNN(
backbone,
num_classes=2,
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler,
)
model.eval()
return TorchModel(path_to_serialized_file=None, model=model)
def __init__(self, methodName: str = "runTest") -> None:
super().__init__(methodName)
def test_error_loading_weights(self):
with self.assertRaises(TypeError):
_ = self.load_resnet_weights()
def test_resnet_metadata(self):
model: TorchModel = self.load_standalone_resnet()
x = torch.ones(3, 300, 400).unsqueeze(0)
output = model.infer({"any_input_name_you_want": x.numpy()})
self.assertIsInstance(output, dict)
inspector = Inspector(model)
metadata: Metadata = inspector.find_metadata(label_path="fake_path")
for output_name, output in output.items():
output_checked = False
for output_info in metadata.output_info:
if output_info.name == output_name:
self.assertEqual(
output.shape[0], output_info.shape[0]
) # check at index 0 because one is (1000,) and the other is [1000]
output_checked = True
print(f"Checked {output_name} ")
self.assertTrue(output_checked)
def test_detector_metadata(self):
model: TorchModel = self.load_detector_from_torchvision()
x = torch.ones(3, 300, 400).unsqueeze(0)
output = model.infer({"any_input_name_you_want": x.numpy()})
self.assertIsInstance(output, dict)
inspector = Inspector(model)
metadata: Metadata = inspector.find_metadata(label_path="fake_path")
for output_name, output in output.items():
output_checked = False
for output_info in metadata.output_info:
if output_info.name == output_name:
self.assertEqual(
output.shape[0], output_info.shape[0]
) # check at index 0 because one is (1000,) and the other is [1000]
output_checked = True
print(f"Checked {output_name} ")
self.assertTrue(output_checked)
if __name__ == "__main__":
unittest.main()