Skip to content

Commit

Permalink
String tensor support (#90)
Browse files Browse the repository at this point in the history
* added support for string tensors,
when the data is given by VALUE (and not with BLOB)

* updated test.py
for setting and getting string tensors by VALUE

* added support for tensorset
from numpy string array as blob

* updated test.py
to test tensorset with numpy string array

* linting

* small fix

* Review fixes:

Added a comment.

Deleted numpy_string2blob and replaced with a single line using join.

Deleted utils.recursive_bytetransform_str and
sets 'target' to a decode function

Co-authored-by: alonre24 <alonreshef24@gmail.com>
  • Loading branch information
GuyAv46 and alonre24 committed Dec 1, 2021
1 parent 35b1e3f commit 15a67d7
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 10 deletions.
7 changes: 6 additions & 1 deletion redisai/command_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,12 @@ def tensorset(
args = ["AI.TENSORSET", key, dtype, *shape, "BLOB", blob]
elif isinstance(tensor, (list, tuple)):
try:
dtype = utils.dtype_dict[dtype.lower()]
# Numpy 'str' dtype has many different names regarding maximal length in the tensor and more,
# but the all share the 'num' attribute. This is a way to check if a dtype is a kind of string.
if np.dtype(dtype).num == np.dtype("str").num:
dtype = utils.dtype_dict["str"]
else:
dtype = utils.dtype_dict[dtype.lower()]
except KeyError:
raise TypeError(
f"``{dtype}`` is not supported by RedisAI. Currently "
Expand Down
6 changes: 5 additions & 1 deletion redisai/postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def tensorget(res, as_numpy, as_numpy_mutable, meta_only):
mutable=False,
)
else:
target = float if rai_result["dtype"] in ("FLOAT", "DOUBLE") else int
if rai_result["dtype"] == "STRING":
def target(b):
return b.decode()
else:
target = float if rai_result["dtype"] in ("FLOAT", "DOUBLE") else int
utils.recursive_bytetransform(rai_result["values"], target)
return rai_result

Expand Down
13 changes: 10 additions & 3 deletions redisai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"uint32": "UINT32",
"uint64": "UINT64",
"bool": "BOOL",
"str": "STRING",
}

allowed_devices = {"CPU", "GPU"}
Expand All @@ -24,11 +25,15 @@
def numpy2blob(tensor: np.ndarray) -> tuple:
"""Convert the numpy input from user to `Tensor`."""
try:
dtype = dtype_dict[str(tensor.dtype)]
if tensor.dtype.num == np.dtype("str").num:
dtype = dtype_dict["str"]
blob = "".join([string + "\0" for string in tensor.flat])
else:
dtype = dtype_dict[str(tensor.dtype)]
blob = tensor.tobytes()
except KeyError:
raise TypeError(f"RedisAI doesn't support tensors of type {tensor.dtype}")
shape = tensor.shape
blob = bytes(tensor.data)
return dtype, shape, blob


Expand All @@ -38,7 +43,9 @@ def blob2numpy(
"""Convert `BLOB` result from RedisAI to `np.ndarray`."""
mm = {"FLOAT": "float32", "DOUBLE": "float64"}
dtype = mm.get(dtype, dtype.lower())
if mutable:
if dtype == 'string':
a = np.array(value.decode().split('\0')[:-1], dtype='str')
elif mutable:
a = np.fromstring(value, dtype=dtype)
else:
a = np.frombuffer(value, dtype=dtype)
Expand Down
16 changes: 12 additions & 4 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def test_set_non_numpy_tensor(self):
self.assertEqual([2, 2], result["shape"])
self.assertEqual("BOOL", result["dtype"])

con.tensorset("x", (12, 'a', 'G', 'four'), dtype="str", shape=(2, 2))
result = con.tensorget("x", as_numpy=False)
self.assertEqual(['12', 'a', 'G', 'four'], result["values"])
self.assertEqual([2, 2], result["shape"])
self.assertEqual("STRING", result["dtype"])

with self.assertRaises(TypeError):
con.tensorset("x", (2, 3, 4, 5), dtype="wrongtype", shape=(2, 2))
con.tensorset("x", (2, 3, 4, 5), dtype="int8", shape=(2, 2))
Expand Down Expand Up @@ -156,6 +162,12 @@ def test_numpy_tensor(self):
self.assertEqual(values.dtype, "bool")
self.assertTrue(np.array_equal(values, [True, False]))

input_array = np.array(["a", "bb", "⚓⚓⚓", "d♻d♻"]).reshape((2, 2))
con.tensorset("x", input_array)
values = con.tensorget("x")
self.assertEqual(values.dtype.num, np.dtype("str").num)
self.assertTrue(np.array_equal(values, [['a', 'bb'], ["⚓⚓⚓", "d♻d♻"]]))

input_array = np.array([2, 3])
con.tensorset("x", input_array)
values = con.tensorget("x")
Expand All @@ -174,10 +186,6 @@ def test_numpy_tensor(self):
np.put(ret, 0, 1)
self.assertEqual(ret[0], 1)

stringarr = np.array("dummy")
with self.assertRaises(TypeError):
con.tensorset("trying", stringarr)

# AI.MODELSET is deprecated by AI.MODELSTORE.
def test_deprecated_modelset(self):
model_path = os.path.join(MODEL_DIR, "graph.pb")
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ envlist = linters,tests
max-complexity = 10
ignore = E501,C901
srcdir = ./redisai
exclude =.git,.tox,dist,doc,*/__pycache__/*
exclude =.git,.tox,dist,doc,*/__pycache__/*,venv

[testenv:tests]
whitelist_externals = find
Expand Down

0 comments on commit 15a67d7

Please sign in to comment.