Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About GumbelQuantization training #67

Open
kobiso opened this issue Jun 30, 2021 · 26 comments
Open

About GumbelQuantization training #67

kobiso opened this issue Jun 30, 2021 · 26 comments

Comments

@kobiso
Copy link

kobiso commented Jun 30, 2021

Thank you for the great work!

I tried to repoduce VQGAN OpenImages (f=8), 8192, GumbelQuantization model based on the config file from the cloud. (the detailed config file is in below.)

VQGAN OpenImages (f=8), 8192, GumbelQuantization
model:
  base_learning_rate: 4.5e-06
  target: taming.models.vqgan.GumbelVQ
  params:
    kl_weight: 1.0e-08
    embed_dim: 256
    n_embed: 8192
    monitor: val/rec_loss
    temperature_scheduler_config:
      target: taming.lr_scheduler.LambdaWarmUpCosineScheduler
      params:
        warm_up_steps: 0
        max_decay_steps: 1000001
        lr_start: 0.9
        lr_max: 0.9
        lr_min: 1.0e-06
    ddconfig:
      double_z: false
      z_channels: 256
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult:
      - 1
      - 1
      - 2
      - 4
      num_res_blocks: 2
      attn_resolutions:
      - 32
      dropout: 0.0
    lossconfig:
      target: taming.modules.losses.vqperceptual.DummyLoss

However, I encountered some errors to train with GumbelQuantization training.
The first error was an unexpected keyword argument error as below.

  File "/opt/conda/envs/taming/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 57, in forward
    output = self.module.validation_step(*inputs, **kwargs)
  File "/home/shared/workspace/dalle/taming-transformers/taming/models/vqgan.py", line 336, in validation_step
    xrec, qloss = self(x, return_pred_indices=True)
  File "/opt/conda/envs/taming/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
TypeError: forward() got an unexpected keyword argument 'return_pred_indices'

I could fix this error by remove return_pred_indices=True from the below line.

xrec, qloss = self(x, return_pred_indices=True)

The second error occurs because of DummyLoss as below.

  File "/opt/conda/envs/taming/lib/python3.8/site-packages/pytorch_lightning/trainer/optimizers.py", line 34, in init_optimizers
    optim_conf = model.configure_optimizers()
  File "/home/shared/workspace/dalle/taming-transformers/taming/models/vqgan.py", line 129, in configure_optimizers
    opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
  File "/opt/conda/envs/taming/lib/python3.8/site-packages/torch/nn/modules/module.py", line 947, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'DummyLoss' object has no attribute 'discriminator'

This can be fixed by changing target: taming.modules.losses.vqperceptual.DummyLoss to target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator.

But the thing is, I not sure if VQGAN OpenImages (f=8), 8192, GumbelQuantization model was trained with Discriminator loss and when it was on with what parameters.
Can you share the detailed config file of VQGAN OpenImages (f=8), 8192, GumbelQuantization model and fix above issues so that the model can be reproducible?

Thank you in advance!

@TomoshibiAkira
Copy link

TomoshibiAkira commented Jun 30, 2021

This config is probably used for inference only. DummyLoss is simply not calculating any loss.
Another thing is that kl_weight is 1e-8, which makes quantization loss extremely small. I don't think this is correct either.

Hope the author would release the config for training soon. As for now, I'll stick to vanilla VQ although it suffers from index collapsing. The utilization of the 16384 model is ~6% (~1000 valid codes), the 1024 model is around 50% (~500 valid codes).
You can visualize the codes by decoding the VQ dictionary.

EDIT: The quantization loss for GumbelVQ is not big, around 0.005 if training from scratch with kl_weight=1.

@crowsonkb
Copy link

I would also like some clarity on the best KL weight for training from scratch (and whether it should be warmed up over time).

@borisdayma
Copy link

@TomoshibiAkira Why do you expect to have a better utilization of the codes with Gumbel Quantization?

@TomoshibiAkira
Copy link

@borisdayma Because the codebook in the f=8 GumbelVQ model does not contain invalid codes, unlike the IN model.

By "invalid codes", I mean:
In the IN model's codebook, there are several thousands of codes that have a very small L2-norm (around 5e-4) compared with other valid codes (around 15). These codes usually don't contain any interesting information as shown in the visualization (the first 1024 codes in the 16384 model).
image

Here's the visualization of the first 1024 codes in the f=8 GumbelVQ's codebook for comparsion.
image

@borisdayma
Copy link

Thanks @TomoshibiAkira for this great explanation!
Btw how did you create those visualizations?

@TomoshibiAkira
Copy link

You're welcome! @borisdayma
I simply treat every code as a 1x1-size patch and forward it through the pretrained decoder.

@sczhou
Copy link

sczhou commented Sep 21, 2021

Thanks for your explanation@TomoshibiAkira.
How do you set the temperature_scheduler_config?

@borisdayma
Copy link

@TomoshibiAkira wondering if you looked at codebook utilization for other models (like OpenAI dVAE).

@TomoshibiAkira
Copy link

TomoshibiAkira commented Sep 29, 2021

@borisdayma
I haven't been playing around VQ for a while, but hey, we're here. Why not :)
Here's the visualization of DALL-E's discrete code (the first 1024 codes of 8192 in total).

dic0

DALL-E suffers much harsher information loss than VQ. The reason might be that every code in DALL-E is an integer (or "a class"). Thus contains much less information than VQ's codes (feature vectors).
As a result, although they're both f=8, DALL-E codes' visualization is much muddier and lacks details than GumbelVQ's codebook.

As for the utilization though, DALL-E's discretion method is different from VQ.
For VQ, I could compute the norm of a code, and from that, I could tell whether it's valid since the invalid ones are always very different from the valid ones.
For DALL-E, there's no such a way that I can explicitly determine whether the code is valid or not. Every code can be decoded into a patch, and from the look of it, it seems like every code is occupied. Although there are some duplications, it also happens in GumbelVQ's codebook, so one might say DALL-E's codes are 100% utilized.

EDIT: The code for visualization, you can directly use this in the usage.ipynb provided by DALL-E.

import torch.nn.functional as F
from torchvision.utils import save_image
num_classes = enc.vocab_size
batch_size = 1024
for i in range(num_classes // batch_size):
    ind = torch.arange(i * batch_size, (i+1) * batch_size)
    z = F.one_hot(ind, num_classes=num_classes).unsqueeze(-1).unsqueeze(-1).float()
    x_stats = dec(z.cuda()).float()
    x_recs = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
    save_image(x_recs, "viz/dic{}.png".format(i), nrow=32)

@sczhou
The default parameter in GumbelVQ's model.yaml seems okay to me. That may be a good starting point.

@sczhou
Copy link

sczhou commented Sep 29, 2021

Thanks, @TomoshibiAkira. Where could I find GumbelVQ's model.yaml? I didn't see this config file in this repo.

Many thanks.

@TomoshibiAkira
Copy link

Thanks, @TomoshibiAkira. Where could I find GumbelVQ's model.yaml? I didn't see this config file in this repo.

Many thanks.

It's in the pretrained model zoo.
https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/?p=%2F&mode=list

@borisdayma
Copy link

borisdayma commented Sep 29, 2021

DALL-E suffers much harsher information loss than VQ. The reason might be that every code in DALL-E is an integer (or "a class"). Thus contains much less information than VQ's codes (feature vectors). As a result, although they're both f=8, DALL-E codes' visualization is much muddier and lacks details than GumbelVQ's codebook.

@TomoshibiAkira Don't they both use a codebook where you can use either the codebook index or the corresponding feature vector?
I thought that overall the one from OpenAI is blurrier mainly because it averages over patches (with mse loss) vs GAN loss and perceptual loss from the VQGAN force. to sharpen the image.

@TomoshibiAkira
Copy link

TomoshibiAkira commented Sep 29, 2021

@TomoshibiAkira Don't they both use a codebook where you can use either the codebook index or the corresponding feature vector?

@borisdayma I personally don't think so.
In the image reconstruction example from usage.ipynb, the discretion method of DALL-E is the argmax function.
The output feature of DALL-E's encoder is directly argmaxed in the channel dimension and then transformed into a one-hot vector.
Then the one-hot vector is sent to the decoder and followed by normal Conv2D ops. One can say that DALL-E's decoder is actually decoding the INDEX of the encoder's feature.

To put it in VQ's perspective, you can say all different 8192 one-hot vectors are DALL-E's codebook.
VQGAN maps the continuous feature into a codebook that has 8192 codes, every code is an $\mathbf{R}^256$ vector.
DALL-E also maps the continuous feature into a codebook that has 8192 codes, but every code is a one-hot vector.
The information that one code could represent in these two methods is vastly different IMO.

I thought that overall the one from OpenAI is blurrier mainly because it averages over patches (with mse loss) vs GAN loss and perceptual loss from the VQGAN force. to sharpen the image.

Well the GAN and perceptual loss are definitely helping, I do think even without them the VQGAN (or a plain simple VQ-VAE) could achieve better reconstruction results.
Here's one thought, if we keep every other thing of VQGAN intact, and change the codebook into DALL-E style, will it have the same performance?
If it does, well, it means that the actual feature DOES NOT matter at all, indices are good enough for feature discretion, which is pretty counterintuitive. But neural networks are pretty counterintuitive as a whole, so yeah :D

@crowsonkb
Copy link

@borisdayma I personally don't think so. In the image reconstruction example from usage.ipynb, the discretion method of DALL-E is the argmax function... Here's one thought, if we keep every other thing of VQGAN intact, and change the codebook into DALL-E style, will it have the same performance? If it does, well, it means that the actual feature DOES NOT matter at all, indices are good enough for feature discretion, which is pretty counterintuitive. But neural networks are pretty counterintuitive as a whole, so yeah :D

I think the new Gumbel VQGAN type already has a DALL-E style codebook. It does indeed seem better but I think this comes down to the quantization method preventing codebook collapse. The DALL-E decoder just uses a simple 1x1 conv2d layer to transform the one-hots into feature vectors (it's a one-to-one mapping), I have opened the decoder up and used the features directly instead.

indices are good enough for feature discretion

They have to be because the second stage transformer models only produce indices, not features.

@TomoshibiAkira
Copy link

TomoshibiAkira commented Sep 30, 2021

The DALL-E decoder just uses a simple 1x1 conv2d layer to transform the one-hots into feature vectors (it's a one-to-one mapping), I have opened the decoder up and used the features directly instead.

Ah, now I see. If we combine the Conv2D layer with the one-hots (only considering the output of the 1x1 Conv2D layer), it's actually the same as VQ's codebook (with or without Gumbel). The codebook here is actually the weight of the Conv2D layer.
Both @borisdayma and you are right. My bad!

Since they're the same,

If it does, well, it means that the actual feature DOES NOT matter at all, indices are good enough for feature discretion.

This hypothesis is invalid from the start.

They have to be because the second stage transformer models only produce indices, not features.

Sorry if my statement is not clear before.
What I really want to say is that "if we throw away the contents of the codebook, the indices of the codes alone are good enough to reconstruct the original image". But it doesn't matter now since the hypothesis is not valid at all 🤣

It does indeed seem better but I think this comes down to the quantization method preventing codebook collapse.

I'm not sure at this point. From my personal experience on an AE with VQ, with f=8/f=16, the network's behavior is vastly different from each other on the same dataset.
If someone would train an f=8 model without Gumbel to see the codebook utilization, that'll be very helpful.

@fnzhan
Copy link

fnzhan commented Dec 2, 2021

Hi @TomoshibiAkira , it is really a valuable discussion! May I know if you validate the performance of f=8 without Gumbel? Actually, I just want to see the effect of Gumbel, i.e., adding Gumbel to vanilla VQ will always improve the reconstruction & codebook utilization (e.g., f=8, f=16), or there is some trade-off such as a high utilization of codebook but relatively low accuracy of code matching. If you have any idea on that?

@TomoshibiAkira
Copy link

@fnzhan I didn't conduct the experiment so I can't give any concrete answer. Personally, I'd like to believe that Gumbel can improve the performance without any trade-off since it's basically a better method for sampling discrete data. But, I didn't dive deep into the theory part of Gumbel-Softmax so please take a grain of salt.

@crowsonkb
Copy link

Hi @TomoshibiAkira , it is really a valuable discussion! May I know if you validate the performance of f=8 without Gumbel? Actually, I just want to see the effect of Gumbel, i.e., adding Gumbel to vanilla VQ will always improve the reconstruction & codebook utilization (e.g., f=8, f=16), or there is some trade-off such as a high utilization of codebook but relatively low accuracy of code matching. If you have any idea on that?

I think the tradeoff is during training, you have to train longer because you have to slowly decrease the Gumbel-Softmax temperature to 0 or very near 0. But I think it is straightforwardly better during inference.

@fnzhan
Copy link

fnzhan commented Dec 4, 2021

@TomoshibiAkira @crowsonkb Thanks for sharing your insight, I am working on it recently and will update if concrete conclusion is reached.

@EmaadKhwaja
Copy link

@fnzhan any updates? Really interested to see if there are any key improvements.

@fnzhan
Copy link

fnzhan commented Feb 22, 2022

Hi @EmaadKhwaja , I am preparing a paper regarding to it. Here is a brief observation: comparing original VQ and GumbelVQ (both f=16), the improvement with Gumbel tends to be marginal although its codebook utilization is nearly 100%.

@TomoshibiAkira
Copy link

TomoshibiAkira commented Feb 22, 2022

@fnzhan
Hmm, that's interesting! This might mean that the actual usage of the codes is very unbalanced no matter the codebook utilization (e.g., the network tends to use several "special" codes rather than others), which unfortunately means that the codebook collapse issue is still very much present. The statistics on the indexes of used codes would be helpful to verify this statement.
Anyway, good luck with the paper!

@Zyriix
Copy link

Zyriix commented Mar 9, 2023

@borisdayma I haven't been playing around VQ for a while, but hey, we're here. Why not :) Here's the visualization of DALL-E's discrete code (the first 1024 codes of 8192 in total).

dic0

DALL-E suffers much harsher information loss than VQ. The reason might be that every code in DALL-E is an integer (or "a class"). Thus contains much less information than VQ's codes (feature vectors). As a result, although they're both f=8, DALL-E codes' visualization is much muddier and lacks details than GumbelVQ's codebook.

As for the utilization though, DALL-E's discretion method is different from VQ. For VQ, I could compute the norm of a code, and from that, I could tell whether it's valid since the invalid ones are always very different from the valid ones. For DALL-E, there's no such a way that I can explicitly determine whether the code is valid or not. Every code can be decoded into a patch, and from the look of it, it seems like every code is occupied. Although there are some duplications, it also happens in GumbelVQ's codebook, so one might say DALL-E's codes are 100% utilized.

EDIT: The code for visualization, you can directly use this in the usage.ipynb provided by DALL-E.

import torch.nn.functional as F
from torchvision.utils import save_image
num_classes = enc.vocab_size
batch_size = 1024
for i in range(num_classes // batch_size):
    ind = torch.arange(i * batch_size, (i+1) * batch_size)
    z = F.one_hot(ind, num_classes=num_classes).unsqueeze(-1).unsqueeze(-1).float()
    x_stats = dec(z.cuda()).float()
    x_recs = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
    save_image(x_recs, "viz/dic{}.png".format(i), nrow=32)

@sczhou The default parameter in GumbelVQ's model.yaml seems okay to me. That may be a good starting point.

Thanks for your share

@Zyriix
Copy link

Zyriix commented Mar 17, 2023

hey guys, if you are still interested about optimize codebooks. i tried using the codebook with projection and l2 norm from https://arxiv.org/pdf/2110.04627v3.pdf, it works well. here is a codebook.
image
it has great color and various shapes.

@function2-llx
Copy link

@TomoshibiAkira @crowsonkb Thanks for sharing your insight, I am working on it recently and will update if concrete conclusion is reached.

@fnzhan Congratulations on your article being accepted by CVPR 2023! Would you kindly share your codes and pre-trained weights? It would help us to better understand and follow up on your work.

@OrangeSodahub
Copy link

OrangeSodahub commented Sep 1, 2023

Hi, guys, I have tried train another VQModel (first stage) on my own datasets, (modified the encoder, decoder a little), however when training, the vector quantization loss rises, and kl loss also rises, any suggestions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants