Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualize trained plenoxels #19

Open
FrederikWarburg opened this issue Jan 25, 2023 · 1 comment
Open

Visualize trained plenoxels #19

FrederikWarburg opened this issue Jan 25, 2023 · 1 comment

Comments

@FrederikWarburg
Copy link

FrederikWarburg commented Jan 25, 2023

Hi

I would like to visualize one of your trained plenoxels. Ideally, I would want to just load a ckpt and render views from a spherical path around the center object. I would like to be able to do this without having to download co3d. However, I find this challenging to do with your current code.

I was able to load your model by using your on_load_checkpoint that dequantize the checkpoints and load the model. Then I want to render views from this.

I decide on an intrinsic matrix:

near, far = 0., 1.
ndc_coeffs = (-1., -1.)
image_sizes = (200, 200)
focal = (100., 100.)
intrinsics = np.array(
    [
        [focal[0], 0.0, image_sizes[0]/2, 0.0],
        [0.0, focal[1], image_sizes[1]/2, 0.0],
        [0.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 0.0, 1.0],
    ]
)

and use your function spherical_poses to get the extrinsics

cam_trans = np.diag(np.array([-1, -1, 1, 1], dtype=np.float32))
render_poses = spherical_poses(cam_trans)

I then try to create the rays from the first pose using various of your functions

extrinsics_idx = render_poses[:1]
N_render = len(render_poses)
intrinsics_idx = np.stack(
    [intrinsics for _ in range(N_render)]
)
image_sizes_idx = np.stack(
    [image_sizes for _ in range(N_render)]
)

rays_o, rays_d = batchified_get_rays(
    intrinsics_idx, 
    extrinsics_idx, 
    image_sizes_idx,
    True,
)

rays_d = torch.tensor(rays_d, dtype=torch.float32)
rays_o = torch.tensor(rays_o, dtype=torch.float32)
rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)

rays = torch.stack(
    convert_to_ndc(rays_o, rays_o, ndc_coeffs), dim=1
)

rays_o = rays[:,0,:].contiguous()
rays_d = rays[:,1,:].contiguous()

rays_o = rays_o.to("cuda")
rays_d = rays_d.to("cuda")

and then try to render

rays = Rays(rays_o, rays_d)
grid = grid.to(device="cuda")
depth = grid.volume_render_depth(rays, 1e-5)
target = torch.zeros_like(rays_o)

rgb, mask = grid.volume_render_fused(rays, target)

but when I visualize the rendering it looks like I did something wrong:

depth = depth.reshape(200, 200)
rgb = rgb.reshape(200, 200, 3)

plt.imshow(depth.cpu().numpy())
plt.show()

plt.imshow(rgb.cpu().numpy())
plt.show()

Could you please help me? Would be very useful to check that the model loading is correct and to see how good the reconstructions are.

@FrederikWarburg
Copy link
Author

FrederikWarburg commented Jan 25, 2023

For completeness, I'll also share my code for loading the model. I moved some of the functions to a jupyter notebook to keep the evaluation code disentangled from the training code.

def dequantize_data( data, data_min, data_scale, quant_bit=8, logarithmic_quant=False):

    if quant_bit == 8 or quant_bit == 16: 
        data_tensor = data.type(torch.FloatTensor) * data_scale + data_min
    elif quant_bit == 4:
        data_blank = torch.zeros(len(data) * 2, *data.shape[1:], device=data.device)
        data_blank[0::2] = data // 16
        data_blank[1::2] = data % 16
        if torch.all(data_blank[-1] == 0): 
            data_blank = data_blank[:-1]
        data_tensor = data_blank.type(torch.FloatTensor) * data_scale + data_min
    elif quant_bit == 2:
        data_blank = torch.zeros(len(data) * 4, *data.shape[1:], device=data.device)
        data_blank[0::4] = data // 64
        data_blank[1::4] = data % 64 // 16
        data_blank[2::4] = data % 16 // 4
        data_blank[3::4] = data % 4
        for _ in range(4):
            if torch.all(data_blank[-1]) == 0:
                data_blank = data_blank[:-1]
        data_tensor = data_blank.type(torch.FloatTensor) * data_scale + data_min

    if logarithmic_quant:
        data_tensor = torch.exp(-data_tensor)

    return data_tensor

def load_checkpoint(grid, checkpoint, quantize=True, quantize_density=False) -> None:

    state_dict = checkpoint["state_dict"]

    grid.reso_idx = checkpoint["reso_idx"]

    del grid.basis_data
    del grid.density_data
    del grid.sh_data
    del grid.links

    grid.register_parameter(
        "basis_data", nn.Parameter(state_dict["model.basis_data"])
    )

    if "model.background_data_min" in checkpoint.keys():
        del grid.background_data
        bgd_data = state_dict["model.background_data"]
        if quantize:
            bgd_min = checkpoint["model.background_data_min"]
            bgd_scale = checkpoint["model.background_data_scale"]
            bgd_data = dequantize_data(bgd_data, bgd_min, bgd_scale)

        grid.register_parameter("background_data", nn.Parameter(bgd_data))
        checkpoint["state_dict"]["model.background_data"] = bgd_data

    density_data = state_dict["model.density_data"]
    if quantize_density:
        density_min = checkpoint["model.density_data_min"]
        density_scale = checkpoint["model.density_data_scale"]
        density_data = dequantize_data(density_data, density_min, density_scale)

    grid.register_parameter("density_data", nn.Parameter(density_data))
    checkpoint["state_dict"]["model.density_data"] = density_data

    sh_data = state_dict["model.sh_data"]
    if quantize:
        sh_data_min = checkpoint["model.sh_data_min"]
        sh_data_scale = checkpoint["model.sh_data_scale"]
        sh_data = dequantize_data(sh_data, sh_data_min, sh_data_scale)

    grid.register_parameter("sh_data", nn.Parameter(sh_data))
    checkpoint["state_dict"]["model.sh_data"] = sh_data

    reso_list = [[128, 128, 128], [256, 256, 256]]
    reso = reso_list[checkpoint["reso_idx"]]

    links = torch.zeros(reso, dtype=torch.int32) - 1
    links_sparse = state_dict["model.links_idx"]
    links_idx = torch.stack(
        [
            links_sparse // (reso[1] * reso[2]),
            links_sparse % (reso[1] * reso[2]) // reso[2],
            links_sparse % reso[2],
        ]
    ).long()
    links[links_idx[0], links_idx[1], links_idx[2]] = torch.arange(
        len(links_idx[0]), dtype=torch.int32
    )
    checkpoint["state_dict"].pop("model.links_idx")
    checkpoint["state_dict"]["model.links"] = links
    grid.register_buffer("links", links)

    state_dict = checkpoint["state_dict"]
    state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
    grid.load_state_dict(state_dict)

    return grid

having defined these, I

grid = SparseGrid(background_nlayers=28, background_reso = 512)
ckpt = torch.load('../../data/co3d/PeRFception-v1-1/00/plenoxel_co3d_30_1091_3400/last.ckpt', map_location='cpu')
grid = load_checkpoint(grid, ckpt)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant