Skip to content

Commit

Permalink
Adressed review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertSamoilescu committed May 20, 2024
1 parent 38b9b30 commit 499e693
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 38 deletions.
21 changes: 13 additions & 8 deletions mlserver/batching/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,17 @@ def not_implemented_warning(
Decorator to lets users know that adaptive batching is not required on
method `f`.
"""
warning_template = (
"Adaptive Batching is enabled for model '{model_name}'"
" but not supported for inference streaming. "
"Falling back to non-batched inference streaming."
model = _get_model(f)
logger.warning(
f"Adaptive Batching is enabled for model '{model.name}'"
" but not supported for inference streaming."
" Falling back to non-batched inference streaming."
)

@wraps(f)
async def _inner_stream(
payload: AsyncIterator[InferenceRequest],
) -> AsyncIterator[InferenceResponse]:
model = _get_model(f)
logger.warning(
warning_template.format(model_name=model.name, f_name=f.__name__)
)
async for response in f(payload):
yield response

Expand All @@ -88,6 +85,14 @@ async def load_batching(model: MLModel) -> MLModel:
if model.settings.max_batch_time <= 0:
return model

if model.settings.max_batch_size > 1 and model.settings.max_batch_time <= 0:
logger.warning(
"Setting max_batch_time equal to zero will result"
" in batching having no effect, if you intend to "
"use batching try setting it to a value > 0 for"
" batching to take effect"
)

batcher = AdaptiveBatcher(model)
setattr(model, _AdaptiveBatchingAttr, batcher)

Expand Down
4 changes: 2 additions & 2 deletions mlserver/grpc/servicers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def ModelStreamInfer(
break

use_raw = InferenceServicer._GetReturnRaw(request)
payloads = self._PayloadsDecorator(request, requests_stream, context)
payloads = self._PayloadsMetadataGenerator(request, requests_stream, context)

async for result in self._data_plane.infer_stream(
payloads=payloads, name=request.model_name, version=request.model_version
Expand All @@ -92,7 +92,7 @@ async def ModelStreamInfer(

self._SetTrailingMetadata(result, context)

async def _PayloadsDecorator(
async def _PayloadsMetadataGenerator(
self,
request: pb.ModelInferRequest,
requests_stream: AsyncIterator[pb.ModelInferRequest],
Expand Down
39 changes: 25 additions & 14 deletions mlserver/handlers/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def infer(
# need to cache the payload here since it
# will be modified in the context manager
if self._response_cache is not None:
cache_key = payload.json()
cache_key = payload.model_dump_json()

async with self._infer_contextmanager(name, version) as model:
payload = self._prepare_payload(payload, model)
Expand All @@ -101,11 +101,13 @@ async def infer(
):
cache_value = await self._response_cache.lookup(cache_key)
if cache_value != "":
prediction = InferenceResponse.parse_raw(cache_value)
prediction = InferenceResponse.model_validate_json(cache_value)
else:
prediction = await model.predict(payload)
# ignore cache insertion error if any
await self._response_cache.insert(cache_key, prediction.json())
await self._response_cache.insert(
cache_key, prediction.model_dump_json()
)
else:
prediction = await model.predict(payload)

Expand All @@ -121,13 +123,15 @@ async def infer_stream(
version: Optional[str] = None,
) -> AsyncIterator[InferenceResponse]:
# TODO: Implement cache for stream

async with self._infer_contextmanager(name, version) as model:
# we need to get the first payload to get the ID
async for payload in payloads:
break

payload = self._prepare_payload(payload, model)
payloads_decorated = self._payloads_decorator(payload, payloads, model)

payloads_decorated = self._prepare_payloads_generator(
payload, payloads, model
)
async for prediction in model.predict_stream(payloads_decorated):
prediction.id = payload.id # Ensure ID matches
self._inference_middleware.response_middleware(
Expand All @@ -144,17 +148,24 @@ def _prepare_payload(
self._inference_middleware.request_middleware(payload, model.settings)
return payload

async def _payloads_decorator(
async def _prepare_payloads_generator(
self,
payload: InferenceRequest,
payloads: AsyncIterator[InferenceRequest],
first_payload: InferenceRequest,
subsequent_payloads: AsyncIterator[InferenceRequest],
model: MLModel,
) -> AsyncIterator[InferenceRequest]:

payload = self._prepare_payload(payload, model)
yield payload

async for payload in payloads:
# yield the first payload after preparing it
first_payload = self._prepare_payload(first_payload, model)
yield first_payload

# Yield the rest of the payloads after preparing them
# and set the ID to match the first payload. Note that
# we don't make any assumptions about how many inputs and
# outputs there are. Thus, everything gets the same ID, cause
# otherwise we could have one to many, many to one, or many to
# many id mappings.
async for payload in subsequent_payloads:
payload.id = first_payload.id
payload = self._prepare_payload(payload, model)
yield payload

Expand Down
5 changes: 2 additions & 3 deletions mlserver/rest/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ async def infer(
async def infer_stream(
self,
raw_request: Request,
raw_response: Response,
payload: InferenceRequest,
model_name: str,
model_version: Optional[str] = None,
Expand All @@ -120,12 +119,12 @@ async def infer_stream(
request_headers = dict(raw_request.headers)
insert_headers(payload, request_headers)

async def payloads_async_iter(
async def payloads_generator(
payload: InferenceRequest,
) -> AsyncIterator[InferenceRequest]:
yield payload

payloads = payloads_async_iter(payload)
payloads = payloads_generator(payload)
infer_stream = self._data_plane.infer_stream(
payloads, model_name, model_version
)
Expand Down
2 changes: 1 addition & 1 deletion mlserver/rest/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, data: BaseModel, *args, **kwargs):
self.data = data

def encode(self) -> bytes:
as_dict = self.data.dict()
as_dict = self.data.model_dump()
return self._pre + _render(as_dict) + self._sep


Expand Down
25 changes: 16 additions & 9 deletions tests/rest/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,31 +151,37 @@ async def test_generate(
if model_version is not None:
endpoint = f"/v2/models/{model_name}/versions/{model_version}/generate"

response = await rest_client.post(endpoint, json=generate_request.dict())
response = await rest_client.post(endpoint, json=generate_request.model_dump())
assert response.status_code == 200

prediction = InferenceResponse.parse_obj(response.json())
prediction = InferenceResponse.model_validate(response.json())
assert len(prediction.outputs) == 1
assert prediction.outputs[0].data.__root__ == ["What is the capital of France?"]
assert prediction.outputs[0].data == TensorData(
root=["What is the capital of France?"]
)


@pytest.mark.parametrize("settings", [lazy_fixture("settings_stream")])
@pytest.mark.parametrize("sum_model", [lazy_fixture("text_stream_model")])
@pytest.mark.parametrize("endpoint", ["generate_stream", "infer_stream"])
async def test_generate_stream(
rest_client: AsyncClient,
generate_request: InferenceRequest,
text_stream_model: MLModel,
endpoint: str,
):
endpoint = f"/v2/models/{text_stream_model.name}/generate_stream"
conn = aconnect_sse(rest_client, "POST", endpoint, json=generate_request.dict())
endpoint = f"/v2/models/{text_stream_model.name}/{endpoint}"
conn = aconnect_sse(
rest_client, "POST", endpoint, json=generate_request.model_dump()
)
ref_text = ["What", " is", " the", " capital", " of", " France?"]

async with conn as stream:
i = 0
async for response in stream.aiter_sse():
prediction = InferenceResponse.parse_obj(response.json())
prediction = InferenceResponse.model_validate(response.json())
assert len(prediction.outputs) == 1
assert prediction.outputs[0].data.__root__ == [ref_text[i]]
assert prediction.outputs[0].data == TensorData(root=[ref_text[i]])
i += 1


Expand All @@ -200,10 +206,11 @@ async def test_infer_headers(
)


@pytest.mark.parametrize("endpoint", ["infer", "generate"])
async def test_infer_error(
rest_client: AsyncClient, inference_request: InferenceRequest
rest_client: AsyncClient, inference_request: InferenceRequest, endpoint: str
):
endpoint = "/v2/models/my-model/versions/v0/infer"
endpoint = f"/v2/models/my-model/versions/v0/{endpoint}"
response = await rest_client.post(endpoint, json=inference_request.model_dump())

assert response.status_code == 404
Expand Down
2 changes: 1 addition & 1 deletion tests/rest/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ def test_sse_encode(inference_request: InferenceRequest):
encoded = sse.encode()
as_string = encoded.decode("utf-8")

expected_json = inference_request.json().replace(" ", "")
expected_json = inference_request.model_dump_json().replace(" ", "")
expected = f"data: {expected_json}\n\n"
assert as_string == expected

0 comments on commit 499e693

Please sign in to comment.