Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError in odeint_adjoint #226

Open
chooron opened this issue Mar 9, 2023 · 1 comment
Open

RuntimeError in odeint_adjoint #226

chooron opened this issue Mar 9, 2023 · 1 comment

Comments

@chooron
Copy link

chooron commented Mar 9, 2023

Hello, I have run my code by using odeint successfully, however when I use the odeint_adjoint, it comes out the error:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn,
here is my code:

class M50_Func(nn.Module):
    def __init__(self, ET_net, Q_net, params, interps, ode_lib='torchdiffeq'):
        super().__init__()
        self.f, self.Smax, self.Qmax, self.Df, self.Tmax, self.Tmin = params
        self.ET_net = ET_net
        self.ET_net.train()
        self.ode_lib = ode_lib
        self.Q_net = Q_net
        self.Q_net.train()
        self.precp_interp, self.temp_interp, self.lday_interp = interps

    def forward(self, t, S):
        from models.common_net import Ps, Pr, M, step_fct
        S_snow, S_water = S[0][0], S[0][1]
        precp = self.precp_interp.evaluate(t).to(torch.float32)
        temp = self.temp_interp.evaluate(t).to(torch.float32)
        lday = self.lday_interp.evaluate(t).to(torch.float32)
        # precp = torch.from_numpy(self.precp_interp(t.numpy()).astype(np.float32)).to(device)
        # temp = torch.from_numpy(self.temp_interp(t.numpy()).astype(np.float32)).to(device)
        # lday = torch.from_numpy(self.lday_interp(t.numpy()).astype(np.float32)).to(device)
        ET_output = self.ET_net(torch.tensor([S_snow, S_water, temp]))
        Q_output = self.Q_net(torch.tensor([S_water, precp]))

        melt_output = M(S_snow, temp, self.Df, self.Tmax)
        dS_1 = Ps(precp, temp, self.Tmin) - melt_output
        dS_2 = Pr(precp, temp, self.Tmin) + melt_output - step_fct(S_water) * lday * torch.exp(
            ET_output) - step_fct(S_water) * torch.exp(Q_output)
        return torch.tensor([dS_1, dS_2]).unsqueeze(0)


class M50_Solver(BaseLearner):
    def __init__(self, solve_func: nn.Module, rtol=1e-6, atol=1e-6, ode_lib='torchdiffeq',
                 loss_metric=torch.nn.MSELoss(), eval_metric_list=None, lr=0.01, optimizer=None):
        super().__init__(solve_func, loss_metric, eval_metric_list, lr, optimizer)
        self.solve_func = solve_func
        self.solve_func.train()
        self.ode_lib = ode_lib
        self.rtol = rtol
        self.atol = atol

    def forward(self, x, t_eval):
        if len(x.shape) > 2:
            x = x[0]
        if len(t_eval.shape) > 1:
            t_eval = t_eval[0]
        t_eval = t_eval.to(torch.float32)
        y0 = torch.tensor([[x[0, 0], x[0, 1]]])
        sol = odeint_adjoint(self.solve_func, y0=y0, t=t_eval, rtol=self.rtol, atol=self.atol,
                             adjoint_options={"norm": "seminorm"})
        # adjoint_params=list(self.solve_func.ET_net.parameters())
        #                + list(self.solve_func.Q_net.parameters()))
        # sol = odeint(self.solve_func, y0=y0, t=t_eval, rtol=self.rtol, atol=self.atol)
        sol_1 = sol[:, 0, 1]
        y_hat = torch.exp(self.solve_func.Q_net(torch.concat([sol_1.unsqueeze(1), x[:, 2].unsqueeze(1)], dim=1)))
        return y_hat

The BaseLearner extends from the pytorch_lightning.LightningModule.

@haonanhe
Copy link

I also met the same problem... Have you solved it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants