-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #80 from tnc-ca-geo/sagemaker-serverless-deployment
Sagemaker serverless deployment with torchserve
- Loading branch information
Showing
82 changed files
with
12,072 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Setup Instructions | ||
|
||
## Download weights and torchscript model | ||
|
||
``` | ||
cd animl-ml/api/megadetectorv5 | ||
aws s3 sync s3://animl-model-zoo/megadetectorv5/ models/megadetectorv5/ | ||
``` | ||
|
||
## Export yolov5 weights as torchscript model | ||
Size needs to be same as in mdv5_handler.py for good performance | ||
``` | ||
python yolov5/export.py --weights models/megadetectorv5/md_v5a.0.0.pt --img 640 640 --batch 1 | ||
``` | ||
this will create models/megadetectorv5/md_v5a.0.0.torchscript | ||
|
||
## Run model archiver | ||
|
||
``` | ||
torch-model-archiver --model-name mdv5 --version 1.0.0 --serialized-file models/megadetectorv5/md_v5a.0.0.torchscript --extra-files index_to_name.json --handler mdv5_handler.py | ||
mkdir -p model_store | ||
mv mdv5.mar model_store/megadetectorv5-yolov5-1-batch-640-640.mar | ||
``` | ||
|
||
The .mar file is what is served by torchserve. | ||
|
||
## Serve the torchscript model with torchserve | ||
|
||
``` | ||
bash docker_mdv5.sh | ||
``` | ||
|
||
## Return prediction in normalized coordinates with category integer and confidence score | ||
|
||
``` | ||
curl http://127.0.0.1:8080/predictions/mdv5 -T ../../input/sample-img-fox.jpg | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# dockerized version of the above, ran into dependency issue for termcolor | ||
# https://github.com/pytorch/serve/tree/master/docker | ||
docker run --rm -p 8080:8080 -p 8081:8081 -p 8082:8082 -p 7070:7070 -p 7071:7071 -v "$(pwd)":/app -it pytorch/torchserve:latest bash /app/serve_megadetect.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{ "1": "animal", "2": "person", "3": "vehicle" } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
"""Custom TorchServe model handler for YOLOv5 models. | ||
""" | ||
from ts.torch_handler.base_handler import BaseHandler | ||
import numpy as np | ||
import base64 | ||
import torch | ||
import torchvision.transforms as tf | ||
import torchvision | ||
import io | ||
from PIL import Image | ||
|
||
|
||
class ModelHandler(BaseHandler): | ||
""" | ||
A custom model handler implementation. | ||
""" | ||
|
||
img_size = 640 | ||
"""Image size (px). Images will be resized to this resolution before inference. | ||
""" | ||
|
||
def __init__(self): | ||
# call superclass initializer | ||
super().__init__() | ||
|
||
def preprocess(self, data): | ||
"""Converts input images to float tensors. | ||
Args: | ||
data (List): Input data from the request in the form of a list of image tensors. | ||
Returns: | ||
Tensor: single Tensor of shape [BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE] | ||
""" | ||
images = [] | ||
|
||
transform = tf.Compose([ | ||
tf.ToTensor(), | ||
tf.Resize((self.img_size, self.img_size)) | ||
]) | ||
|
||
# load images | ||
# taken from https://github.com/pytorch/serve/blob/master/ts/torch_handler/vision_handler.py | ||
|
||
# handle if images are given in base64, etc. | ||
for row in data: | ||
# Compat layer: normally the envelope should just return the data | ||
# directly, but older versions of Torchserve didn't have envelope. | ||
image = row.get("data") or row.get("body") | ||
if isinstance(image, str): | ||
# if the image is a string of bytesarray. | ||
image = base64.b64decode(image) | ||
|
||
# If the image is sent as bytesarray | ||
if isinstance(image, (bytearray, bytes)): | ||
image = Image.open(io.BytesIO(image)) | ||
else: | ||
# if the image is a list | ||
image = torch.FloatTensor(image) | ||
|
||
# force convert to tensor | ||
# and resize to [img_size, img_size] | ||
image = transform(image) | ||
|
||
images.append(image) | ||
|
||
# convert list of equal-size tensors to single stacked tensor | ||
# has shape BATCH_SIZE x 3 x IMG_SIZE x IMG_SIZE | ||
images_tensor = torch.stack(images).to(self.device) | ||
|
||
return images_tensor | ||
|
||
def postprocess(self, inference_output): | ||
# perform NMS (nonmax suppression) on model outputs | ||
pred = non_max_suppression(inference_output[0], conf_thres=.1, iou_thres=.45) | ||
|
||
# initialize empty list of detections for each image | ||
detections = [[] for _ in range(len(pred))] | ||
|
||
for i, image_detections in enumerate(pred): # axis 0: for each image | ||
for det in image_detections: # axis 1: for each detection | ||
# x1,y1,x2,y2 in normalized image coordinates (i.e. 0.0-1.0) | ||
xyxy = det[:4] / self.img_size | ||
# confidence value | ||
conf = det[4].item() | ||
# index of predicted class | ||
class_idx = int(det[5].item()) | ||
# get label of predicted class | ||
# if missing, then just return class idx | ||
label = self.mapping.get(str(class_idx), class_idx) | ||
|
||
detections[i].append({ | ||
"x1": xyxy[0].item(), | ||
"y1": xyxy[1].item(), | ||
"x2": xyxy[2].item(), | ||
"y2": xyxy[3].item(), | ||
"confidence": conf, | ||
"class": label | ||
}) | ||
|
||
# format each detection | ||
return detections | ||
|
||
|
||
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, | ||
labels=(), max_det=1000): | ||
"""Runs Non-Maximum Suppression (NMS) on inference results | ||
Returns: | ||
list of detections, on (n,6) tensor per image [xyxy, conf, cls] | ||
""" | ||
|
||
nc = prediction.shape[2] - 5 # number of classes | ||
xc = prediction[..., 4] > conf_thres # candidates | ||
|
||
# Checks | ||
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' | ||
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0' | ||
|
||
# Settings | ||
# (pixels) minimum and maximum box width and height | ||
min_wh, max_wh = 2, 4096 | ||
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() | ||
time_limit = 10.0 # seconds to quit after | ||
redundant = True # require redundant detections | ||
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) | ||
merge = False # use merge-NMS | ||
|
||
output = [torch.zeros((0, 6), device=prediction.device) | ||
] * prediction.shape[0] | ||
for xi, x in enumerate(prediction): # image index, image inference | ||
# Apply constraints | ||
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height | ||
x = x[xc[xi]] # confidence | ||
|
||
# Cat apriori labels if autolabelling | ||
if labels and len(labels[xi]): | ||
l = labels[xi] | ||
v = torch.zeros((len(l), nc + 5), device=x.device) | ||
v[:, :4] = l[:, 1:5] # box | ||
v[:, 4] = 1.0 # conf | ||
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls | ||
x = torch.cat((x, v), 0) | ||
|
||
# If none remain process next image | ||
if not x.shape[0]: | ||
continue | ||
|
||
# Compute conf | ||
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf | ||
|
||
# Box (center x, center y, width, height) to (x1, y1, x2, y2) | ||
box = xywh2xyxy(x[:, :4]) | ||
|
||
# Detections matrix nx6 (xyxy, conf, cls) | ||
if multi_label: | ||
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T | ||
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) | ||
else: # best class only | ||
conf, j = x[:, 5:].max(1, keepdim=True) | ||
x = torch.cat((box, conf, j.float()), 1)[ | ||
conf.view(-1) > conf_thres] | ||
|
||
# Filter by class | ||
if classes is not None: | ||
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] | ||
|
||
# Apply finite constraint | ||
# if not torch.isfinite(x).all(): | ||
# x = x[torch.isfinite(x).all(1)] | ||
|
||
# Check shape | ||
n = x.shape[0] # number of boxes | ||
if not n: # no boxes | ||
continue | ||
elif n > max_nms: # excess boxes | ||
# sort by confidence | ||
x = x[x[:, 4].argsort(descending=True)[:max_nms]] | ||
|
||
# Batched NMS | ||
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes | ||
# boxes (offset by class), scores | ||
boxes, scores = x[:, :4] + c, x[:, 4] | ||
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS | ||
if i.shape[0] > max_det: # limit detections | ||
i = i[:max_det] | ||
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) | ||
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) | ||
iou = torchvision.box_iou( | ||
boxes[i], boxes) > iou_thres # iou matrix | ||
weights = iou * scores[None] # box weights | ||
x[i, :4] = torch.mm(weights, x[:, :4]).float( | ||
) / weights.sum(1, keepdim=True) # merged boxes | ||
if redundant: | ||
i = i[iou.sum(1) > 1] # require redundancy | ||
|
||
output[xi] = x[i] | ||
|
||
return output | ||
|
||
|
||
def xywh2xyxy(x): | ||
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right | ||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) | ||
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x | ||
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y | ||
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x | ||
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y | ||
return y |
1 change: 1 addition & 0 deletions
1
api/megadetectorv5/sagemaker-serverless-endpoint-with-torchserve
Submodule sagemaker-serverless-endpoint-with-torchserve
added at
56a48c
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# https://github.com/pytorch/serve/tree/master/examples/object_detector/fast-rcnn | ||
# https://github.com/pytorch/serve/blob/master/examples/README.md#creating-mar-file-for-torchscript-mode-model | ||
# https://github.com/pytorch/serve | ||
# setup https://github.com/pytorch/serve | ||
|
||
#mkdir model_store | ||
|
||
# torch-model-archiver --model-name mdv5 --version 1.0.0 --serialized-file ../models/megadetectorv5/md_v5a.0.0.torchscript --extra-files index_to_name.json --handler ../api/megadetectorv5/mdv5_handler.py | ||
# mv mdv5.mar model_store/megadetectorv5-yolov5-1-batch-2048-2048.mar | ||
torchserve --start --model-store /app/model_store --no-config-snapshots --models mdv5=/app/megadetectorv5-yolov5-1-batch-640-640.mar | ||
#curl http://127.0.0.1:8080/predictions/mdv5 -T ../input/sample-img-fox.jpg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# this drop notebooks from GitHub language stats | ||
*.ipynb linguist-vendored |
Oops, something went wrong.