diff --git a/runtimes/mlflow/mlserver_mlflow/runtime.py b/runtimes/mlflow/mlserver_mlflow/runtime.py index 896d2bf6c..37bf0a039 100644 --- a/runtimes/mlflow/mlserver_mlflow/runtime.py +++ b/runtimes/mlflow/mlserver_mlflow/runtime.py @@ -10,7 +10,8 @@ CONTENT_TYPE_CSV, CONTENT_TYPE_JSON, parse_csv_input, - infer_and_parse_json_input, + _split_data_and_params, + infer_and_parse_data, predictions_to_json, ) @@ -124,8 +125,10 @@ async def invocations( if mime_type == CONTENT_TYPE_CSV: csv_input = StringIO(raw_body) data = parse_csv_input(csv_input=csv_input, schema=self._input_schema) + inference_params = None elif mime_type == CONTENT_TYPE_JSON: - data = infer_and_parse_json_input(raw_body, self._input_schema) + raw_data, inference_params = _split_data_and_params(raw_body) + data = infer_and_parse_data(raw_data, self._input_schema) else: err_message = ( "This predictor only supports the following content types, " @@ -134,7 +137,7 @@ async def invocations( raise InferenceError(err_message) try: - raw_predictions = self._model.predict(data) + raw_predictions = self._model.predict(data, params=inference_params) except MlflowException as e: raise InferenceError(e.message) except Exception: diff --git a/runtimes/mlflow/tests/conftest.py b/runtimes/mlflow/tests/conftest.py index 3c4aa3748..71a713a37 100644 --- a/runtimes/mlflow/tests/conftest.py +++ b/runtimes/mlflow/tests/conftest.py @@ -41,9 +41,14 @@ def dataset() -> tuple: @pytest.fixture -def model_signature(dataset: tuple) -> ModelSignature: +def default_inference_params() -> dict: + return {"foo_param": "foo_value"} + + +@pytest.fixture +def model_signature(dataset: tuple, default_inference_params: dict) -> ModelSignature: X, y = dataset - signature = infer_signature(X, y) + signature = infer_signature(X, model_output=y, params=default_inference_params) return signature diff --git a/runtimes/mlflow/tests/test_runtime.py b/runtimes/mlflow/tests/test_runtime.py index 7090720b5..28f1bb1a2 100644 --- a/runtimes/mlflow/tests/test_runtime.py +++ b/runtimes/mlflow/tests/test_runtime.py @@ -1,4 +1,5 @@ import pytest +from unittest import mock import numpy as np import pandas as pd @@ -14,6 +15,7 @@ ) from mlflow.pyfunc import PyFuncModel from mlflow.models.signature import ModelSignature +from mlflow.pyfunc.scoring_server import CONTENT_TYPE_CSV, CONTENT_TYPE_JSON from mlserver_mlflow import MLflowRuntime from mlserver_mlflow.codecs import TensorDictCodec @@ -188,3 +190,70 @@ async def test_metadata(runtime: MLflowRuntime, model_signature: ModelSignature) assert metadata.parameters is not None assert metadata.parameters.content_type == PandasCodec.ContentType + + +@pytest.mark.parametrize( + "input, expected", + [ + # works with params: + ( + ['{"instances": [1, 2, 3], "params": {"foo": "bar"}}', CONTENT_TYPE_JSON], + {"data": {"foo": [1, 2, 3]}, "params": {"foo": "bar"}}, + ), + ( + [ + '{"inputs": [1, 2, 3], "params": {"foo": "bar"}}', + CONTENT_TYPE_JSON, + ], + {"data": {"foo": [1, 2, 3]}, "params": {"foo": "bar"}}, + ), + ( + [ + '{"inputs": {"foo": [1, 2, 3]}, "params": {"foo": "bar"}}', + CONTENT_TYPE_JSON, + ], + {"data": {"foo": [1, 2, 3]}, "params": {"foo": "bar"}}, + ), + ( + [ + '{"dataframe_split": {' + '"columns": ["foo"], ' + '"data": [1, 2, 3]}, ' + '"params": {"foo": "bar"}}', + CONTENT_TYPE_JSON, + ], + {"data": {"foo": [1, 2, 3]}, "params": {"foo": "bar"}}, + ), + ( + [ + '{"dataframe_records": [' + '{"foo": 1}, {"foo": 2}, {"foo": 3}], ' + '"params": {"foo": "bar"}}', + CONTENT_TYPE_JSON, + ], + {"data": {"foo": [1, 2, 3]}, "params": {"foo": "bar"}}, + ), + ( + ["foo\n1\n2\n3\n", CONTENT_TYPE_CSV], + {"data": {"foo": [1, 2, 3]}, "params": None}, + ), + # works without params: + ( + ['{"instances": [1, 2, 3]}', CONTENT_TYPE_JSON], + {"data": {"foo": [1, 2, 3]}, "params": None}, + ), + ], +) +async def test_invocation_with_params( + runtime: MLflowRuntime, + input: list, + expected: dict, +): + with mock.patch.object( + runtime._model, "predict", return_value=[1, 2, 3] + ) as predict_mock: + await runtime.invocations(*input) + np.testing.assert_array_equal( + predict_mock.call_args[0][0].get("foo"), expected["data"]["foo"] + ) + assert predict_mock.call_args.kwargs["params"] == expected["params"]