Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Using model.mar with built-in handler script #133

Open
austinmw opened this issue Nov 22, 2022 · 0 comments
Open

[Question] Using model.mar with built-in handler script #133

austinmw opened this issue Nov 22, 2022 · 0 comments

Comments

@austinmw
Copy link

austinmw commented Nov 22, 2022

What did you find confusing? Please describe.
Hi, I've recently used a torchserve export utility provided by the MMDetection library. This uses the torch model archiver to package the following files into a model.mar:

Archive:  model.mar
  bbox_mAP_epoch_1.pth
  config.py
  mmdet_handler.py
  MAR-INF/MANIFEST.json

Is it possible to use this model.mar directly with SM TorchServe without needing to pull out the handler script and reformat it?

Additional context
I've added mmdet into my requirements.txt so the dependencies are not a problem.

The mmdet_handler.py script looks like this:

# Copyright (c) OpenMMLab. All rights reserved.
import base64
import os

import mmcv
import numpy as np
import torch
from ts.torch_handler.base_handler import BaseHandler

from mmdet.apis import inference_detector, init_detector
from mmdet.utils import register_all_modules

register_all_modules(True)


class MMdetHandler(BaseHandler):
    threshold = 0.5

    def initialize(self, context):
        properties = context.system_properties
        self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.device = torch.device(self.map_location + ':' +
                                   str(properties.get('gpu_id')) if torch.cuda.
                                   is_available() else self.map_location)
        self.manifest = context.manifest

        model_dir = properties.get('model_dir')
        serialized_file = self.manifest['model']['serializedFile']
        checkpoint = os.path.join(model_dir, serialized_file)
        self.config_file = os.path.join(model_dir, 'config.py')

        self.model = init_detector(self.config_file, checkpoint, self.device)
        self.initialized = True

    def preprocess(self, data):
        images = []

        for row in data:
            image = row.get('data') or row.get('body')
            if isinstance(image, str):
                image = base64.b64decode(image)
            image = mmcv.imfrombytes(image)
            images.append(image)

        return images

    def inference(self, data, *args, **kwargs):
        results = inference_detector(self.model, data)
        return results

    def postprocess(self, data):
        # Format output following the example ObjectDetectionHandler format
        output = []
        for data_sample in data:
            pred_instances = data_sample.pred_instances
            bboxes = pred_instances.bboxes.cpu().numpy().astype(
                np.float32).tolist()
            labels = pred_instances.labels.cpu().numpy().astype(
                np.int32).tolist()
            scores = pred_instances.scores.cpu().numpy().astype(
                np.float32).tolist()
            preds = []
            for idx in range(len(labels)):
                cls_score, bbox, cls_label = scores[idx], bboxes[idx], labels[
                    idx]
                if cls_score >= self.threshold:
                    class_name = self.model.dataset_meta['CLASSES'][cls_label]
                    result = dict(
                        class_label=cls_label,
                        class_name=class_name,
                        bbox=bbox,
                        score=cls_score)
                    preds.append(result)
            output.append(preds)
        return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant