Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于WGAN-gp源码的两点问题 #45

Open
donlv1997 opened this issue Mar 3, 2020 · 2 comments
Open

关于WGAN-gp源码的两点问题 #45

donlv1997 opened this issue Mar 3, 2020 · 2 comments

Comments

@donlv1997
Copy link

donlv1997 commented Mar 3, 2020

在看源码的过程中发现了一点小问题

  1. wgan_train.py源码还是使用了sigmoid再做cross_entro_loss,但是WGAN应该直接返回Discrimintaror的输出logits作为loss
def d_loss_fn(generator, discriminator, batch_z, real_image):
    fake_image = generator(batch_z, training=True)
    d_fake_score = discriminator(fake_image, training=True)
    d_real_score = discriminator(real_image, training=True)

    loss = tf.reduce_mean(d_fake_score - d_real_score)
    # lambda = 10
    gp = gradient_penalty(discriminator, real_image, fake_image) * 10.

    loss = loss + gp
    return loss, gp

def g_loss_fn(generator, discriminator, batch_z):
    fake_image = generator(batch_z, training=True)
    d_fake_logits = discriminator(fake_image, training=True)
    # loss = celoss_ones(d_fake_logits)
    loss = -tf.reduce_mean(d_fake_logits)
    return loss

2.按照WGAN的要求改完logits作为loss后,我发现train起来不能收敛,经过反复检查,发现是gradient penalty的计算有些问题,将原有函数如下之后可以很好地收敛:

def gradient_penalty(discriminator, real_image, fake_image):
    batchsz = real_image.shape[0]
    # dtype caused disconvergence?
    t = tf.random.uniform([batchsz, 1, 1, 1], minval=0., maxval=1., dtype=tf.float32)
    x_hat = t * real_image + (1. - t) * fake_image
    with tf.GradientTape() as tape:
        tape.watch(x_hat)
        Dx = discriminator(x_hat, training=True)
    grads = tape.gradient(Dx, x_hat)
    slopes = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
    gp = tf.reduce_mean((slopes - 1.) ** 2)
    return gp
@donlv1997
Copy link
Author

wgan_gp-160000
改进前:train到5W epoch左右就会发生梯度爆炸,导致generator只能产生噪声。
改进后:发挥了WGAN training稳定的特性,目前train了16W个epoch,输出还是可以稳定提升。

@donlv1997
Copy link
Author

其他改进:使用Deconvolution,输出放大仔细看,好像能观察到棋盘状暗纹。可能是Conv_Transpose导致的overlap。如果把discriminator改为upsampling+Conv2D的结构应该可以消除,由于该改进我还在train,具体效果还有待确认

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant