From 2ee398750845e784814368197ac1aa6ff8358756 Mon Sep 17 00:00:00 2001 From: Aiden010200 <150222139+Aiden010200@users.noreply.github.com> Date: Fri, 27 Dec 2024 15:25:26 +0800 Subject: [PATCH] Upload a ResNet predictor example This example uses aiplatform and torch library to provide a ResNet predictor. --- .../torch/predictor_resnet.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 community-content/vertex_cpr_samples/torch/predictor_resnet.py diff --git a/community-content/vertex_cpr_samples/torch/predictor_resnet.py b/community-content/vertex_cpr_samples/torch/predictor_resnet.py new file mode 100644 index 000000000..282566ac8 --- /dev/null +++ b/community-content/vertex_cpr_samples/torch/predictor_resnet.py @@ -0,0 +1,34 @@ +import os +import torch + +from google.cloud.aiplatform.utils import prediction_utils +from google.cloud.aiplatform.prediction.predictor import Predictor +from torchvision.models import detection, resnet50, ResNet50_Weights +from typing import Dict, List + +class ResNetPredictor(Predictor): + + def __init__(self): + return + + def load(self, artifacts_uri: str) -> None: + prediction_utils.download_model_artifacts(artifacts_uri) + if os.path.exists("model.pth.tar"): + self.model = detection.fasterrcnn_resnet50_fpn(pretrained=True) + stat_dic = torch.load("model.pth.tar") + self.model.load_state_dict(stat_dic['state_dict']) + else: + weights = ResNet50_Weights.DEFAULT + self.model = resnet50(weights=weights) + self.model.eval() + + def preprocess(self, prediction_input: dict) -> torch.Tensor: + instances = prediction_input["instances"] + return torch.Tensor(instances) + + @torch.inference_mode() + def predict(self, instances: torch.Tensor) -> List[str]: + return self._model(instances) + + def postprocess(self, prediction_results: List[str]) -> Dict: + return {"predictions": prediction_results}