You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have trained a simple Deep-MLP model and saved it in .keras format. I am utilizing JAX jitted functions for predictions, passing two inputs as jax.numpy.column_stack. Despite attempting alternative methods, including using numpy.column_stack and setting JAX_TRACEBACK_FILTERING=off, the issue persists. Notably, my Keras backend is configured as KERAS_BACKEND=jax.
File "/media/project/inference/lippl.py", line 114, in exp_rate_integral
jnp.exp(mass_model.log_prob(m1q) + self.logVT.predict(m1m2).flatten()),
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gradf/miniforge3/envs/gwkenv/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/gradf/miniforge3/envs/gwkenv/lib/python3.11/site-packages/optree/ops.py", line 594, in tree_map
return treespec.unflatten(map(func, *flat_args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gradf/miniforge3/envs/gwkenv/lib/python3.11/site-packages/jax/_src/core.py", line 684, in __array__
raise TracerArrayConversionError(self)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[10000,2].
The error occurred while tracing the functionlikelihood at /media/project/inference/lippl.py:119 for jit. This value became a tracer due to JAX operations on these lines:
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
from line /media/project/inference/lippl.py:107:29 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)
operation a:f32[10000] = pjit[
name=_uniform
jaxpr={ lambda ; b:key<fry>[] c:i32[] d:i32[]. let
e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
g:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] e
h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
i:u32[10000] = random_bits[bit_width=32 shape=(10000,)] b
j:u32[10000] = shift_right_logical i 9
k:u32[10000] = or j 1065353216
l:f32[10000] = bitcast_convert_type[new_dtype=float32] k
m:f32[10000] = sub l 1.0
n:f32[1] = sub h g
o:f32[10000] = mul m n
p:f32[10000] = add o g
q:f32[10000] = max g p
in (q,) }
] r s t
from line /media/project/inference/lippl.py:107:13 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
from line /media/project/inference/lippl.py:108:29 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)
operation a:f32[10000] = pjit[
name=_uniform
jaxpr={ lambda ; b:key<fry>[] c:i32[] d:i32[]. let
e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
g:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] e
h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
i:u32[10000] = random_bits[bit_width=32 shape=(10000,)] b
j:u32[10000] = shift_right_logical i 9
k:u32[10000] = or j 1065353216
l:f32[10000] = bitcast_convert_type[new_dtype=float32] k
m:f32[10000] = sub l 1.0
n:f32[1] = sub h g
o:f32[10000] = mul m n
p:f32[10000] = add o g
q:f32[10000] = max g p
in (q,) }
] r s t
from line /media/project/inference/lippl.py:108:13 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The text was updated successfully, but these errors were encountered:
I have trained a simple Deep-MLP model and saved it in
.keras
format. I am utilizing JAX jitted functions for predictions, passing two inputs asjax.numpy.column_stack
. Despite attempting alternative methods, including usingnumpy.column_stack
and settingJAX_TRACEBACK_FILTERING=off
, the issue persists. Notably, my Keras backend is configured asKERAS_BACKEND=jax
.The text was updated successfully, but these errors were encountered: