-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* new tensor APIs * container namedtuples * tflite * simplified APIs stabilization init * test cases fix * todo fixes, LGTM.com fixes, type annotations * more tests * supporting np dtypes
- Loading branch information
Showing
11 changed files
with
170 additions
and
228 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,26 @@ | ||
from redisai import Client, Tensor, \ | ||
BlobTensor, DType, Device, Backend | ||
import numpy as np | ||
from redisai import Client, DType, Device, Backend | ||
import ml2rt | ||
|
||
client = Client() | ||
client.tensorset('x', Tensor(DType.float, [2], [2, 3])) | ||
client.tensorset('x', [2, 3], dtype=DType.float) | ||
t = client.tensorget('x') | ||
print(t.value) | ||
|
||
model = ml2rt.load_model('test/testdata/graph.pb') | ||
client.tensorset('a', Tensor.scalar(DType.float, 2, 3)) | ||
client.tensorset('b', Tensor.scalar(DType.float, 12, 10)) | ||
tensor1 = np.array([2, 3], dtype=np.float) | ||
client.tensorset('a', tensor1) | ||
client.tensorset('b', (12, 10), dtype=np.float) | ||
client.modelset('m', Backend.tf, | ||
Device.cpu, | ||
input=['a', 'b'], | ||
output='mul', | ||
inputs=['a', 'b'], | ||
outputs='mul', | ||
data=model) | ||
client.modelrun('m', ['a', 'b'], ['mul']) | ||
print(client.tensorget('mul').value) | ||
print(client.tensorget('mul')) | ||
|
||
# Try with a script | ||
script = ml2rt.load_script('test/testdata/script.txt') | ||
client.scriptset('ket', Device.cpu, script) | ||
client.scriptrun('ket', 'bar', input=['a', 'b'], output='c') | ||
client.scriptrun('ket', 'bar', inputs=['a', 'b'], outputs='c') | ||
|
||
b1 = client.tensorget('c', as_type=BlobTensor) | ||
b2 = client.tensorget('c', as_type=BlobTensor) | ||
|
||
client.tensorset('d', BlobTensor(DType.float, b1.shape, b1, b2)) | ||
|
||
tnp = b1.to_numpy() | ||
client.tensorset('e', tnp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
from .version import __version__ | ||
from .client import Client | ||
from .tensor import Tensor, BlobTensor | ||
from .constants import DType, Device, Backend |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ class Backend(Enum): | |
tf = 'TF' | ||
torch = 'TORCH' | ||
onnx = 'ONNX' | ||
tflite = 'TFLITE' | ||
|
||
|
||
class DType(Enum): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from collections import namedtuple | ||
|
||
Tensor = namedtuple('Tensor', field_names=['value', 'shape', 'dtype', 'argname']) | ||
Script = namedtuple('Script', field_names=['script', 'device']) | ||
Model = namedtuple('Model', field_names=['data', 'device', 'backend']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import Union, ByteString, Sequence | ||
from .utils import convert_to_num | ||
from .constants import DType | ||
from .containers import Tensor | ||
try: | ||
import numpy as np | ||
except (ImportError, ModuleNotFoundError): | ||
np = None | ||
|
||
|
||
def from_numpy(tensor: np.ndarray) -> Tensor: | ||
""" Convert the numpy input from user to `Tensor` """ | ||
dtype = DType.__members__[str(tensor.dtype)] | ||
shape = tensor.shape | ||
blob = bytes(tensor.data) | ||
return Tensor(blob, shape, dtype, 'BLOB') | ||
|
||
|
||
def from_sequence(tensor: Sequence, shape: Union[list, tuple], dtype: DType) -> Tensor: | ||
""" Convert the `list`/`tuple` input from user to `Tensor` """ | ||
return Tensor(tensor, shape, dtype, 'VALUES') | ||
|
||
|
||
def to_numpy(value: ByteString, shape: Union[list, tuple], dtype: DType) -> np.ndarray: | ||
""" Convert `BLOB` result from RedisAI to `np.ndarray` """ | ||
dtype = DType.__members__[dtype.lower()].value | ||
mm = { | ||
'FLOAT': 'float32', | ||
'DOUBLE': 'float64' | ||
} | ||
if dtype in mm: | ||
dtype = mm[dtype] | ||
else: | ||
dtype = dtype.lower() | ||
a = np.frombuffer(value, dtype=dtype) | ||
return a.reshape(shape) | ||
|
||
|
||
def to_sequence(value: list, shape: list, dtype: DType) -> Tensor: | ||
""" Convert `VALUES` result from RedisAI to `Tensor` """ | ||
dtype = DType.__members__[dtype.lower()] | ||
convert_to_num(dtype, value) | ||
return Tensor(value, tuple(shape), dtype, 'VALUES') |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
|
||
#!/usr/bin/env python | ||
from setuptools import setup, find_packages | ||
|
||
|
Oops, something went wrong.