Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX array conversion failure in Keras model prediction #19674

Closed
Qazalbash opened this issue May 6, 2024 · 3 comments
Closed

JAX array conversion failure in Keras model prediction #19674

Qazalbash opened this issue May 6, 2024 · 3 comments
Assignees

Comments

@Qazalbash
Copy link

I have trained a simple Deep-MLP model and saved it in .keras format. I am utilizing JAX jitted functions for predictions, passing two inputs as jax.numpy.column_stack. Despite attempting alternative methods, including using numpy.column_stack and setting JAX_TRACEBACK_FILTERING=off, the issue persists. Notably, my Keras backend is configured as KERAS_BACKEND=jax.

File "/media/project/inference/lippl.py", line 114, in exp_rate_integral
    jnp.exp(mass_model.log_prob(m1q) + self.logVT.predict(m1m2).flatten()),
                                       ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gradf/miniforge3/envs/gwkenv/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/gradf/miniforge3/envs/gwkenv/lib/python3.11/site-packages/optree/ops.py", line 594, in tree_map
    return treespec.unflatten(map(func, *flat_args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gradf/miniforge3/envs/gwkenv/lib/python3.11/site-packages/jax/_src/core.py", line 684, in __array__
    raise TracerArrayConversionError(self)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[10000,2].
The error occurred while tracing the function likelihood at /media/project/inference/lippl.py:119 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /media/project/inference/lippl.py:107:29 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)

  operation a:f32[10000] = pjit[
  name=_uniform
  jaxpr={ lambda ; b:key<fry>[] c:i32[] d:i32[]. let
      e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
      f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
      g:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] e
      h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
      i:u32[10000] = random_bits[bit_width=32 shape=(10000,)] b
      j:u32[10000] = shift_right_logical i 9
      k:u32[10000] = or j 1065353216
      l:f32[10000] = bitcast_convert_type[new_dtype=float32] k
      m:f32[10000] = sub l 1.0
      n:f32[1] = sub h g
      o:f32[10000] = mul m n
      p:f32[10000] = add o g
      q:f32[10000] = max g p
    in (q,) }
] r s t
    from line /media/project/inference/lippl.py:107:13 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /media/project/inference/lippl.py:108:29 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)

  operation a:f32[10000] = pjit[
  name=_uniform
  jaxpr={ lambda ; b:key<fry>[] c:i32[] d:i32[]. let
      e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
      f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
      g:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] e
      h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
      i:u32[10000] = random_bits[bit_width=32 shape=(10000,)] b
      j:u32[10000] = shift_right_logical i 9
      k:u32[10000] = or j 1065353216
      l:f32[10000] = bitcast_convert_type[new_dtype=float32] k
      m:f32[10000] = sub l 1.0
      n:f32[1] = sub h g
      o:f32[10000] = mul m n
      p:f32[10000] = add o g
      q:f32[10000] = max g p
    in (q,) }
] r s t
    from line /media/project/inference/lippl.py:108:13 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
@fchollet
Copy link
Member

fchollet commented May 6, 2024

jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on traced array with shape float32[10000,2].

This indicates that your model contains an operation that tries to retrieve the eager value of a tensor.

Earlier, I see:

jnp.exp(mass_model.log_prob(m1q) + self.logVT.predict(m1m2).flatten()),

So it sounds like you are calling predict() inside a tracing scope. This is impossible. Perhaps you meant to call self.logVT(m1m2) instead?

See also: https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call

@fchollet
Copy link
Member

fchollet commented May 6, 2024

Oh, and if your call method is stateful in any way, you'll need to use stateless_call() instead and manage the state updates manually.

@Qazalbash
Copy link
Author

So it sounds like you are calling predict() inside a tracing scope. This is impossible. Perhaps you meant to call self.logVT(m1m2) instead?

@fchollet This solved the problem. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants