From 15a67d7456e2e3d5b0f47cb75f7bc7a948d00413 Mon Sep 17 00:00:00 2001 From: GuyAv46 <47632673+GuyAv46@users.noreply.github.com> Date: Wed, 1 Dec 2021 15:38:09 +0200 Subject: [PATCH] String tensor support (#90) * added support for string tensors, when the data is given by VALUE (and not with BLOB) * updated test.py for setting and getting string tensors by VALUE * added support for tensorset from numpy string array as blob * updated test.py to test tensorset with numpy string array * linting * small fix * Review fixes: Added a comment. Deleted numpy_string2blob and replaced with a single line using join. Deleted utils.recursive_bytetransform_str and sets 'target' to a decode function Co-authored-by: alonre24 --- redisai/command_builder.py | 7 ++++++- redisai/postprocessor.py | 6 +++++- redisai/utils.py | 13 ++++++++++--- test/test.py | 16 ++++++++++++---- tox.ini | 2 +- 5 files changed, 34 insertions(+), 10 deletions(-) diff --git a/redisai/command_builder.py b/redisai/command_builder.py index 5b72dec..d97e910 100644 --- a/redisai/command_builder.py +++ b/redisai/command_builder.py @@ -176,7 +176,12 @@ def tensorset( args = ["AI.TENSORSET", key, dtype, *shape, "BLOB", blob] elif isinstance(tensor, (list, tuple)): try: - dtype = utils.dtype_dict[dtype.lower()] + # Numpy 'str' dtype has many different names regarding maximal length in the tensor and more, + # but the all share the 'num' attribute. This is a way to check if a dtype is a kind of string. + if np.dtype(dtype).num == np.dtype("str").num: + dtype = utils.dtype_dict["str"] + else: + dtype = utils.dtype_dict[dtype.lower()] except KeyError: raise TypeError( f"``{dtype}`` is not supported by RedisAI. Currently " diff --git a/redisai/postprocessor.py b/redisai/postprocessor.py index ae93fab..23318d8 100644 --- a/redisai/postprocessor.py +++ b/redisai/postprocessor.py @@ -42,7 +42,11 @@ def tensorget(res, as_numpy, as_numpy_mutable, meta_only): mutable=False, ) else: - target = float if rai_result["dtype"] in ("FLOAT", "DOUBLE") else int + if rai_result["dtype"] == "STRING": + def target(b): + return b.decode() + else: + target = float if rai_result["dtype"] in ("FLOAT", "DOUBLE") else int utils.recursive_bytetransform(rai_result["values"], target) return rai_result diff --git a/redisai/utils.py b/redisai/utils.py index 6fc4347..c0720a3 100644 --- a/redisai/utils.py +++ b/redisai/utils.py @@ -15,6 +15,7 @@ "uint32": "UINT32", "uint64": "UINT64", "bool": "BOOL", + "str": "STRING", } allowed_devices = {"CPU", "GPU"} @@ -24,11 +25,15 @@ def numpy2blob(tensor: np.ndarray) -> tuple: """Convert the numpy input from user to `Tensor`.""" try: - dtype = dtype_dict[str(tensor.dtype)] + if tensor.dtype.num == np.dtype("str").num: + dtype = dtype_dict["str"] + blob = "".join([string + "\0" for string in tensor.flat]) + else: + dtype = dtype_dict[str(tensor.dtype)] + blob = tensor.tobytes() except KeyError: raise TypeError(f"RedisAI doesn't support tensors of type {tensor.dtype}") shape = tensor.shape - blob = bytes(tensor.data) return dtype, shape, blob @@ -38,7 +43,9 @@ def blob2numpy( """Convert `BLOB` result from RedisAI to `np.ndarray`.""" mm = {"FLOAT": "float32", "DOUBLE": "float64"} dtype = mm.get(dtype, dtype.lower()) - if mutable: + if dtype == 'string': + a = np.array(value.decode().split('\0')[:-1], dtype='str') + elif mutable: a = np.fromstring(value, dtype=dtype) else: a = np.frombuffer(value, dtype=dtype) diff --git a/test/test.py b/test/test.py index da9d1e3..86c5b9e 100644 --- a/test/test.py +++ b/test/test.py @@ -117,6 +117,12 @@ def test_set_non_numpy_tensor(self): self.assertEqual([2, 2], result["shape"]) self.assertEqual("BOOL", result["dtype"]) + con.tensorset("x", (12, 'a', 'G', 'four'), dtype="str", shape=(2, 2)) + result = con.tensorget("x", as_numpy=False) + self.assertEqual(['12', 'a', 'G', 'four'], result["values"]) + self.assertEqual([2, 2], result["shape"]) + self.assertEqual("STRING", result["dtype"]) + with self.assertRaises(TypeError): con.tensorset("x", (2, 3, 4, 5), dtype="wrongtype", shape=(2, 2)) con.tensorset("x", (2, 3, 4, 5), dtype="int8", shape=(2, 2)) @@ -156,6 +162,12 @@ def test_numpy_tensor(self): self.assertEqual(values.dtype, "bool") self.assertTrue(np.array_equal(values, [True, False])) + input_array = np.array(["a", "bb", "⚓⚓⚓", "d♻d♻"]).reshape((2, 2)) + con.tensorset("x", input_array) + values = con.tensorget("x") + self.assertEqual(values.dtype.num, np.dtype("str").num) + self.assertTrue(np.array_equal(values, [['a', 'bb'], ["⚓⚓⚓", "d♻d♻"]])) + input_array = np.array([2, 3]) con.tensorset("x", input_array) values = con.tensorget("x") @@ -174,10 +186,6 @@ def test_numpy_tensor(self): np.put(ret, 0, 1) self.assertEqual(ret[0], 1) - stringarr = np.array("dummy") - with self.assertRaises(TypeError): - con.tensorset("trying", stringarr) - # AI.MODELSET is deprecated by AI.MODELSTORE. def test_deprecated_modelset(self): model_path = os.path.join(MODEL_DIR, "graph.pb") diff --git a/tox.ini b/tox.ini index 01a4666..69a64d7 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ envlist = linters,tests max-complexity = 10 ignore = E501,C901 srcdir = ./redisai -exclude =.git,.tox,dist,doc,*/__pycache__/* +exclude =.git,.tox,dist,doc,*/__pycache__/*,venv [testenv:tests] whitelist_externals = find