Are there jit
compile-time guards?
#20737
-
Are there >>> def custom_log(x):
... if not jax.is_jitting: # Or a better name ...
... np.testing.assert_array_less(0, x)
... return jnp.log(x)
>>> jax.jit(custom_log)(jnp.zeros(1))
Array([-inf], dtype=float32)
>>> custom_log(jnp.zeros(1))
AssertionError:
Arrays are not less-ordered
Mismatched elements: 1 [/](http://127.0.0.1:8888/) 1 (100%)
Max absolute difference: 0.
Max relative difference: inf
x: array(0)
y: array([0.], dtype=float32) Currently, I just comment out all my assertions before jitting and then uncomment them for more debugging. Is there a better approach? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
No, we deliberately avoid providing any API along those lines, because as soon as you write code that operates differently at trace-time and run-time, you break JAX's tracing model and open the door to difficult to track bugs (e.g. Even if it were possible, the "compile-time only" assertion approach wouldn't even work correctly, because we don't recompile on subsequent function calls. For example: jax.jit(custom_log)(jnp.array([1.0])) # passes the assertion
jax.jit(custom_log)(jnp.array([0.0])) # hits JIT cache, so compile-time assertion doesn't run It sounds like what you're after is not compile-time assertions, but rather runtime value assertions. There are essentially two supported ways to do that right now: For what it's worth, JAX itself avoids this kind of runtime assertion entirely, instead returning invalid values like |
Beta Was this translation helpful? Give feedback.
It's certainly possible to write code that does something like what you have in mind using a number of mechanisms, but we do not recommend it.