Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: derivative for grid_sampler_2d_backward is not implemented #34704

Open
xieshuqin opened this issue Mar 13, 2020 · 43 comments
Open
Labels
module: interpolation triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@xieshuqin
Copy link

xieshuqin commented Mar 13, 2020

Hi

When trying to compute the second order derivative of grid_sampler, the following error occurs: RuntimeError: derivative for grid_sampler_2d_backward is not implemented.

It seems useful to support second order derivative for bilinear interpolation operations like this, given bilinear interpolation is used in some common operations like deformable convs and roialign.

Another side question is that, if I write a custom operation in cuda, do I need to write a function by myself to support second order derivative of that operation? For example, the deformable conv operation in mmdet has its own cuda implementation (see here), is it possible to get the second order derivative of this operation without having to write my own implementation.

Thanks a lot

@antonfrancois
Copy link

Hello,
I got the same issue, Why this issue is closed ? And what is the workaround ?

Best regards

@lucidrains
Copy link

@Fuggy I am running into this issue as well :( Have you found a solution to this?

@antonfrancois
Copy link

Hello @lucidrains.
I endend up not using the hessian of Autograd with grid_sample and rather doing a finite differences implementation with the library kornia. In Autograd, the API for jacobian and hessian are still experimental, we can hope that they will implement it in a near future.

Best of lucks,

@sb0115basavaraju
Copy link

Hello,

I got the same issue. Any solution?

@trancelestial
Copy link

Is this issue closed because of experimental jacobian and hessian API?

@luminohope
Copy link

Getting the same issue "RuntimeError: derivative for grid_sampler_2d_backward is not implemented" when using torch::nn::functional::grid_sample with torch.autograd.grad. Any workaround?

@YuliangXiu
Copy link

I got the same issue,

RuntimeError: derivative for cudnn_grid_sampler_backward is not implemented

Any solution?

@pedrovfigueiredo
Copy link

Any updates on this? I'm having the same issue.

@BaldrLector
Copy link

@xieshuqin @YuliangXiu @pedrovfigueiredo
Hey guys, is there any solution? I got the same problem at v1.9.

@xieshuqin
Copy link
Author

xieshuqin commented Jul 12, 2021

@BaldrLector
I had to implemented the double backward function manually. Unfortunately I no longer have access to those codes. FYI, econd order derivative doesn't help too much. Setting it to 0 is probably sufficient if you just want to do double backward. At least in my use case, I didn't see any improvement with this double backward.

@AliaksandrSiarohin
Copy link

Dear @xieshuqin that you for sharing your experience. Do you remember how to register this function, so it can be set to zero?

@xieshuqin
Copy link
Author

@AliaksandrSiarohin Sorry I don't remember that. You can clone the PyTorch source code and use IDE(e.g. VSCode) to find where the function is being used, that should save you some time on finding the correct file. Another alternative is to create a custom GridSample Function/Op and use it in your code.

@BaldrLector
Copy link

BaldrLector commented Jul 13, 2021

@BaldrLector
I had to implemented the double backward function manually. Unfortunately I no longer have access to those codes. FYI, econd order derivative doesn't help too much. Setting it to 0 is probably sufficient if you just want to do double backward. At least in my use case, I didn't see any improvement with this double backward.

Got it, thanks, I rewrite the bilinear-sample with pytorch, now it can compute the second-order gradient.

the code:

def grid_sample(self, feature, uv):
        B, C, H, W = feature.shape
        _, pH, pW, _ = uv.shape
        floor_uv = torch.floor(uv).long().reshape(B, -1, 2)
        ceil_uv = floor_uv + 1
        floor_uv[..., 0] = torch.clamp(floor_uv[..., 0], 0, H - 1)
        floor_uv[..., 1] = torch.clamp(floor_uv[..., 1], 0, W - 1)
        ceil_uv[..., 0] = torch.clamp(ceil_uv[..., 0], 0, H - 1)
        ceil_uv[..., 1] = torch.clamp(ceil_uv[..., 1], 0, W - 1)
        res = []
        for i in range(B):
            pff = feature[i][:, floor_uv[i][:, 1], floor_uv[i][:, 0]]
            pfc = feature[i][:, floor_uv[i][:, 1], ceil_uv[i][:, 0]]
            pcf = feature[i][:, ceil_uv[i][:, 1], floor_uv[i][:, 0]]
            pcc = feature[i][:, ceil_uv[i][:, 1], ceil_uv[i][:, 0]]

            df = uv.reshape(B, -1, 2)[i] - floor_uv[i]
            dff = df[:, 0] * df[:, 1]
            dfc = df[:, 0] * (1 - df[:, 1])
            dcf = (1 - df[:, 0]) * df[:, 1]
            dcc = (1 - df[:, 0]) * (1 - df[:, 1])
            p = pff * dff + pfc + dfc + pcf + dcf + pcc + dcc
            res.append(p)
        res = torch.stack(res, 0).reshape(B, C, pH, pW)
        return res

@AliaksandrSiarohin
Copy link

@BaldrLector I guess this line is not correct p = pff * dff + pfc + dfc + pcf + dcf + pcc + dcc.

Here is mine implimentation if you interested:


import torch
import torch.nn.functional as F

def grid_sample(image, optical):
    N, C, IH, IW = image.shape
    _, H, W, _ = optical.shape

    ix = optical[..., 0]
    iy = optical[..., 1]

    ix = ((ix + 1) / 2) * (IW-1);
    iy = ((iy + 1) / 2) * (IH-1);
    with torch.no_grad():
        ix_nw = torch.floor(ix);
        iy_nw = torch.floor(iy);
        ix_ne = ix_nw + 1;
        iy_ne = iy_nw;
        ix_sw = ix_nw;
        iy_sw = iy_nw + 1;
        ix_se = ix_nw + 1;
        iy_se = iy_nw + 1;

    nw = (ix_se - ix)    * (iy_se - iy)
    ne = (ix    - ix_sw) * (iy_sw - iy)
    sw = (ix_ne - ix)    * (iy    - iy_ne)
    se = (ix    - ix_nw) * (iy    - iy_nw)
    
    with torch.no_grad():
        torch.clamp(ix_nw, 0, IW-1, out=ix_nw)
        torch.clamp(iy_nw, 0, IH-1, out=iy_nw)

        torch.clamp(ix_ne, 0, IW-1, out=ix_ne)
        torch.clamp(iy_ne, 0, IH-1, out=iy_ne)
 
        torch.clamp(ix_sw, 0, IW-1, out=ix_sw)
        torch.clamp(iy_sw, 0, IH-1, out=iy_sw)
 
        torch.clamp(ix_se, 0, IW-1, out=ix_se)
        torch.clamp(iy_se, 0, IH-1, out=iy_se)

    image = image.view(N, C, IH * IW)


    nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1))
    ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1))
    sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1))
    se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1))

    out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) + 
               ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) +
               sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) +
               se_val.view(N, C, H, W) * se.view(N, 1, H, W))

    return out_val


if __name__ == "__main__":
    image = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).view(1, 3, 1, 3)
    

    optical = torch.Tensor([0.9, 0.5, 0.6, -0.7]).view(1, 1, 2, 2)

    print (grid_sample(image, optical))

    print (F.grid_sample(image, optical, padding_mode='border', align_corners=True))

@BaldrLector
Copy link

@AliaksandrSiarohin You are right, and thanks for your code :)

@JavierCoronaA01023063
Copy link

Hey @AliaksandrSiarohin and @BaldrLector In what exact archive did you change: def grid_sample. ??? help pls I have the same Runtime error. Hope you answer

@BaldrLector
Copy link

Hey @AliaksandrSiarohin and @BaldrLector In what exact archive did you change: def grid_sample. ??? help pls I have the same Runtime error. Hope you answer

Hi, @JavierCoronaA01023063 I just use the new function of @AliaksandrSiarohin to replace the official grid_sample.

@vincenthesiyuan
Copy link

Hi, @AliaksandrSiarohin
thanks for your code. i currently suffer from a similar problem RuntimeError: derivative for grid_sampler_3d_backward is not implemented

i achieve a grid sample function for 3d following your code. the output is inconsistent with F.grid_sample, and i have no idea where is error.
Thank you in advance for any help you might be able to provide.

def grid_sampler(feature_3d, grid):
    N, C, iD, iH, iW = feature_3d.shape
    _, D, H, W, _ = grid.shape

    ix = grid[..., 0]
    iy = grid[..., 1]
    iz = grid[..., 2]

    ix = ((ix + 1) / 2) * (iW - 1)
    iy = ((iy + 1) / 2) * (iH - 1)
    iz = ((iz + 1) / 2) * (iD - 1)

    with torch.no_grad():
        ix_tnw = torch.floor(ix)
        iy_tnw = torch.floor(iy)
        iz_tnw = torch.floor(ix)

        ix_tne = ix_tnw + 1
        iy_tne = iy_tnw
        iz_tne = iz_tnw

        ix_tsw = ix_tnw
        iy_tsw = iy_tnw + 1
        iz_tsw = iz_tnw

        ix_tse = ix_tnw + 1
        iy_tse = iy_tnw + 1
        iz_tse = iz_tnw

        ix_bnw = ix_tnw
        iy_bnw = iy_tnw
        iz_bnw = iz_tnw + 1

        ix_bne = ix_tnw + 1
        iy_bne = iy_tnw
        iz_bne = iz_tnw + 1

        ix_bsw = ix_tnw
        iy_bsw = iy_tnw + 1
        iz_bsw = iz_tnw + 1

        ix_bse = ix_tnw + 1
        iy_bse = iy_tnw + 1
        iz_bse = iz_tnw + 1

    bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse)
    bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw)
    bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne)
    bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw)

    tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz)
    tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz)
    tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz)
    tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz)

    with torch.no_grad():
        torch.clamp(ix_bnw, 0, iW - 1, out=ix_bnw)
        torch.clamp(iy_bnw, 0, iH - 1, out=iy_bnw)
        torch.clamp(iz_bnw, 0, iD - 1, out=iz_bnw)

        torch.clamp(ix_bne, 0, iW - 1, out=ix_bne)
        torch.clamp(iy_bne, 0, iH - 1, out=iy_bne)
        torch.clamp(iz_bne, 0, iD - 1, out=iz_bne)

        torch.clamp(ix_bsw, 0, iW - 1, out=ix_bsw)
        torch.clamp(iy_bsw, 0, iH - 1, out=iy_bsw)
        torch.clamp(iz_bsw, 0, iD - 1, out=iz_bsw)

        torch.clamp(ix_bse, 0, iW - 1, out=ix_bse)
        torch.clamp(iy_bse, 0, iH - 1, out=iy_bse)
        torch.clamp(iz_bse, 0, iD - 1, out=iz_bse)

        torch.clamp(ix_tnw, 0, iW - 1, out=ix_tnw)
        torch.clamp(iy_tnw, 0, iH - 1, out=iy_tnw)
        torch.clamp(iz_tnw, 0, iD - 1, out=iz_tnw)

        torch.clamp(ix_tne, 0, iW - 1, out=ix_tne)
        torch.clamp(iy_tne, 0, iH - 1, out=iy_tne)
        torch.clamp(iz_tne, 0, iD - 1, out=iz_tne)

        torch.clamp(ix_tsw, 0, iW - 1, out=ix_tsw)
        torch.clamp(iy_tsw, 0, iH - 1, out=iy_tsw)
        torch.clamp(iz_tsw, 0, iD - 1, out=iz_tsw)

        torch.clamp(ix_tse, 0, iW - 1, out=ix_tse)
        torch.clamp(iy_tse, 0, iH - 1, out=iy_tse)
        torch.clamp(iz_tse, 0, iD - 1, out=iz_tse)

    feature_3d = feature_3d.view(N, C, iH * iW * iD)

    bnw_val = torch.gather(feature_3d, 2, (iy_bnw * iW + ix_bnw * iH + iz_bnw * iD).long().view(N, 1, D * H * W).repeat(1, C, 1))
    bne_val = torch.gather(feature_3d, 2, (iy_bne * iW + ix_bne * iH + iz_bnw * iD).long().view(N, 1, D * H * W).repeat(1, C, 1))
    bsw_val = torch.gather(feature_3d, 2, (iy_bsw * iW + ix_bsw * iH + iz_bsw * iD).long().view(N, 1, D * H * W).repeat(1, C, 1))
    bse_val = torch.gather(feature_3d, 2, (iy_bse * iW + ix_bse * iH + iz_bse * iD).long().view(N, 1, D * H * W).repeat(1, C, 1))

    tnw_val = torch.gather(feature_3d, 2, (iy_tnw * iW + ix_tnw * iH + iz_tnw * iD).long().view(N, 1, D * H * W).repeat(1, C, 1))
    tne_val = torch.gather(feature_3d, 2, (iy_tne * iW + ix_tne * iH + iz_tnw * iD).long().view(N, 1, D * H * W).repeat(1, C, 1))
    tsw_val = torch.gather(feature_3d, 2, (iy_tsw * iW + ix_tsw * iH + iz_tsw * iD).long().view(N, 1, D * H * W).repeat(1, C, 1))
    tse_val = torch.gather(feature_3d, 2, (iy_tse * iW + ix_tse * iH + iz_tse * iD).long().view(N, 1, D * H * W).repeat(1, C, 1))

    out_val = (
        bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) + 
        bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) + 
        bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) +
        bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W) +
        tnw_val.view(N, C, D, H, W) * tnw.view(N, 1, D, H, W) +
        tne_val.view(N, C, D, H, W) * tne.view(N, 1, D, H, W) +
        tsw_val.view(N, C, D, H, W) * tsw.view(N, 1, D, H, W) +
        tse_val.view(N, C, D, H, W) * tse.view(N, 1, D, H, W)
    )

    return out_val

@DongJT1996
Copy link

Here is my implementation if you are interested. @vincenthesiyuan

def grid_sample_3d(image, optical):
    N, C, ID, IH, IW = image.shape
    _, D, H, W, _ = optical.shape

    ix = optical[..., 0]
    iy = optical[..., 1]
    iz = optical[..., 2]

    ix = ((ix + 1) / 2) * (IW - 1);
    iy = ((iy + 1) / 2) * (IH - 1);
    iz = ((iz + 1) / 2) * (ID - 1);
    with torch.no_grad():
        
        ix_tnw = torch.floor(ix);
        iy_tnw = torch.floor(iy);
        iz_tnw = torch.floor(iz);

        ix_tne = ix_tnw + 1;
        iy_tne = iy_tnw;
        iz_tne = iz_tnw;

        ix_tsw = ix_tnw;
        iy_tsw = iy_tnw + 1;
        iz_tsw = iz_tnw;

        ix_tse = ix_tnw + 1;
        iy_tse = iy_tnw + 1;
        iz_tse = iz_tnw;

        ix_bnw = ix_tnw;
        iy_bnw = iy_tnw;
        iz_bnw = iz_tnw + 1;

        ix_bne = ix_tnw + 1;
        iy_bne = iy_tnw;
        iz_bne = iz_tnw + 1;

        ix_bsw = ix_tnw;
        iy_bsw = iy_tnw + 1;
        iz_bsw = iz_tnw + 1;

        ix_bse = ix_tnw + 1;
        iy_bse = iy_tnw + 1;
        iz_bse = iz_tnw + 1;

    tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
    tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
    tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
    tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
    bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
    bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
    bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
    bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);


    with torch.no_grad():

        torch.clamp(ix_tnw, 0, IW - 1, out=ix_tnw)
        torch.clamp(iy_tnw, 0, IH - 1, out=iy_tnw)
        torch.clamp(iz_tnw, 0, ID - 1, out=iz_tnw)

        torch.clamp(ix_tne, 0, IW - 1, out=ix_tne)
        torch.clamp(iy_tne, 0, IH - 1, out=iy_tne)
        torch.clamp(iz_tne, 0, ID - 1, out=iz_tne)

        torch.clamp(ix_tsw, 0, IW - 1, out=ix_tsw)
        torch.clamp(iy_tsw, 0, IH - 1, out=iy_tsw)
        torch.clamp(iz_tsw, 0, ID - 1, out=iz_tsw)

        torch.clamp(ix_tse, 0, IW - 1, out=ix_tse)
        torch.clamp(iy_tse, 0, IH - 1, out=iy_tse)
        torch.clamp(iz_tse, 0, ID - 1, out=iz_tse)

        torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw)
        torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw)
        torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw)

        torch.clamp(ix_bne, 0, IW - 1, out=ix_bne)
        torch.clamp(iy_bne, 0, IH - 1, out=iy_bne)
        torch.clamp(iz_bne, 0, ID - 1, out=iz_bne)

        torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw)
        torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw)
        torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw)

        torch.clamp(ix_bse, 0, IW - 1, out=ix_bse)
        torch.clamp(iy_bse, 0, IH - 1, out=iy_bse)
        torch.clamp(iz_bse, 0, ID - 1, out=iz_bse)

    image = image.view(N, C, ID * IH * IW)

    tnw_val = torch.gather(image, 2, (iz_tnw * IW * IH + iy_tnw * IW + ix_tnw).long().view(N, 1, D * H * W).repeat(1, C, 1))
    tne_val = torch.gather(image, 2, (iz_tne * IW * IH + iy_tne * IW + ix_tne).long().view(N, 1, D * H * W).repeat(1, C, 1))
    tsw_val = torch.gather(image, 2, (iz_tsw * IW * IH + iy_tsw * IW + ix_tsw).long().view(N, 1, D * H * W).repeat(1, C, 1))
    tse_val = torch.gather(image, 2, (iz_tse * IW * IH + iy_tse * IW + ix_tse).long().view(N, 1, D * H * W).repeat(1, C, 1))
    bnw_val = torch.gather(image, 2, (iz_bnw * IW * IH + iy_bnw * IW + ix_bnw).long().view(N, 1, D * H * W).repeat(1, C, 1))
    bne_val = torch.gather(image, 2, (iz_bne * IW * IH + iy_bne * IW + ix_bne).long().view(N, 1, D * H * W).repeat(1, C, 1))
    bsw_val = torch.gather(image, 2, (iz_bsw * IW * IH + iy_bsw * IW + ix_bsw).long().view(N, 1, D * H * W).repeat(1, C, 1))
    bse_val = torch.gather(image, 2, (iz_bse * IW * IH + iy_bse * IW + ix_bse).long().view(N, 1, D * H * W).repeat(1, C, 1))

    out_val = (tnw_val.view(N, C, D, H, W) * tnw.view(N, 1, D, H, W) +
               tne_val.view(N, C, D, H, W) * tne.view(N, 1, D, H, W) +
               tsw_val.view(N, C, D, H, W) * tsw.view(N, 1, D, H, W) +
               tse_val.view(N, C, D, H, W) * tse.view(N, 1, D, H, W) +
               bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) +
               bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) +
               bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) +
               bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W))

    return out_val

if __name__ == "__main__":
    if True:
        image = torch.rand(1, 3, 200, 300, 100)
        grid = torch.rand(1, 100, 100, 2, 3)
        start = time.time()
        output1 = grid_sample_3d(image, grid)
        end = time.time()
        print('using {}'.format(end - start))

        start = time.time()
        output2 = F.grid_sample(image, grid, padding_mode='border', align_corners=True)
        end = time.time()
        print('using {}'.format(end - start))

@ngimel
Copy link
Collaborator

ngimel commented Feb 17, 2022

Reopening due to users activity

@ngimel ngimel reopened this Feb 17, 2022
@ngimel ngimel added module: interpolation triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 17, 2022
@hachreak
Copy link

hachreak commented Mar 1, 2022

+1

I have the same issue wit torch 1.10.0

@Tanzman
Copy link

Tanzman commented Mar 10, 2022

RuntimeError: derivative for aten::grid_sampler_2d_backward is not implemented

@ken012git
Copy link

Similar issue. How should I workaround?
RuntimeError: derivative for aten::hardswish_backward is not implemented

@vasudev-sharma
Copy link

Same issue

@vaibhavnayel
Copy link

same issue

@MontaEllis
Copy link

Same issue

@NIRVANALAN
Copy link

same issue

1 similar comment
@cocoakang
Copy link

same issue

@cocoakang
Copy link

cocoakang commented May 10, 2022

Hi everyone,
Here is an alternative way to fix this problem.
In official stylegan2 implementation, they provide an implementation for grid_sampler_2d_backward. In most cases, you can directly use this function.
You just need to import the grid_sample_gradfix.py and use the grid_sample function.

@NIRVANALAN
Copy link

thanks @cocoakang
also, since the stylegan-ada version only supports pytorch <=1.9, we can use an updated version at
stylegan3

@cocoakang
Copy link

@NIRVANALAN Thank you!

@SJoJoK
Copy link

SJoJoK commented Dec 2, 2022

same issue for 3d

@linye-boli
Copy link

sample issue for 3d "RuntimeError: derivative for aten::grid_sampler_3d_backward is not implemented"

@CeoiZidung
Copy link

@linye-boli hi~ Same issue for me. Have your problem been solved? Asking for your help!

@AliaksandrSiarohin
Copy link

Here is the cuda-based implementation with second order derivative https://github.com/AliaksandrSiarohin/cuda-gridsample-grad2. As far as I understand the style-gan grad_fix implementation assumes that the derivative wrp to grid is always zero, which makes sense if the grid is constant, however when the grid is predicted it will be incorrect. So this implementation should be more complete. Tell me if you find any problems with it.

@MontaEllis
Copy link

Here is the cuda-based implementation with second order derivative https://github.com/AliaksandrSiarohin/cuda-gridsample-grad2. As far as I understand the style-gan grad_fix implementation assumes that the derivative wrp to grid is always zero, which makes sense if the grid is constant, however when the grid is predicted it will be incorrect. So this implementation should be more complete. Tell me if you find any problems with it.

Has anyone tested this code on DDP?

@artths
Copy link

artths commented Aug 29, 2023

If AliaksandrSiarohin implementation crashes for you, add this line at the beginning:
optical = torch.nan_to_num(optical, nan=-1)
to match the PyTorch implementation:
"NaN values in the grid will be interpreted as -1."

@pzpzpzp2
Copy link

Same issue on pytorch version '2.0.1+cu117'.
Been 3 years since the start of the thread. With bilinear interpolation, the second derivative is just a constant matrix. This shouldn't be so hard for pytorch to fix. Not too sure how autograd gets computed under the hood but shouldn't there be a computation graph to the first deriv they can just autograd through for the second deriv?

@naabdi
Copy link

naabdi commented Dec 7, 2023

I want to use https://github.com/facebookresearch/jacobian_regularizer but I encounter these errors:

derivative for aten::grid_sampler_2d_backward is not implemented
derivative for aten::cudnn_sampler_2d_backward is not implemented

I am on torch '2.1.1+cu121'

@Hommoner
Copy link

I solved this issue by update this two files in folder ..torch_utils\ops\ , conv2d_gradfix.py & grid_sample_gradfix.py
from NVlabs/stylegan3
I am on torch '1.12.0+cu113'

@eoozbq
Copy link

eoozbq commented Feb 27, 2024

same issue for 3d

hello, I have the same question. I want to know, how did you solve it. Thank you very much.

@shukdevtroy
Copy link

@BaldrLector I guess this line is not correct p = pff * dff + pfc + dfc + pcf + dcf + pcc + dcc.

Here is mine implimentation if you interested:


import torch
import torch.nn.functional as F

def grid_sample(image, optical):
    N, C, IH, IW = image.shape
    _, H, W, _ = optical.shape

    ix = optical[..., 0]
    iy = optical[..., 1]

    ix = ((ix + 1) / 2) * (IW-1);
    iy = ((iy + 1) / 2) * (IH-1);
    with torch.no_grad():
        ix_nw = torch.floor(ix);
        iy_nw = torch.floor(iy);
        ix_ne = ix_nw + 1;
        iy_ne = iy_nw;
        ix_sw = ix_nw;
        iy_sw = iy_nw + 1;
        ix_se = ix_nw + 1;
        iy_se = iy_nw + 1;

    nw = (ix_se - ix)    * (iy_se - iy)
    ne = (ix    - ix_sw) * (iy_sw - iy)
    sw = (ix_ne - ix)    * (iy    - iy_ne)
    se = (ix    - ix_nw) * (iy    - iy_nw)
    
    with torch.no_grad():
        torch.clamp(ix_nw, 0, IW-1, out=ix_nw)
        torch.clamp(iy_nw, 0, IH-1, out=iy_nw)

        torch.clamp(ix_ne, 0, IW-1, out=ix_ne)
        torch.clamp(iy_ne, 0, IH-1, out=iy_ne)
 
        torch.clamp(ix_sw, 0, IW-1, out=ix_sw)
        torch.clamp(iy_sw, 0, IH-1, out=iy_sw)
 
        torch.clamp(ix_se, 0, IW-1, out=ix_se)
        torch.clamp(iy_se, 0, IH-1, out=iy_se)

    image = image.view(N, C, IH * IW)


    nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1))
    ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1))
    sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1))
    se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1))

    out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) + 
               ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) +
               sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) +
               se_val.view(N, C, H, W) * se.view(N, 1, H, W))

    return out_val


if __name__ == "__main__":
    image = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).view(1, 3, 1, 3)
    

    optical = torch.Tensor([0.9, 0.5, 0.6, -0.7]).view(1, 1, 2, 2)

    print (grid_sample(image, optical))

    print (F.grid_sample(image, optical, padding_mode='border', align_corners=True))

where to write this code

@J-Jul
Copy link

J-Jul commented May 2, 2024

Any fix, or update on this issue? Trying to get StyleGAN2 to work again in CoLab and this error stops it from training..

@BaldrLector I guess this line is not correct p = pff * dff + pfc + dfc + pcf + dcf + pcc + dcc.
Here is mine implimentation if you interested:


import torch
import torch.nn.functional as F

def grid_sample(image, optical):
    N, C, IH, IW = image.shape
    _, H, W, _ = optical.shape

    ix = optical[..., 0]
    iy = optical[..., 1]

    ix = ((ix + 1) / 2) * (IW-1);
    iy = ((iy + 1) / 2) * (IH-1);
    with torch.no_grad():
        ix_nw = torch.floor(ix);
        iy_nw = torch.floor(iy);
        ix_ne = ix_nw + 1;
        iy_ne = iy_nw;
        ix_sw = ix_nw;
        iy_sw = iy_nw + 1;
        ix_se = ix_nw + 1;
        iy_se = iy_nw + 1;

    nw = (ix_se - ix)    * (iy_se - iy)
    ne = (ix    - ix_sw) * (iy_sw - iy)
    sw = (ix_ne - ix)    * (iy    - iy_ne)
    se = (ix    - ix_nw) * (iy    - iy_nw)
    
    with torch.no_grad():
        torch.clamp(ix_nw, 0, IW-1, out=ix_nw)
        torch.clamp(iy_nw, 0, IH-1, out=iy_nw)

        torch.clamp(ix_ne, 0, IW-1, out=ix_ne)
        torch.clamp(iy_ne, 0, IH-1, out=iy_ne)
 
        torch.clamp(ix_sw, 0, IW-1, out=ix_sw)
        torch.clamp(iy_sw, 0, IH-1, out=iy_sw)
 
        torch.clamp(ix_se, 0, IW-1, out=ix_se)
        torch.clamp(iy_se, 0, IH-1, out=iy_se)

    image = image.view(N, C, IH * IW)


    nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1))
    ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1))
    sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1))
    se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1))

    out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) + 
               ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) +
               sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) +
               se_val.view(N, C, H, W) * se.view(N, 1, H, W))

    return out_val


if __name__ == "__main__":
    image = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).view(1, 3, 1, 3)
    

    optical = torch.Tensor([0.9, 0.5, 0.6, -0.7]).view(1, 1, 2, 2)

    print (grid_sample(image, optical))

    print (F.grid_sample(image, optical, padding_mode='border', align_corners=True))

where to write this code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: interpolation triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests