Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to pass extra paramaters of func to odeint? #246

Open
shifttttttt opened this issue Jan 23, 2024 · 2 comments
Open

How to pass extra paramaters of func to odeint? #246

shifttttttt opened this issue Jan 23, 2024 · 2 comments

Comments

@shifttttttt
Copy link

I look the defination of deint in torchdiffeq , but do not find a paramater to pass extra paramaters like the args paramater in scipy.integrate.odeint. Is there any other way to pass paramaters to odeint besides define a global variable?

@rtqichen
Copy link
Owner

rtqichen commented Mar 12, 2024

Yeah, just define it anywhere. In order to use odeint_adjoint, it's good practice to define them as part of the module.


global_params = ...

class ODEfunc(nn.Module):

  def __init__(self):
    self.parameters = nn.Parameter(some_tensor_we_want_to_optimize_ie_compute_gradients_for)

  def forward(self, t, x):
    p = self.parameters()
    external_p = global_params
    # some ops regarding t, x, p, external_p
    return ...

If you use odeint, gradient will be computed w.r.t. external_p, but odeint_adjoint will only do it for p.

@shifttttttt
Copy link
Author

Thanks for your answer!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants