Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing in multiple arguments #137

Open
varunagrawal opened this issue May 5, 2022 · 4 comments
Open

Passing in multiple arguments #137

varunagrawal opened this issue May 5, 2022 · 4 comments
Labels
question Further information is requested

Comments

@varunagrawal
Copy link

varunagrawal commented May 5, 2022

Additional Description

I have a network I wish to train f(x, x_dot, theta) where x and x_dot are the inputs, theta are the network weights. This is a slightly odd problem since x_dot is the corrupted derivative of x and I wish to train a network to give me the correct x_dot. To solve the ODE, I need to pass in x at t=0 but the network itself doesn't use x in its forward pass, only x_dot.

How would I passi n multiple arguments like this to a NeuralODE in torchdyn? I am guessing the way to do this is to concatenate the two so I get x_x_dot = torch.cat((x, x_dot)) but I am not sure if this is correct.

In torchdiffeq, what I did was call the solver like so

class Network(nn.Module):
    def forward(t, args):
        x, x_dot = args
        return self.mlp(x_dot)

y = odeint_adjoint(network, (x_i, x_dot), t_span)

what would be the equivalent approach in torchdyn?

@varunagrawal varunagrawal added the question Further information is requested label May 5, 2022
@joglekara
Copy link
Contributor

Thanks for the Q!

The concat approach should work just fine but I agree that it's not necessarily the most transparent.

Using a pytree based approach is likely most flexible in the long run, though, so that's also something we're keeping an eye on.

@varunagrawal
Copy link
Author

A pytree seems like a large hammer for a small nail. If there is no direct support for passing in tuples as arguments, I imagine that would be easier to add in the short term. Just a couple of checks (isinstance(x, tuple) and isinstance(x[i], torch.Tensor) and then continue from there.

@joglekara
Copy link
Contributor

Update here, fixed_odeint supports the state as a dict for now. We have yet to extend it to the adaptive solver.

@zjowowen
Copy link

zjowowen commented May 6, 2024

Hi, I have similar needs when coding my repo based on torchdyn.

I need the function of odeint to support passing some data type as (dx, dlog(x)) for building generative models such as a continuous normalizing flow. The variable of x should be tensor of any shape, while dlog is simply a scalar.
(I tried to reshape and concat these tensors into one tensor and do the reverse when calling modules. But it is harmful for grad_fn to go backward.)

Temporarily, it seems that I have to turn back to torchdiffeq, which accepts tuple data type input.

I suggest torchdyn to support tree-like tensor data type input. One of the implementation is https://github.com/opendilab/treevalue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants