Skip to content

Commit

Permalink
Add inference params support to MLFlow's custom invocation endpoint (… (
Browse files Browse the repository at this point in the history
#1375)

Co-authored-by: Manuel Laventure <[email protected]>
  • Loading branch information
2 people authored and Adrian Gonzalez-Martin committed Sep 7, 2023
1 parent 5aaf329 commit 3cabf33
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 5 deletions.
9 changes: 6 additions & 3 deletions runtimes/mlflow/mlserver_mlflow/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
CONTENT_TYPE_CSV,
CONTENT_TYPE_JSON,
parse_csv_input,
infer_and_parse_json_input,
_split_data_and_params,
infer_and_parse_data,
predictions_to_json,
)

Expand Down Expand Up @@ -124,8 +125,10 @@ async def invocations(
if mime_type == CONTENT_TYPE_CSV:
csv_input = StringIO(raw_body)
data = parse_csv_input(csv_input=csv_input, schema=self._input_schema)
inference_params = None
elif mime_type == CONTENT_TYPE_JSON:
data = infer_and_parse_json_input(raw_body, self._input_schema)
raw_data, inference_params = _split_data_and_params(raw_body)
data = infer_and_parse_data(raw_data, self._input_schema)
else:
err_message = (
"This predictor only supports the following content types, "
Expand All @@ -134,7 +137,7 @@ async def invocations(
raise InferenceError(err_message)

try:
raw_predictions = self._model.predict(data)
raw_predictions = self._model.predict(data, params=inference_params)
except MlflowException as e:
raise InferenceError(e.message)
except Exception:
Expand Down
9 changes: 7 additions & 2 deletions runtimes/mlflow/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,14 @@ def dataset() -> tuple:


@pytest.fixture
def model_signature(dataset: tuple) -> ModelSignature:
def default_inference_params() -> dict:
return {"foo_param": "foo_value"}


@pytest.fixture
def model_signature(dataset: tuple, default_inference_params: dict) -> ModelSignature:
X, y = dataset
signature = infer_signature(X, y)
signature = infer_signature(X, model_output=y, params=default_inference_params)

return signature

Expand Down
69 changes: 69 additions & 0 deletions runtimes/mlflow/tests/test_runtime.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from unittest import mock
import numpy as np
import pandas as pd

Expand All @@ -14,6 +15,7 @@
)
from mlflow.pyfunc import PyFuncModel
from mlflow.models.signature import ModelSignature
from mlflow.pyfunc.scoring_server import CONTENT_TYPE_CSV, CONTENT_TYPE_JSON

from mlserver_mlflow import MLflowRuntime
from mlserver_mlflow.codecs import TensorDictCodec
Expand Down Expand Up @@ -188,3 +190,70 @@ async def test_metadata(runtime: MLflowRuntime, model_signature: ModelSignature)

assert metadata.parameters is not None
assert metadata.parameters.content_type == PandasCodec.ContentType


@pytest.mark.parametrize(
"input, expected",
[
# works with params:
(
['{"instances": [1, 2, 3], "params": {"foo": "bar"}}', CONTENT_TYPE_JSON],
{"data": {"foo": [1, 2, 3]}, "params": {"foo": "bar"}},
),
(
[
'{"inputs": [1, 2, 3], "params": {"foo": "bar"}}',
CONTENT_TYPE_JSON,
],
{"data": {"foo": [1, 2, 3]}, "params": {"foo": "bar"}},
),
(
[
'{"inputs": {"foo": [1, 2, 3]}, "params": {"foo": "bar"}}',
CONTENT_TYPE_JSON,
],
{"data": {"foo": [1, 2, 3]}, "params": {"foo": "bar"}},
),
(
[
'{"dataframe_split": {'
'"columns": ["foo"], '
'"data": [1, 2, 3]}, '
'"params": {"foo": "bar"}}',
CONTENT_TYPE_JSON,
],
{"data": {"foo": [1, 2, 3]}, "params": {"foo": "bar"}},
),
(
[
'{"dataframe_records": ['
'{"foo": 1}, {"foo": 2}, {"foo": 3}], '
'"params": {"foo": "bar"}}',
CONTENT_TYPE_JSON,
],
{"data": {"foo": [1, 2, 3]}, "params": {"foo": "bar"}},
),
(
["foo\n1\n2\n3\n", CONTENT_TYPE_CSV],
{"data": {"foo": [1, 2, 3]}, "params": None},
),
# works without params:
(
['{"instances": [1, 2, 3]}', CONTENT_TYPE_JSON],
{"data": {"foo": [1, 2, 3]}, "params": None},
),
],
)
async def test_invocation_with_params(
runtime: MLflowRuntime,
input: list,
expected: dict,
):
with mock.patch.object(
runtime._model, "predict", return_value=[1, 2, 3]
) as predict_mock:
await runtime.invocations(*input)
np.testing.assert_array_equal(
predict_mock.call_args[0][0].get("foo"), expected["data"]["foo"]
)
assert predict_mock.call_args.kwargs["params"] == expected["params"]

0 comments on commit 3cabf33

Please sign in to comment.