Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low MIG and AAM metrics #52

Open
Justin-Tan opened this issue Nov 10, 2019 · 6 comments
Open

Low MIG and AAM metrics #52

Justin-Tan opened this issue Nov 10, 2019 · 6 comments

Comments

@Justin-Tan
Copy link

Hello,

Firstly, just wanted to state that this is a great repo with a very understandable code base!

I seem to be getting extremely low MIG / AAM scores (around 1e-3 to 1e-2) when training with any of the pretrained models, even using the recommended hyperparams in the .ini file in the main directory. Is this something you were noticing in your own tests?

Visual inspection of the traversals in DSprites seem to show that the network is learning quite disentangled representations (attached, with rows arranged in order of descending KL-divergence from Gaussian prior), so I am quite confused as to why the MIG score is so low.

Even introducing supervision (matching latent factors to generative factors, the maximum MIG score I have been able to attain is around 0.01, but AAM is a lot higher, at around 0.6 for the model that produced the attached latent traversals.

Cheers,
Justin

traversals

@YannDubs
Copy link
Owner

YannDubs commented Nov 24, 2019

The small MIG is definitely (and unfortunately) something we always had in our experiments. Importantly, I got the same results when using the author's implementation. This is one of the reason we introduced AAM, which measures only the disentanglement rather than disentanglement + amount of information of v about z. I am surprised you get small AAM though.

Here are the results we were getting :

Screen Shot 2019-11-24 at 4 12 12 AM

We see that when increasing β by a small amount (from 1 to 4), highly increases axis alignment (from20% to 65%) due to the regularisation of the total correlation, while increasing β by a large amount (from 4 to 50) decreases axis alignment due to the penalisation of the dimension wise KL. I.e. it is not monotonic.

@Justin-Tan
Copy link
Author

Justin-Tan commented Nov 24, 2019 via email

@YannDubs
Copy link
Owner

YannDubs commented Nov 24, 2019

Yes it is, if you get an answer / insights please post it here. I would be interested + other people might be.

And just to be clear, I have not tried rerunning the authors code. I only tried using their MIG code to compute the MIG for our results :/ . I.e. it does not seem that the issue comes from the computation of MIG, but to be honest I have not spent too much time on MIG as this was a late addition before a deadline.

@Justin-Tan
Copy link
Author

Justin-Tan commented Dec 9, 2019

After some digging I am getting better results using the author's MIG calculation code - around 0.3-0.8 for most of my trained models. Perhaps the problem lies in shuffling the dataloader? I notice when I shuffle the dataloader I get a very low MIG (on dSprites).

# Load dataloader
all_loader = (..., shuffle=False)

vae = model
N = len(all_loader.dataset)     # number of data samples - don't shuffle
K = vae.latent_dim              # number of latent variables
nparams = 2
vae.eval()

qzCx_params = torch.Tensor(N, K, nparams)

n = 0

with torch.no_grad():
    for x, gen_factors in all_loader:
        batch_size = x.size(0)
        x, gen_factors = x.to(device, dtype=torch.float), gen_factors.to(device)
        qzCx_stats = torch.stack(vae.encoder(x)['continuous'], dim=2)
        qzCx_params[n:n + batch_size] = qzCx_stats.data
        n += batch_size

# Reshape to get known generative factors
qzCx_params = qzCx_params.view(3, 6, 40, 32, 32, K, nparams).to(device)

# Sample from diagonal Gaussian posterior q(z|x) using given parameters (mu, logvar)
qzCx_samples = qzCx_sample(params=qzCx_params)

I think the reshape on the second last line requires the dataset to be in the native order so that the generative factors are in the correct order - it's not obvious that they should be though, this is a quirk of the dSprites dataset.

@YannDubs
Copy link
Owner

YannDubs commented Dec 10, 2019

Thanks for digging into it. What exactly do you mean by shuffling ? We do no shuffle the test loader (

shuffle=False,
) if that's your point.

BTW : I'm more than happy to accept PRs

@Justin-Tan
Copy link
Author

Justin-Tan commented Dec 11, 2019

Yeah, that's what I mean about the shuffling, thanks for confirming, so that is not the cause of it. I am still confused because AFAIK this repository and the author's code appear to be doing the exact same thing when calculating MIG (on dSprites at least) but this repository is giving much lower MIG when loading the same models. I'll look more into it over this weekend.

I also looked into the discrete estimation of MIG used in [1], appendix C. (Essentially discretize samples from z in ~20 bins and use sklearn to estimate the discrete MI b/w latents z and generative factors v.) Unfortunately it does not agree with the MIG computed using this sampling-based estimation (consistently lower, reasonably insensitive to number of bins) irrespective of whether we use the mean of the latent representation or we sample from q(z|x), so the jury is still out on how to best estimate MIG I suppose.

Edit: It seems like the MIG scores reported in [1] are consistently lower anyway around 0.2 for the best models, so perhaps this is expected.
'
[1]: https://arxiv.org/abs/1811.12359

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants