Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use event function when each time is to be pre-processed differently? #224

Open
xlk369293141 opened this issue Mar 5, 2023 · 1 comment

Comments

@xlk369293141
Copy link

xlk369293141 commented Mar 5, 2023

Thanks for your sharing!

I use a GNN as ode_interface and initial it using a unique graph at each time. The simplified code is shown below.

odefunc = GNN()
times = torch.linspace(0., 1., 10)

z = torch.randn()
for i in range(len(times)):
    odefunc.set_graph(edge[i])
    integration_time = torch.tensor([times[i], times[i+1]).float()
    solution = odeint_adjoint(odefunc, z, integration_time)
    z = solution[-1]

Because the odefunc needs to be updated at each time, the odeint calculation can only be performed at adjacent times. How can I introduce an event function in this case? It seems difficult to use odeint_event directly.
Any help is appreciated.

@rtqichen
Copy link
Owner

rtqichen commented Mar 6, 2023

Not entirely sure I follow the issue.

If the issue is that the ODE needs to be updated once you solve past t_{i+1}, then you can also set the time interval as an event (using g(t, x) = t - t_{i+1}), and if this event triggers (you can check the time of event after it returns from odeint_event), then update the odefunc. This effectively allows you to define a different event function within each time interval.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants