Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to use vmap on a function containing the ode solver #217

Open
adam-hartshorne opened this issue Jan 16, 2023 · 1 comment
Open

Unable to use vmap on a function containing the ode solver #217

adam-hartshorne opened this issue Jan 16, 2023 · 1 comment

Comments

@adam-hartshorne
Copy link

If you have say a solve method

def f(t,x):
 ....

def solve(y0):
    t_eval = torch.linspace(0.0, 1.0, 2)
    traj = ode_int(f, y0, t_eval)
    return traj[-1]

def other_func(X):
    mapped_traj = vmap(solve)(X)

where X is for example shape (n,m,d)

This will result in a
RuntimeError: functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. Please rewrite your function to not use autograd.Function while we work on fixing this

I believe the solution is now provided in the nightly builds, as described here.

https://pytorch.org/docs/master/notes/extending.func.html

Which requires a small alteration to core methods.

@rtqichen
Copy link
Owner

I likely won't have time to support this, but will gladly take a PR for it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants