Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confused with the backward of RasterizeAutograd #9

Open
Wu-Xiuchao opened this issue Mar 8, 2021 · 2 comments
Open

Confused with the backward of RasterizeAutograd #9

Wu-Xiuchao opened this issue Mar 8, 2021 · 2 comments

Comments

@Wu-Xiuchao
Copy link

class RasterizeAutograd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, rho, rhoValues, Ws, projPoints, boundingBoxes, inplane, Ms, cameraPoints,
                width, height, localWidth, localHeight, camFar, focalLength, mergeThreshold, considerZ, topK):
        """
        input:
            rho            BxNxhxw   rho evaluated in a BB of (h, w) around projected points
            rhoValues      BxNx1     normalizing term of rho
            Ws             BxNx3     point color
            projPoints     BxNx2or3  projected point location
            boundingBoxes  BxNx4     bounding boxes (w0, h0, width, height)
            inplane        BxNxhxwx3 window back projected in space
            Ms             BxNx2x2   inverse of gaussian variance
            cameraPoints   BxNx3     point location in camera space
            width          scalar
            height         scalar
            localWidth     scalar    used to limit gradient computation to local window
            localHeight    scalar
            carFar         scalar
            focalLength    scalar
            mergeThreshold scalar    depth merging threshold T
            considerZ      bool      consider Z gradient
            topK           int       front most points to consider during forward/backward rendering
        returns:
            pixels         BxHxWx3
            pointIdxMap    BxHxWx5
            rhoMap         BxHxWx5
            WsMap          BxHxWx5x3
            isBehind       BxHxWx5
        """
        batchSize, numPoints, bbHeight, bbWidth = rho.shape
        # compute visiblity, return per pixel top5 contributor sorted by their
        # depthValue
        # pointIdxMap BxHxWx5, index inside the splat's bounding box window which is to be rendered at pixel (h,w)
        # depthMap    BxHxWx5, depths of splats that are rendered at pixel (h,w)
        # rhoMap      BxHxWx5
        pointIdxMap = torch.full((batchSize, height, width, topK), -1, dtype=torch.int64, device=rho.device)
        depthMap = torch.full((batchSize, height, width, topK), camFar, dtype=rho.dtype, device=rho.device)
        bbPositionMap = torch.full((batchSize, height, width, topK, 2), -1, dtype=torch.int64, device=rho.device)
        with torch.cuda.device(rho.device):
            # call visibility kernel, outputs depthMap, pointIdxMap which store the depth and index of
            # the 5 closest point for each pixel, if less than 5 points paint the pixel, set idxMap to -1
            _compute_visibility_maps(boundingBoxes[:, :, :2].contiguous(), inplane, pointIdxMap, bbPositionMap, depthMap)
            # gather rho, wk
            WsMap = _gather_maps(Ws, pointIdxMap, 0.0)
            # per batch indice for rhos bx(Nxhxw)
            # gather rho, Ws values from pointIdxMap and bbPositionMap, if idx < 0 (unset), then set rho=0 Ws=0
            totalIdxMap = pointIdxMap*bbHeight*bbWidth+bbPositionMap[:, :, :, :, 0]*bbWidth+bbPositionMap[:, :, :, :, 1]
            validMaps = totalIdxMap >= 0
            totalIdxMap = torch.where(validMaps, totalIdxMap, torch.full_like(totalIdxMap, -1))
            rhoMap = _gather_maps(rho.reshape(batchSize, -1, 1), totalIdxMap, 0.0).squeeze(-1)
            # check depth jump
            isBehind = torch.zeros(depthMap.shape, dtype=torch.uint8, device=depthMap.device)
            isBehind[:, :, :, 1:] = (depthMap[:, :, :, 1:]-depthMap[:, :, :, :1]) > mergeThreshold
            rhoMap_filtered = torch.where(isBehind, torch.zeros(1, 1, 1, 1, device=rhoMap.device, dtype=rhoMap.dtype), rhoMap)
            # WsMap[:, :, :, 1:, :] = torch.where(isBehind.unsqueeze(-1), torch.zeros(1, 1, 1, 1, 1, device=WsMap.device, dtype=WsMap.dtype), WsMap[:, :, :, 1:])
            # normalize rho
            sumRho = torch.sum(rhoMap_filtered, dim=-1, keepdim=True)
            sumRho = torch.where(sumRho == 0, torch.ones_like(sumRho), sumRho)
            rhoMap_normalized = rhoMap_filtered/sumRho
            # rho * w
            pixels = torch.sum(WsMap * rhoMap_normalized.unsqueeze(-1), dim=3)
            # accumulated = WsMap[:, :, :, 0, :]
            ctx.save_for_backward(pointIdxMap, bbPositionMap, isBehind, WsMap, rhoMap, depthMap, Ws, rhoValues, projPoints, cameraPoints, boundingBoxes, pixels, Ms)
            ctx.numPoint = numPoints
            ctx.bbWidth = bbWidth
            ctx.bbHeight = bbHeight
            ctx.localHeight = localHeight
            ctx.localWidth = localWidth
            ctx.mergeThreshold = mergeThreshold
            ctx.focalLength = focalLength
            ctx.considerZ = considerZ
            ctx.rho_requires_grad = rho.requires_grad
            ctx.w_requires_grad = Ws.requires_grad
            ctx.xyz_requires_grad = projPoints.requires_grad
            ctx.mark_non_differentiable(pointIdxMap, rhoMap, WsMap, isBehind)
            return pixels, pointIdxMap, rhoMap_normalized, WsMap, isBehind

    @staticmethod
    def backward(ctx, gradPixels, dpointIdxMap, gradRhoMap, gradWsMap, gradIsBehind):
        """
        input
            gradPixels          (BxHxWx3)
        output
            dRho            (BxNxbbHxbbW)
            dW              (BxNx3)
            dP              (BxNx2) derivative wrt projected points
            dcamP           (BxNx3) derivative wrt camera points (only z-dim is nonzero)
        """
        pointIdxMap, bbPositionMap, isBehind, WsMap, rhoMap, depthMap, Ws, rhoValues, projPoints, cameraPoints, boundingBoxes, pixels, Ms = ctx.saved_tensors
        mergeThreshold = ctx.mergeThreshold
        focalLength = ctx.focalLength
        numPoint = ctx.numPoint
        considerZ = ctx.considerZ
        bbWidth = ctx.bbWidth
        bbHeight = ctx.bbHeight
        batchSize, height, width, topK, C = WsMap.shape
        if ctx.needs_input_grad[0]:  # rho will not be backpropagated
            WsMap_ = torch.where(isBehind.unsqueeze(-1), torch.zeros(1, 1, 1, 1, 1, device=WsMap.device, dtype=WsMap.dtype), WsMap)
            totalIdxMap = pointIdxMap*bbHeight*bbWidth+bbPositionMap[:, :, :, :, 0]*bbWidth+bbPositionMap[:, :, :, :, 1]
            # TODO check dNormalizeddRho
            rhoMap_filtered = torch.where(isBehind, torch.zeros(1, 1, 1, 1, device=rhoMap.device, dtype=rhoMap.dtype), rhoMap)
            sumRho = torch.sum(rhoMap_filtered, dim=-1, keepdim=True)
            dNormalizeddRho = torch.where(rhoMap > 0, 1/sumRho-rhoMap/sumRho, rhoMap)
            dRho = _guided_scatter_maps(numPoint*bbWidth*bbHeight, dNormalizeddRho.unsqueeze(-1)*gradPixels.unsqueeze(3)*WsMap_, totalIdxMap, boundingBoxes)
            dRho = torch.sum(dRho, dim=-1)
            dRho = torch.reshape(dRho, (batchSize, numPoint, bbHeight, bbWidth))
        else:
            dRho = None

        if ctx.needs_input_grad[2]:
            # dPixels/dWs = Rho
            rhoMap_filtered = torch.where(isBehind, torch.zeros(1, 1, 1, 1, device=rhoMap.device, dtype=rhoMap.dtype), rhoMap)
            sumRho = torch.sum(rhoMap_filtered, dim=-1, keepdim=True)
            sumRho = torch.where(sumRho == 0, torch.zeros_like(sumRho), sumRho)
            rhoMap_normalized = rhoMap_filtered/sumRho
            # BxHxWx3 -> BxHxWxKx3 -> BxNx3
            dWs = _guided_scatter_maps(numPoint, gradPixels.unsqueeze(3)*rhoMap_normalized.unsqueeze(-1), pointIdxMap, boundingBoxes)
        else:
            dWs = None

        if ctx.needs_input_grad[3]:
            localWidth = ctx.localWidth
            localHeight = ctx.localHeight
            depthValues = cameraPoints[:, :, 2].contiguous()
            # B,N,1
            dIdp = torch.zeros_like(projPoints, device=gradPixels.device, dtype=gradPixels.dtype)
            dIdz = torch.zeros(1, numPoint, device=gradPixels.device, dtype=gradPixels.dtype)
            outputs = _visibility_backward(focalLength, mergeThreshold, considerZ,
                                                             localHeight, localWidth,
                                                             gradPixels, pointIdxMap, rhoMap, WsMap, depthMap, isBehind,
                                                             pixels, boundingBoxes, projPoints, Ws, depthValues, rhoValues, dIdp, dIdz)
            dIdp, dIdz = outputs
            # outputs = _visibility_debug_backward(mergeThreshold, focalLength, considerZ,
            #                                                        localHeight, localWidth, 0,
            #                                                        gradPixels, pointIdxMap, rhoMap, WsMap, depthMap, isBehind,
            #                                                        pixels, boundingBoxes, projPoints, Ws, depthValues, rhoValues, dIdp, dIdz)
            # dIdp, dIdz, debugTensor = outputs

            dIdcam = torch.zeros_like(cameraPoints)
            dIdcam[:, :, 2] = dIdz
            # saved_variables["dI"] = gradPixels.detach().cpu()
            # saved_variables["dIdp"] = saved_variables["dIdp"].scatter_(1, saved_variables["renderable_idx"].expand(-1, -1, dIdp.shape[-1]),
            #                                                            dIdp.cpu().detach())
            # saved_variables["projPoints"] = saved_variables["projPoints"].scatter_(1, saved_variables["renderable_idx"].expand(-1,-1,dIdp.shape[-1]),
            #                                                                        projPoints.cpu().detach())
            # saved_variables["dIdpMap"] = debugTensor[:,:,:,:2].cpu().detach()
        else:
            dIdp = dIdcam = None

        return (None, None, dWs, dIdp, None, None, dIdcam, None, None, None, None, None, None, None, None, None, None)

Hi, I am confused with the seventh return item of backward function, which is named dIdcam and explained as (BxNx3) derivative wrt camera points (only z-dim is nonzero). Why its corresponding input item of forward function is Ms? Something wrong here?

@Wu-Xiuchao
Copy link
Author

Hi, I have checked the value of grad and I think you backward the grad dIdcam in wrong place, which will lead the localpoints can not be optimized in depth.

@yifita
Copy link
Owner

yifita commented Mar 15, 2021

Hi @Wu-Xiuchao I just pushed a major update, could you give it a try? Currently the denoising feature isn't available yet, but the inverse rendering should be fully functional.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants