Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low MIG values bug found & solution #64

Open
DianeBouchacourt opened this issue Feb 26, 2021 · 3 comments
Open

Low MIG values bug found & solution #64

DianeBouchacourt opened this issue Feb 26, 2021 · 3 comments

Comments

@DianeBouchacourt
Copy link

I trained a beta TCVAE with the code from https://github.com/rtqichen/beta-tcvae which gives MIG for beta TCVAE of ~0.50. When computing MIG with your code with the same model (based on MLP), I had values close to 0.0008.

Differences with Chen's code I found important:

  • MIG values are not computed on shape in Chen's code (not considered a factor of variation). I had to modify the dsprites dataset to remove shape from dSprites lat_names, and write a custom _estimate_H_zCv function. I can share if you want.

  • Chen uses samples, not the mean as you do here

    since self.training is False

  • The most important change is I changed these lines

    samples_zCx = samples_zCx.expand(len_dataset, latent_dim, n_samples)
    for

      samples_zCx = samples_zCx.permute(1,0)
      samples_zCx = samples_zCx.index_select(1, samples_x).view(latent_dim, n_samples)
      samples_zCx = samples_zCx.view(1, latent_dim, n_samples).expand(len_dataset, latent_dim, n_samples)
      mean = params_zCX[0].view(len_dataset, latent_dim, 1).expand(len_dataset, latent_dim, n_samples)
      log_var = params_zCX[1].view(len_dataset, latent_dim, 1).expand(len_dataset, latent_dim, n_samples)
    

which are closer to Chen's code, and I get values of ~0.50 now too. I don't exactly know why the original lines where not expanding the correct way

@DianeBouchacourt
Copy link
Author

After a bit more digging, I think the problem comes from the fact that indexes are selected but the view() operation

samples_zCx = samples_zCx.expand(len_dataset, latent_dim, n_samples)
is reordering in a wrong manner, permute is the correct way to do it. Indeed, I've checked by taking the first 10000 samples, which are also the first 10000 samples in mean. With the view() operation it does not return the same values as in mean (which it should since the samples are mean values). I replaced with permute and it works as expected.

Thus I think it should be:

    samples_zCx = samples_zCx.index_select(0, samples_x).permute(1,0)
    samples_zCx = samples_zCx.unsqueeze(0).expand(len_dataset, latent_dim, n_samples)

in place of

samples_zCx = samples_zCx.index_select(0, samples_x).view(latent_dim, n_samples)
and
samples_zCx = samples_zCx.expand(len_dataset, latent_dim, n_samples)

@YannDubs
Copy link
Owner

Awesome thanks Diane, a few others had open issues for that but there was no solutions (I don't actively work on that anymore). Do you think you could send a PR ?

ideally, it would use samples instead of the mean , but I think it's more urgent to have the correct permutations !

@DianeBouchacourt
Copy link
Author

Let me check my code again, and I will ! I also need to do a PR just for these changes (I edited quite some stuff in the fork I created).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants