Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling from VAE's latent space #566

Open
seakforzq opened this issue Oct 11, 2023 · 1 comment
Open

Sampling from VAE's latent space #566

seakforzq opened this issue Oct 11, 2023 · 1 comment

Comments

@seakforzq
Copy link

Is your feature request related to a problem? Please describe.
How to sample from the vae's latent space?


        target_key = sin.coords_key
        out_cls, targets, sout, means, log_vars, zs = net(sin, target_key)
        num_layers, BCE = len(out_cls), 0
        losses = []
        for out_cl, target in zip(out_cls, targets):
            curr_loss = crit(out_cl.F.squeeze(), target.type(out_cl.F.dtype).to(device))
            losses.append(curr_loss.item())
            BCE += curr_loss / num_layers

        KLD = -0.5 * torch.mean(
            torch.sum(1 + log_vars.F - means.F.pow(2) - log_vars.F.exp(), 1)
        )
        loss = KLD + BCE

        print(loss)

        batch_coords, batch_feats = sout.decomposed_coordinates_and_features
        for b, (coords, feats) in enumerate(zip(batch_coords, batch_feats)):
            pcd = PointCloud(coords)
            pcd.estimate_normals()
            pcd.translate([0.6 * config.resolution, 0, 0])
            pcd.rotate(M)
            opcd = PointCloud(data_dict["xyzs"][b])
            opcd.translate([-0.6 * config.resolution, 0, 0])
            opcd.estimate_normals()
            opcd.rotate(M)
            o3d.visualization.draw_geometries([pcd, opcd])

            n_vis += 1
            if n_vis > config.max_visualization:
                return

Describe the solution you'd like
The code above only reaches the reconstruction purpose, and target_key from sin is needed, can you help me with sampling from the latent space?

Describe alternatives you've considered
NO

Additional context
NO

@seakforzq
Copy link
Author

And why we need gt_target to prune the upsampled sparsetensor?


        # If training, force target shape generation, use net.eval() to disable
        if self.training:
            keep1 += target

I think the network will not learn how to pruning the output in this way right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant