Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Argument mismatch and hard-coded return_all_eval #195

Open
cantabile-kwok opened this issue Jul 13, 2023 · 0 comments
Open

Argument mismatch and hard-coded return_all_eval #195

cantabile-kwok opened this issue Jul 13, 2023 · 0 comments
Labels
bug Something isn't working

Comments

@cantabile-kwok
Copy link

cantabile-kwok commented Jul 13, 2023

There are mismatched arguments in problems.ODEProblem.odeint
My torchdyn version is 1.0.3
Step to Reproduce
I want to see how many steps did the adaptive dopri5 solver take, so I sought for return_all_eval argument according to issue #131. Then I found the NeuralODE class does not provide such a keyword argument here, so after a little bit diving into the source code I decided to put args={'return_all_eval': True}. However, this still does not give the desired result. The code snippet is:

from torchdyn.core import NeuralODE
import torch
import torch.nn as nn


class VectorField(nn.Module):
    def __init__(self):
        super(VectorField, self).__init__()
        self.net = nn.Linear(2, 2)

    def forward(self, t, x):
        print(f"In VectorField, t is fed as {t}")
        return self.net(t+x)


vf = VectorField()
ode = NeuralODE(vf, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
time = torch.linspace(0, 1, 10)
initial = torch.randn(16, 20, 2)
eval_time, sol = ode(initial, time, args={'return_all_eval': True})
print(sol.shape)

Then, I found the return_all_eval keyword is not actually passed into the numerics.odeint.odeint function. The signature of that function is

def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, nn.Module], atol:float=1e-3, rtol:float=1e-3,
		   t_stops:Union[List, Tensor, None]=None, verbose:bool=False, interpolator:Union[str, Callable, None]=None, return_all_eval:bool=False,
		   save_at:Union[List, Tensor]=(), args:Dict={}, seminorm:Tuple[bool, Union[int, None]]=(False, None)) -> Tuple[Tensor, Tensor]:

so you can see return_all_eval is explicitly passed, but in numerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward it is hard-coded as False:

def forward(ctx, vf_params, x, t_span, B=None, save_at=()):
            t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B, 
                                        False, maxiter, fine_steps, save_at)
            ctx.save_for_backward(sol, t_sol)
            return t_sol, sol

So, basically I don't have any chance to switch it on except changing the source code.

Another thing is the argument mismatch issue of the numerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward function. When it is called from odeint like

return self._autograd_func()(self.vf_params, x, t_span, save_at, args)
, the arguments are mismatched from the signature of that forward function. This means the save_at argument will actually be overwritten by a dict and the B (which I do not understand) argument is actually the true save_at. This so far has not caused any problems in my code but I don't believe this is an expected behavior. I suggest someone take a deep debug into the code to have a look.

Screenshots
There is a traceback that shows the problem.
image

Expected behavior

The return_all_eval option should be handled by user and control whether the ODE solver produces all the evaluation time slots.
Also, there is a huge lack of documentation on the meaning of these arguments and the provided functionalities, e.g. it is not until I found that github issue did I realize that there is a way to return all the evaluation time stamps.

@cantabile-kwok cantabile-kwok added the bug Something isn't working label Jul 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant