diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7361d6d..82e77c2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,7 +42,7 @@ jobs: - name: Install dependencies run: | pip install -e . - pip install pytest==6.1.1 torchvision scikit-learn gorilla transformers torchtext matplotlib captum + pip install pytest==7.2.1 torchvision scikit-learn gorilla transformers torchtext matplotlib captum - name: Install pytorch lightning run: | diff --git a/mlflow_torchserve/__init__.py b/mlflow_torchserve/__init__.py index 81894c4..7a5491f 100644 --- a/mlflow_torchserve/__init__.py +++ b/mlflow_torchserve/__init__.py @@ -22,9 +22,9 @@ _DEFAULT_TORCHSERVE_LOCAL_MANAGEMENT_PORT = "8081" -class CustomPredictionsResponse(PredictionsResponse): +class TorchServePredictionsResponse(PredictionsResponse): def __init__(self, resp): - super(CustomPredictionsResponse, self).__init__(self) + super(TorchServePredictionsResponse, self).__init__(self) self.resp = resp def to_json(self, path=None): @@ -299,7 +299,7 @@ def predict(self, deployment_name, df): raise ValueError("Unable to parse input json string: {}".format(e)) resp = requests.post(url, data) - cust_resp = CustomPredictionsResponse(resp.text) + cust_resp = TorchServePredictionsResponse(resp.text) if resp.status_code != 200: raise Exception( "Unable to infer the results for the name %s. "