Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

On JAX, Keras replaces any exception inside call method of keras.Model subclass with misleading error #19675

Closed
burnpanck opened this issue May 6, 2024 · 2 comments
Assignees

Comments

@burnpanck
Copy link

MWE:

import os

os.environ["KERAS_BACKEND"] = "jax"

import keras

class Test(keras.Model):
    def call(self, x):
        raise RuntimeError("Random misspelling deeply nested in the model")

t = Test()

inp = keras.Input(shape=(32, 3))

t(inp)

Running the above example causes the following exception being raised:

TypeError: Exception encountered when calling Test.call().

Shapes must be 1D sequences of concrete values of integer type, got (None, 32, 3).

Arguments received by Test.call():
  • args=('<KerasTensor shape=(None, 32, 3), dtype=float32, sparse=None, name=keras_tensor_1>',)
  • kwargs=<class 'inspect._empty'>

On tensorflow, the exception instead reads:

RuntimeError: Exception encountered when calling Test.call().

Could not automatically infer the output shape / dtype of 'test_1' (of type Test). Either the `Test.call()` method is incorrect, or you need to implement the `Test.compute_output_spec() / compute_output_shape()` method. Error encountered:

Random misspelling deeply nested in the model

Arguments received by Test.call():
  • args=('<KerasTensor shape=(None, 32, 3), dtype=float32, sparse=None, name=keras_tensor_1>',)
  • kwargs=<class 'inspect._empty'>

Note that in the case of tensorflow, the error message contains the original exception string, whereas under JAX, the message misleadingly makes a strong suggestion that there is a problem with a shape. Furthermore, the internal frames of the stack trace get erased (not shown highlighted in the example above to minimize the MWE). If this happens deeply inside a model, an unsuspecting user may be sent off to an many hours long hunt for mismatched shapes that doesn't turn up anything useful.

@burnpanck
Copy link
Author

This was using Keras 3.1.1, JAX 0.4.26, and python 3.12

@fchollet
Copy link
Member

fchollet commented May 6, 2024

Good catch. I fixed it at HEAD.

@fchollet fchollet closed this as completed May 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants