Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch training of NODE with varying external input (forcing) per batch element #365

Open
FabiJa opened this issue Jan 31, 2024 · 4 comments
Labels
question User queries

Comments

@FabiJa
Copy link

FabiJa commented Jan 31, 2024

Hi,

sry for my slightly uninformed question, but I am new to the Jax ecosystem.

I have different data sets of measurements with different excitations u(t) for one dynamic system, which dynamics I want to learn. So, excitation changes, but the system (ODE->NODE) is the same.

I want to use equinox+diffrax to train a neural ODE via batching, which has an external input u, meaning the ODE is described by xdot = f(x,u(t)). The dependency u(t) from time is not known explicitly (interpolation from data has to be used) and varies per batch element.

Looking in the docu I found the forcing term and the batch training of NODEs.
My problem is how to combine both. My first hack was to map each u(t) of every batch element to non-overlapping time periods to get a unique mapping from time to the correct input time series. Then I am able to use vmap directly via

@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
    y_pred = jax.vmap(model, in_axes=(0, 0, None))(ti, yi[:, 0], input_concatenated)
    return jnp.mean((yi - y_pred) ** 2)

Are there any better options to handle this? Note, that the gradient should not be calculated wrt parameters of the interpolation object representing u(t).

Thanks. If there are any questions, let me know.

@patrick-kidger
Copy link
Owner

Considering the final "forcing term" example: try replacing the jax.grad there with a jax.vmap.

Then you should be able to pass in a batch-of-points, so that each batch element gets a different forcing term.

Does that help?

@patrick-kidger patrick-kidger added the question User queries label Jan 31, 2024
@FabiJa
Copy link
Author

FabiJa commented Jan 31, 2024

Thank you. I will test it and give feedback here.

If I understand it correctly, the calculation of the coefficients for interpolation is happening "on the fly", i.e. during training. If so, it would be nice, to have this separated from the training process, also in terms of modularity, if one wants to change the interpolation scheme.

By the way great work! Astonishing pace of new Jax packages from you :-O

@FabiJa
Copy link
Author

FabiJa commented Feb 1, 2024

Ok, short update:
I opted for an alternative way using the NODE example:

@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
    y_pred = jax.vmap(model, in_axes=(0, 0, 0, 0))(ti, yi[:, 0], ts_inputs_i, coeff_array_i)
    return jnp.mean((yi - y_pred) ** 2)

  @eqx.filter_jit
  def make_step(ti, yi, model, opt_state):
      loss, grads = grad_loss(model, ti, yi)
      updates, opt_state = optim.update(grads, opt_state)
      model = eqx.apply_updates(model, updates)
      return loss, model, opt_state


...
data_for_batching = (_ys, _ts, ts_inputs_array, coeff_array)
for step, batch_data in zip(
          range(steps), dataloader(data_for_batching, batch_size, key=loader_key)
      ):
        yi, ti, ts_inputs_i, coeff_array_i = batch_data
        start = time.time()
        loss, model_train, opt_state = make_step(ti, yi, model_train, opt_state)
        end = time.time()
        if (step % print_every) == 0 or step == steps - 1:
            print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")

I do not know, if this makes sense. At least it seems to work (so far). coeff_array_i are the coefficient of e.g. backward_hermite_coefficients.

@patrick-kidger
Copy link
Owner

This looks reasonable to me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants