Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confusing error when passing float64 arrays to some layers #461

Open
lbenc135 opened this issue Jan 20, 2021 · 3 comments
Open

Confusing error when passing float64 arrays to some layers #461

lbenc135 opened this issue Jan 20, 2021 · 3 comments
Labels
feat / ux User experience, error messages etc. serialization Saving and loading models

Comments

@lbenc135
Copy link

I'm try to load a model from disk which I saved previously, and run predict on it, but it produces the following error:

...
    predicted = self.model.predict(X)[0]
../../venv/lib/python3.8/site-packages/thinc/model.py:312: in predict
    return self._func(self, X, is_train=False)[0]
../../venv/lib/python3.8/site-packages/thinc/layers/chain.py:54: in forward
    Y, inc_layer_grad = layer(X, is_train=is_train)
../../venv/lib/python3.8/site-packages/thinc/model.py:288: in __call__
    return self._func(self, X, is_train=is_train)
../../venv/lib/python3.8/site-packages/thinc/layers/relu.py:44: in forward
    Y = model.ops.affine(X, W, b)
../../venv/lib/python3.8/site-packages/thinc/backends/ops.py:203: in affine
    Y = self.gemm(X, W, trans2=True)
thinc/backends/numpy_ops.pyx:84: in thinc.backends.numpy_ops.NumpyOps.gemm
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

>   ???
E   ValueError: Buffer dtype mismatch, expected 'const double' but got 'float'

blis/py.pyx:64: ValueError

I don't get the error when I do train -> predict. Only when doing train -> to_disk -> from_disk -> predict. I've tried with bytes instead of disk, but the same error appears.

Model:

model = chain(
    Relu(10),
    Relu(1),
    Logistic()
)
model.from_disk('model.bin')

Input: model.ops.asarray(np.array([[0., 0., 0.5]]))

@lbenc135 lbenc135 changed the title Errors after serialization (Buffer dtype mismatch) Errors after serialization (Buffer dtype mismatch) (8.0.0rc4) Jan 20, 2021
@honnibal
Copy link
Member

That does look suspicious. Presumably somewhere during serialization or deserialization, the weights are coming back with the wrong dtype. You could check with:

for layer in model.walk():
    for name in layer.param_names:
        param = layer.get_param(name)
        print(name, param.dtype)

Would you expect this to reproduce the issue?

model = chain(Relu(10, 10), Relu(1, 10), Logistic())
model = model.initialize()
b = model.to_bytes()
model = chain(Relu(10, 10), Relu(1, 10), Logistic()).from_bytes(b)
model.predict(model.ops.alloc2f(2, 10))

It's working for me when I tested just now.

@lbenc135
Copy link
Author

I figured it out - the problem is that np.array([[0., 0., 0.5]]) by default sets the dtype to float64 (which produces the error), while model.ops.alloc2f(...) sets it to float32 which works as expected.

Basically I forgot to set the dtype explicitly while running prediction.

Feel free to close the issue, although a more helpful error message would be appreciated.

@honnibal
Copy link
Member

honnibal commented Jan 20, 2021

Ah, thanks 👍

I do think our error (or possibly behaviour) here should be fixed. Our general perspective in our libraries is that if the user's input doesn't match some API constraint, they shouldn't get some arbitrary error deep in the library internals --- we should usually try to raise something specific earlier. So I'll retitle this and leave it open.

@honnibal honnibal changed the title Errors after serialization (Buffer dtype mismatch) (8.0.0rc4) Confusing error when passing float64 arrays to some layers Jan 20, 2021
@svlandeg svlandeg added feat / ux User experience, error messages etc. serialization Saving and loading models labels Jan 21, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feat / ux User experience, error messages etc. serialization Saving and loading models
Projects
None yet
Development

No branches or pull requests

3 participants