Skip to content

Commit

Permalink
Copt serialize_byte_tensor functon
Browse files Browse the repository at this point in the history
  • Loading branch information
gibchikafa committed Jan 31, 2024
1 parent bb2b05f commit 999c2af
Showing 1 changed file with 46 additions and 17 deletions.
63 changes: 46 additions & 17 deletions python/kserve/kserve/protocol/infer_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import Optional, List, Dict

import struct
import numpy
import numpy as np
import pandas as pd
Expand All @@ -31,32 +32,60 @@ def raise_error(msg):
raise InferenceServerException(msg=msg) from None


def serialized_byte_size(tensor_value):
def serialize_byte_tensor(input_tensor):
"""
Get the underlying number of bytes for a numpy ndarray.
Serializes a bytes tensor into a flat numpy array of length prepended
bytes. The numpy array should use dtype of np.object. For np.bytes,
numpy will remove trailing zeros at the end of byte sequence and because
of this it should be avoided.
Parameters
----------
tensor_value : numpy.ndarray
Numpy array to calculate the number of bytes for.
input_tensor : np.array
The bytes tensor to serialize.
Returns
-------
int
Number of bytes present in this tensor
"""
serialized_bytes_tensor : np.array
The 1-D numpy array of type uint8 containing the serialized bytes in row-major form.
if tensor_value.dtype != np.object_:
raise_error("The tensor_value dtype must be np.object_")
Raises
------
InferenceServerException
If unable to serialize the given tensor.
"""

if tensor_value.size > 0:
total_bytes = 0
# 'C' order is row-major.
for obj in np.nditer(tensor_value, flags=["refs_ok"], order="C"):
total_bytes += len(obj.item())
return total_bytes
else:
return 0
if input_tensor.size == 0:
return np.empty([0], dtype=np.object_)

# If the input is a tensor of string/bytes objects, then must flatten those into
# a 1-dimensional array containing the 4-byte byte size followed by the
# actual element bytes. All elements are concatenated together in row-major
# order.

if (input_tensor.dtype != np.object_) and (input_tensor.dtype.type != np.bytes_):
raise_error("cannot serialize bytes tensor: invalid datatype")

flattened_ls = []
# 'C' order is row-major.
for obj in np.nditer(input_tensor, flags=["refs_ok"], order="C"):
# If directly passing bytes to BYTES type,
# don't convert it to str as Python will encode the
# bytes which may distort the meaning
if input_tensor.dtype == np.object_:
if type(obj.item()) == bytes:
s = obj.item()
else:
s = str(obj.item()).encode("utf-8")
else:
s = obj.item()
flattened_ls.append(struct.pack("<I", len(s)))
flattened_ls.append(s)
flattened = b"".join(flattened_ls)
flattened_array = np.asarray(flattened, dtype=np.object_)
if not flattened_array.flags["C_CONTIGUOUS"]:
flattened_array = np.ascontiguousarray(flattened_array, dtype=np.object_)
return flattened_array


class InferenceServerException(Exception):
Expand Down

0 comments on commit 999c2af

Please sign in to comment.