Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implicit Euler method is not compatible with neural ode training #156

Open
cyx96 opened this issue Jul 9, 2022 · 2 comments
Open

Implicit Euler method is not compatible with neural ode training #156

cyx96 opened this issue Jul 9, 2022 · 2 comments
Labels
bug Something isn't working

Comments

@cyx96
Copy link

cyx96 commented Jul 9, 2022

First of all, thank you for this amazing library!

For my own research, I want to test implicit integrators with adjoint sensitivity method to train a neural network. While implicit Euler method provided by the library can be used to integrate ODEs (no gradient required), it is incompatible with adjoint sensitivity method which requires gradient information. The code snippet to reproduce the error is provided as below.

import torch
import torch.nn as nn
import torch.utils.data as data

import pytorch_lightning as pl

from torchdyn.core import NeuralODE
from torchdyn.datasets import *
from torchdyn import *

device = torch.device("cpu") # all of this works in GPU as well :)

X_train = torch.Tensor(X).to(device)
y_train = torch.LongTensor(yn.long()).to(device)
train = data.TensorDataset(X_train, y_train)
trainloader = data.DataLoader(train, batch_size=len(X), shuffle=True)

class Learner(pl.LightningModule):
    def __init__(self, t_span:torch.Tensor, model:nn.Module):
        super().__init__()
        self.model, self.t_span = model, t_span

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        t_eval, y_hat = self.model(x, t_span)
        y_hat = y_hat[-1] # select last point of solution trajectory
        loss = nn.CrossEntropyLoss()(y_hat, y)
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.01)

    def train_dataloader(self):
        return trainloader

f = nn.Sequential(
        nn.Linear(2, 16),
        nn.Tanh(),
        nn.Linear(16, 2)
    )

model = NeuralODE(f, sensitivity='adjoint', solver='ieuler').to(device)

learn = Learner(t_span, model)
trainer = pl.Trainer(min_epochs=200, max_epochs=300)
trainer.fit(learn)

I will get the RuntimeError by executing the above code snippet:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Changing adjoint to autograd prevented RuntimeError from happening, but the loss will not decrease during training. And I tried to modify the implicit Euler method by changing retain_graph from False to True, it didn't solve the issue. I think this issue has something to do with the LBFGS optimizer used by the library to find roots for the implicit integrator, but I don't really know how to fix it.

Any help on this issue is much appreciated! Thanks in advance!

@cyx96 cyx96 added the bug Something isn't working label Jul 9, 2022
@Zymrael
Copy link
Member

Zymrael commented Jul 11, 2022

Providing an update here based on our private chat, in case someone else is interested: this should be done using IFT at the fixed-point, implementing a custom backward for the implicit step.

@cyx96 graciously offered to help implement this

@joglekara
Copy link
Contributor

for future readers who spend a few minutes trying to figure out what IFT is...

IFT = Implicit Function Theorem

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants