You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def f(t,x):
....
def solve(y0):
t_eval = torch.linspace(0.0, 1.0, 2)
traj = ode_int(f, y0, t_eval)
return traj[-1]
def other_func(X):
mapped_traj = vmap(solve)(X)
where X is for example shape (n,m,d)
This will result in a
RuntimeError: functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. Please rewrite your function to not use autograd.Function while we work on fixing this
I believe the solution is now provided in the nightly builds, as described here.
If you have say a solve method
This will result in a
RuntimeError: functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. Please rewrite your function to not use autograd.Function while we work on fixing this
I believe the solution is now provided in the nightly builds, as described here.
https://pytorch.org/docs/master/notes/extending.func.html
Which requires a small alteration to core methods.
The text was updated successfully, but these errors were encountered: