New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch training of NODE with varying external input (forcing) per batch element #365
Comments
Considering the final "forcing term" example: try replacing the Then you should be able to pass in a batch-of-points, so that each batch element gets a different forcing term. Does that help? |
Thank you. I will test it and give feedback here. If I understand it correctly, the calculation of the coefficients for interpolation is happening "on the fly", i.e. during training. If so, it would be nice, to have this separated from the training process, also in terms of modularity, if one wants to change the interpolation scheme. By the way great work! Astonishing pace of new Jax packages from you :-O |
Ok, short update:
I do not know, if this makes sense. At least it seems to work (so far). |
This looks reasonable to me! |
Hi,
sry for my slightly uninformed question, but I am new to the Jax ecosystem.
I have different data sets of measurements with different excitations
u(t)
for one dynamic system, which dynamics I want to learn. So, excitation changes, but the system (ODE->NODE) is the same.I want to use equinox+diffrax to train a neural ODE via batching, which has an external input
u
, meaning the ODE is described byxdot = f(x,u(t))
. The dependencyu(t)
from time is not known explicitly (interpolation from data has to be used) and varies per batch element.Looking in the docu I found the forcing term and the batch training of NODEs.
My problem is how to combine both. My first hack was to map each
u(t)
of every batch element to non-overlapping time periods to get a unique mapping from time to the correct input time series. Then I am able to usevmap
directly viaAre there any better options to handle this? Note, that the gradient should not be calculated wrt parameters of the interpolation object representing
u(t)
.Thanks. If there are any questions, let me know.
The text was updated successfully, but these errors were encountered: