Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discriminator loss converges to zero early in training #16

Open
jpfeil opened this issue Nov 21, 2023 · 9 comments
Open

Discriminator loss converges to zero early in training #16

jpfeil opened this issue Nov 21, 2023 · 9 comments

Comments

@jpfeil
Copy link
Contributor

jpfeil commented Nov 21, 2023

I compared v0.1.26 without the GAN and v0.1.36 with the GAN using the fashion mnist data and was able to get better reconstructions without the GAN:
https://api.wandb.ai/links/pfeiljx/f7wdueh0

Do you have any suggestions for improving training?

I'm using a cosine scheduler for the model and discriminator. Should I use a different learning rate schedule for the discriminator?

I saw similar discriminator collapse with the VQ-GAN, and I read that delaying the discriminator until the generator model is optimized may help. Maybe delaying the discriminator until a certain reconstruction loss is achieved?

After googling some strategies, I saw the unrolled GAN where the generator stays a few steps ahead of the discriminator. I'm not sure how difficult it would be to implement a similar strategy here.

I'm just brainstorming, so feel free to address or ignore any of these comments.

import torch
from datetime import datetime
from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

RUNTIME = datetime.now().strftime("%y%m%d_%H%M%S")

tokenizer = VideoTokenizer(
    image_size = 32,
    channels=1,
    use_gan=True,
    use_fsq=False,
    codebook_size=2**13,
    init_dim=64,
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/mnist/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 10,
    grad_accum_every = 5,
    num_train_steps = 5_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=2e-5,
    accelerate_kwargs={"split_batches": True, "mixed_precision": "fp16"},
    random_split_seed=85,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={},
    use_wandb_tracking=True,
    checkpoints_folder=f'./runs/{RUNTIME}/checkpoints',
    results_folder=f'./runs/{RUNTIME}/results',
)


with trainer.trackers(project_name = 'magvit', run_name = f'MNIST v0.1.26 W/ GAN 2**13 {RUNTIME}'):
    trainer.train()

@lucidrains
Copy link
Owner

@jpfeil can you screenshot the paper section where they propose delaying the discriminator training? (and link the paper too)

@lucidrains
Copy link
Owner

lucidrains commented Nov 21, 2023

@jpfeil do you have adversarial_loss_weight greater than 0.? also try another run where your perceptual_loss_weight is 0.1

@jpfeil
Copy link
Contributor Author

jpfeil commented Nov 21, 2023

Thanks @lucidrains. I'll try again with those parameters. I saw it in the taming implementation here:
https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/losses/vqperceptual.py#L51

@lucidrains
Copy link
Owner

@jpfeil welp.. whatever Robin and Patrick does goes; they are the best in the world.

let me add that

@lucidrains
Copy link
Owner

lucidrains commented Nov 21, 2023

@jpfeil ok, added that same functionality here. try removing the learning rate schedule in your next run too, shouldn't need it for something this easy

@lucidrains
Copy link
Owner

@jpfeil you don't happen to have relatives in Massachusetts, do you?

@jpfeil
Copy link
Contributor Author

jpfeil commented Nov 21, 2023

@lucidrains Nice. Let me try it out again. No, I don't have any relatives in Massachusetts. Did you meet someone with the last name Pfeil?

@lucidrains
Copy link
Owner

yea, I knew someone back in high school with the Pfeil family name. Tragedy struck and they moved away though. You are the second Pfeil I've met!

@jpfeil
Copy link
Contributor Author

jpfeil commented Nov 21, 2023

That's amazing. It's not a common name. Sorry to hear about your friend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants