Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchscripting not possible with NeuralODE due to function redefinition #163

Open
StephenHogg opened this issue Aug 10, 2022 · 5 comments
Open
Labels
bug Something isn't working

Comments

@StephenHogg
Copy link

Describe the bug

Torchscript is unable to script the NeuralODE class due to a function being redefined. This is a problem because there is control flow present in the code that tracing would not necessarily respect, implying that alternative would produce an incorrect output.

Step to Reproduce

Minimal working example:

f = nn.Sequential(
        nn.Linear(2, 16),
        nn.Tanh(),
        nn.Linear(16, 2)
    )

model = NeuralODE(f, solver='tsit5', solver_adjoint='dopri5')
out = torch.jit.script(model)

produces the following errors:

The first time you run it, the error is:

forward(__torch__.torch.nn.modules.container.Sequential self, Tensor input) -> (Tensor):
Expected at most 2 arguments but found 3 positional arguments.
:
  File "/home/shogg/.pyenv/versions/3.8.12/envs/mldi/lib/python3.8/site-packages/torchdyn/core/defunc.py", line 32
    def forward(self, t:Tensor, x:Tensor) -> Tensor:
        self.nfe += 1
        if self.has_time_arg: return self.vf(t, x)
                                     ~~~~~~~ <--- HERE
        else: return self.vf(x)

on subsequent attempts, the error is:

RuntimeError: Can't redefine method: forward on class: __torch__.torchdyn.core.defunc.DEFuncBase (of Python compilation unit at: 0x560bca2c9a70)

Note that changing the solvers doesn't appear to change anything.

Expected behavior

Torchscript should not break

I'm using the latest available version of torchdyn from pip and torch==1.11.0+cu102, would be very grateful for your advice as to how to torchscript this safely.

@StephenHogg StephenHogg added the bug Something isn't working label Aug 10, 2022
@StephenHogg
Copy link
Author

StephenHogg commented Aug 11, 2022

Came to the conclusion that it's likely because of forward() getting wrapped to include a time argument. I gave NeuralODE a network with the time arg already added to avoid this and came across this error instead:

RuntimeError: 
Unknown type name 'Iterable':
  File "/home/shogg/.pyenv/versions/3.8.12/envs/mldi/lib/python3.8/site-packages/torchdyn/core/neuralde.py", line 92
    def forward(self, x:Tensor, t_span:Tensor=None, save_at:Iterable=(), args={}):
                                                            ~~~~~~~~ <--- HERE
        x, t_span = self._prep_integration(x, t_span)
        t_eval, sol =  super().forward(x, t_span, save_at, args)

The weird thing about this is that I can see that Iterable is definitely imported at the top of the file

@StephenHogg
Copy link
Author

I did a bit more looking and can see that torchscript is not happy with Iterable as an annotation. Here's the list of allowable type hints, is it possible to change to one of these? Happy to write a PR if so

https://pytorch.org/docs/stable/jit_language_reference.html#supported-type

@StephenHogg
Copy link
Author

Update again: removing that type annotation revealed another annotation to be incorrect and also revealed a problem in the way forward is called. Seems like you guys would probably want to do a refactor to make torchscripting safe with this package. Let me know if you're interested in doing it and I'd be happy to work with you on it, it would be great to be able to serialise these models.

@Zymrael
Copy link
Member

Zymrael commented Aug 12, 2022

Thanks for looking into this! I've started a refactor in this branch.

I fixed some of the typing inconsistencies and managed to push through to:

RuntimeError: 
'Tensor' object has no attribute or method 'forward'.:
  File "/home/stefano/michael/diffeqml/torchdyn/torchdyn/core/neuralde.py", line 94
    def forward(self, x:Tensor, t_span:Tensor=None, save_at:Tensor=()):
        x, t_span = self._prep_integration(x, t_span)
        t_eval, sol =  super().forward(x, t_span, save_at)
                       ~~~~~~~~~~~~~ <--- HERE
        if self.return_t_eval: return t_eval, sol
        else: return sol

Is that the error you observed?

@StephenHogg
Copy link
Author

Hi @Zymrael - yes, that one is a problem because what you would do there is:

super(ODEProblem, self).forward(x, t_span, save_at)

but torchscript isn't happy with types being passed as parameters and errors if you do that fix. This implies that some of the inheritance the current factoring of the code relies on would potentially need to be addressed - this isn't necessarily a bad thing if it reduces the amount of redirection in the code a little anyway. Hope this helps, happy to keep chatting on this one if it helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants