Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve error message and docs for custom_jvp / custom_vjp nondiff_argnums escaped tracer error #20889

Open
mattjj opened this issue Apr 23, 2024 · 0 comments
Assignees
Labels
better_errors Improve the error reporting documentation

Comments

@mattjj
Copy link
Member

mattjj commented Apr 23, 2024

When using custom_jvp or custom_vjp, don't use nondiff_argnums for array-valued arguments. It'll often lead to "encountered an unexpected tracer" errors.

But we should raise a better error, and make the docs more discoverable (and clearer).

Here's a repro from a user:

import jax
from jax import numpy as jnp

def func_fwd(arr, mask):
  return arr * mask

def func_jvp(mask, primals, tangents):
  def f(arr):
    return arr * mask
  return jax.jvp(f, primals, tangents)

func = jax.custom_jvp(func_fwd, nondiff_argnums=(1,))
func.defjvp(func_jvp)

def step(carry, _):
  return (func(*carry), carry[1]), None

def loss(x, mask):
  carry, _ = jax.lax.scan(step, (x, mask), [None] * 2, length = 2)
  return carry[0].sum()

x = jnp.ones(10)
mask = jnp.zeros(10, dtype=bool)


gradients = jax.grad(loss)(x, mask)
@mattjj mattjj added documentation better_errors Improve the error reporting labels Apr 23, 2024
@mattjj mattjj self-assigned this Apr 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better_errors Improve the error reporting documentation
Projects
None yet
Development

No branches or pull requests

1 participant