You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importosos.environ["KERAS_BACKEND"] ="jax"importkerasclassTest(keras.Model):
defcall(self, x):
raiseRuntimeError("Random misspelling deeply nested in the model")
t=Test()
inp=keras.Input(shape=(32, 3))
t(inp)
Running the above example causes the following exception being raised:
TypeError: Exception encountered when calling Test.call().
Shapes must be 1D sequences of concrete values of integer type, got (None, 32, 3).
Arguments received by Test.call():
• args=('<KerasTensor shape=(None, 32, 3), dtype=float32, sparse=None, name=keras_tensor_1>',)
• kwargs=<class 'inspect._empty'>
On tensorflow, the exception instead reads:
RuntimeError: Exception encountered when calling Test.call().
Could not automatically infer the output shape / dtype of 'test_1' (of type Test). Either the `Test.call()` method is incorrect, or you need to implement the `Test.compute_output_spec() / compute_output_shape()` method. Error encountered:
Random misspelling deeply nested in the model
Arguments received by Test.call():
• args=('<KerasTensor shape=(None, 32, 3), dtype=float32, sparse=None, name=keras_tensor_1>',)
• kwargs=<class 'inspect._empty'>
Note that in the case of tensorflow, the error message contains the original exception string, whereas under JAX, the message misleadingly makes a strong suggestion that there is a problem with a shape. Furthermore, the internal frames of the stack trace get erased (not shown highlighted in the example above to minimize the MWE). If this happens deeply inside a model, an unsuspecting user may be sent off to an many hours long hunt for mismatched shapes that doesn't turn up anything useful.
The text was updated successfully, but these errors were encountered:
MWE:
Running the above example causes the following exception being raised:
On tensorflow, the exception instead reads:
Note that in the case of tensorflow, the error message contains the original exception string, whereas under JAX, the message misleadingly makes a strong suggestion that there is a problem with a shape. Furthermore, the internal frames of the stack trace get erased (not shown highlighted in the example above to minimize the MWE). If this happens deeply inside a model, an unsuspecting user may be sent off to an many hours long hunt for mismatched shapes that doesn't turn up anything useful.
The text was updated successfully, but these errors were encountered: