Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove tritonclient dependancy #20

Merged
merged 2 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion python/kserve/kserve/protocol/infer_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,142 @@

from typing import Optional, List, Dict

import struct
import numpy
import numpy as np
import pandas as pd
from tritonclient.utils import raise_error, serialize_byte_tensor

from ..constants.constants import GRPC_CONTENT_DATATYPE_MAPPINGS
from ..errors import InvalidInput
from ..protocol.grpc.grpc_predict_v2_pb2 import ModelInferRequest, InferTensorContents, ModelInferResponse
from ..utils.numpy_codec import to_np_dtype, from_np_dtype


def raise_error(msg):
"""
Raise error with the provided message
"""
raise InferenceServerException(msg=msg) from None


def serialize_byte_tensor(input_tensor):
"""
Serializes a bytes tensor into a flat numpy array of length prepended
bytes. The numpy array should use dtype of np.object. For np.bytes,
numpy will remove trailing zeros at the end of byte sequence and because
of this it should be avoided.

Parameters
----------
input_tensor : np.array
The bytes tensor to serialize.

Returns
-------
serialized_bytes_tensor : np.array
The 1-D numpy array of type uint8 containing the serialized bytes in row-major form.

Raises
------
InferenceServerException
If unable to serialize the given tensor.
"""

if input_tensor.size == 0:
return np.empty([0], dtype=np.object_)

# If the input is a tensor of string/bytes objects, then must flatten those into
# a 1-dimensional array containing the 4-byte byte size followed by the
# actual element bytes. All elements are concatenated together in row-major
# order.

if (input_tensor.dtype != np.object_) and (input_tensor.dtype.type != np.bytes_):
raise_error("cannot serialize bytes tensor: invalid datatype")

flattened_ls = []
# 'C' order is row-major.
for obj in np.nditer(input_tensor, flags=["refs_ok"], order="C"):
# If directly passing bytes to BYTES type,
# don't convert it to str as Python will encode the
# bytes which may distort the meaning
if input_tensor.dtype == np.object_:
if type(obj.item()) == bytes:
s = obj.item()
else:
s = str(obj.item()).encode("utf-8")
else:
s = obj.item()
flattened_ls.append(struct.pack("<I", len(s)))
flattened_ls.append(s)
flattened = b"".join(flattened_ls)
flattened_array = np.asarray(flattened, dtype=np.object_)
if not flattened_array.flags["C_CONTIGUOUS"]:
flattened_array = np.ascontiguousarray(flattened_array, dtype=np.object_)
return flattened_array


class InferenceServerException(Exception):
"""Exception indicating non-Success status.

Parameters
----------
msg : str
A brief description of error

status : str
The error code

debug_details : str
The additional details on the error

"""

def __init__(self, msg, status=None, debug_details=None):
self._msg = msg
self._status = status
self._debug_details = debug_details

def __str__(self):
msg = super().__str__() if self._msg is None else self._msg
if self._status is not None:
msg = "[" + self._status + "] " + msg
return msg

def message(self):
"""Get the exception message.

Returns
-------
str
The message associated with this exception, or None if no message.

"""
return self._msg

def status(self):
"""Get the status of the exception.

Returns
-------
str
Returns the status of the exception

"""
return self._status

def debug_details(self):
"""Get the detailed information about the exception
for debugging purposes

Returns
-------
str
Returns the exception details

"""
return self._debug_details


class InferInput:
_name: str
_shape: List[int]
Expand Down
1 change: 0 additions & 1 deletion python/kserve/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ prometheus-client = "^0.13.1"
orjson = "^3.8.0"
httpx = "^0.23.0"
timing-asgi = "^0.3.0"
tritonclient = "^2.18.0"
tabulate = "^0.9.0"
pandas = ">=1.3.5"

Expand Down
Loading