Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about training and training stability #11

Open
greeneggsandyaml opened this issue Oct 21, 2023 · 3 comments
Open

Questions about training and training stability #11

greeneggsandyaml opened this issue Oct 21, 2023 · 3 comments

Comments

@greeneggsandyaml
Copy link

Hello,

Thank you so much for all your awesome work. It is really great stuff!

I have some questions about training (not answered by #2). First, if it is possible, would you be able to release your training code? It would be super helpful.

I'm asking because I'm looking to train some autoencoders that combine multiple types of image information (i.e. rgb, depth, segmentation, etc.). I have a training script at the moment (based on diffusers here), but I'm finding that training is very unstable. I'm constantly getting NaNs. Did you find that any tricks were necessary to finetune the VAE without getting NaNs? What learning rate / batch size did you use?

Also, if you have any more experiments that you want to run, I have access to some good resources (8 A100s), so let me know. I feel like these autoencoders are quite under-explored and a lot more (semantic and geometric) information could be integrated into them.

Best,
greeneggsandyaml

@madebyollin
Copy link
Owner

If I can get VAE training code that works reliably and isn't a mess to read, I'll definitely release it - still experimenting with various approaches for that though...

For the TAESD weights in the repo, my optimizer was th.optim.Adam(model.parameters(), 3e-4, betas=(0.9, 0.9), amsgrad=True) and batch size 16 - nothing fancy (trained on 1xA10 :P). My adversarial loss was "relativistic" (penalizing distance(disc(real).mean(), disc(fake).mean()), rather than just disc(fake).mean()) and I used a replay buffer of fakes for the discriminator training, both of which may have helped with stability. I also used several auxiliary losses (LPIPS + frequency-amplitude matching + MSE at 1/8 res) for the most recent model, which helped reduce the dependence on the adversarial loss a bit. I don't remember any persistent instability issues with this setup.

@madebyollin
Copy link
Owner

BTW, it looks like the MosaicML folks are working on a cleaned up version of the Latent Diffusion VAE training code here mosaicml/diffusion#79

@madebyollin
Copy link
Owner

Adding some more info here.

Changes in 1.2

For TAESD 1.2, I removed the LPIPS and other icky hand-coded losses (now just using adversarial + very faint lowres MSE). I also added adversarial loss to the encoder training as well (though I'm not sure it made a difference).

Various questions I've seen

  1. Are the decoder targets GT images or SD-decoded images? GT; TAESD's decoder is a standalone conditional GAN, not a distilled model.
  2. What dataset was used? Depends on model version, but usually some mix of photos (e.g. laion-aesthetic) and illustrations (e.g. danbooru2021), with some color / geometric augmentations
  3. Do you delay adversarial loss until a certain number of steps (like the SD VAE does)? I usually prefer to start from a pretrained decoder model, but I don't have some specific number of steps in mind.
  4. What do you mean by low-res MSE loss? like F.mse_loss(F.avg_pool2d(decoded, 8), F.avg_pool2d(real_images, 8)). Just making sure that the color of each 8x8 patch is approximately correct.
  5. Which reference VAEs do you use? https://huggingface.co/stabilityai/sd-vae-ft-ema and https://huggingface.co/madebyollin/sdxl-vae-fp16-fix - these are used to supervise the encoder and also as a gold standard for decoder quality.

Various figures

Color Augmentation

Color augmentation (occasional hue / saturation shifting of input images) helped improve reproduction of saturated colors (which are otherwise rare in aesthetic datasets)

image

Downsides of different losses

MSE/MAE can make everything very smooth (top is GT, bottom is a simple MSE-only decoder)
image

LPIPS can cause recognizable artifacts on faces & eyes (top is from a run with LPIPS, bottom is a run without it)

image

Adversarial loss can cause divergence if not handled properly:

image

Blue eyes

I don't remember what caused this

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants