forked from pytorch/serve
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdlrm_handler.py
161 lines (123 loc) · 5.58 KB
/
dlrm_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
Handler for Torchrec DLRM based recommendation system
"""
import json
import logging
import os
from abc import ABC
import torch
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
class TorchRecDLRMHandler(BaseHandler, ABC):
"""
Handler for TorchRec DLRM example
"""
def initialize(self, context):
"""Initialize function loads the model.pt file and initialized the model object.
This version creates and initialized the model on cpu fist and transfers to gpu in a second step to prevent GPU OOM.
Args:
context (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
Raises:
RuntimeError: Raises the Runtime error when the model.py is missing
"""
properties = context.system_properties
# Set device to cpu to prevent GPU OOM errors
self.device = "cpu"
self.manifest = context.manifest
model_dir = properties.get("model_dir")
model_pt_path = None
if "serializedFile" in self.manifest["model"]:
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = os.path.join(model_dir, serialized_file)
# model def file
model_file = self.manifest["model"].get("modelFile", "")
if not model_file:
raise RuntimeError("model.py not specified")
logger.debug("Loading eager model")
self.model = self._load_pickled_model(model_dir, model_file, model_pt_path)
self.map_location = (
"cuda"
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else self.map_location
)
self.model.to(self.device)
self.model.eval()
logger.debug("Model file %s loaded successfully", model_pt_path)
self.initialized = True
def preprocess(self, data):
"""
The input values for the DLRM model are twofold. There is a dense part and a sparse part.
The sparse part consists of a list of ids where each entry can consist of zero, one or multiple ids.
Due to the inconsistency in elements, the sparse part is represented by the KeyJaggedTensor class provided by TorchRec.
Args:
data (str): The input data is in the form of a string
Returns:
Tuple of:
(Tensor): Dense features
(KeyJaggedTensor): Sparse features
"""
float_features, id_list_features_lengths, id_list_features_values = [], [], []
for row in data:
input = row.get("data") or row.get("body")
if not isinstance(input, dict):
input = json.loads(input)
# This is the dense feature part
assert "float_features" in input
# The sparse input consists of a length vector and the values.
# The length vector contains the number of elements which are part fo the same entry in the linear list provided as input.
assert "id_list_features.lengths" in input
assert "id_list_features.values" in input
float_features.append(input["float_features"])
id_list_features_lengths.extend(input["id_list_features.lengths"])
id_list_features_values.append(input["id_list_features.values"])
# Reformat the values input for KeyedJaggedTensor
id_list_features_values = torch.FloatTensor(id_list_features_values)
id_list_features_values = torch.transpose(id_list_features_values, 0, 1)
id_list_features_values = [value for value in id_list_features_values]
# Dense and Sparse Features for DLRM model
dense_features = torch.FloatTensor(float_features)
sparse_features = KeyedJaggedTensor(
keys=DEFAULT_CAT_NAMES,
lengths=torch.LongTensor(id_list_features_lengths),
values=torch.cat(id_list_features_values),
)
return dense_features, sparse_features
def inference(self, data):
"""
The inference call moves the elements of the tuple onto the device and calls the model
Args:
data (torch tensor): The data is in the form of Torch Tensor
whose shape should match that of the
Model Input shape.
Returns:
(Torch Tensor): The predicted response from the model is returned
in this function.
"""
with torch.no_grad():
data = map(lambda x: x.to(self.device), data)
results = self.model(*data)
return results
def postprocess(self, data):
"""
The post process function converts the prediction response into a
Torchserve compatible format
Args:
data (Torch Tensor): The data parameter comes from the prediction output
output_explain (None): Defaults to None.
Returns:
(list): Returns the response containing the predictions which consist of a single score per input entry.
"""
result = []
for item in data:
res = {}
res["score"] = item.squeeze().float().tolist()
result.append(res)
return result