Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about the wikiart training sets #12

Open
Rancherzhang opened this issue Oct 27, 2020 · 3 comments
Open

Questions about the wikiart training sets #12

Rancherzhang opened this issue Oct 27, 2020 · 3 comments

Comments

@Rancherzhang
Copy link

Rancherzhang commented Oct 27, 2020

Hello, I want to reproduce your great job, but to my limited knowledge, I have two questions right now.
Firstly, I'm trying to rewrite the training phrase and beginning to train on the wikiart with content-dir of 'wikiart/Rococo' while style-dir of 'wikiart/Symbolism', but the intermediate result is not good as you, so I want to know what content-dir and style-dir you choose on the wikiart datasets?
Secondly, my loss on style distribution could not converge, it is always around between 4.2-4.4. My code is as below:

class StyleDistLoss(nn.Module):
    '''
    style distribition loss of s and s'
    '''
    def __init__(self, pool_size):
        super(StyleDistLoss, self).__init__()
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_style_batch = 0
            self.style_batches = []
        self.loss = nn.L1Loss()

    def __call__(self, sc, st):
        '''
            return the standart Gaussian distribution loss of input 
            style source {sc} and style traget {st} which are respective to s and s' in the paper
        '''
        styles = []
        if self.pool_size == 0:
            styles.extend([sc, st])
        else:
            styles += self.style_batches
            styles.extend([sc, st])

            detach_sc = sc.clone().detach()
            detach_st = st.clone().detach()

            if self.num_style_batch + 2 < self.pool_size:
                self.style_batches.extend([detach_sc, detach_st])
                self.num_style_batch += 2
            else:
                random_idx = [x for x in range(self.num_style_batch)]
                random.shuffle(random_idx)
                self.style_batches[random_idx[0]] = detach_sc
                self.style_batches[random_idx[1]] = detach_st
        tensor_styles = torch.squeeze(torch.cat(styles, 0))
        styles_mean = torch.mean(tensor_styles, dim=0)
        tminuss = tensor_styles - styles_mean
        cov = torch.mm(tminuss.t(), tminuss) / tensor_styles.shape[0]
        std_cov = cov.diag(diagonal=0)
        total_loss = self.loss(styles_mean, torch.zeros_like(styles_mean))
        total_loss += self.loss(cov, torch.ones_like(cov))
        total_loss += self.loss(std_cov, torch.ones_like(std_cov))
        return total_loss

Could you please give me some advice? Thanks!

@denkorzh
Copy link
Collaborator

Hi! I believe you have issues with lhe line total_loss += self.loss(cov, torch.ones_like(cov))

torch.ones_like(cov) returns a matrix filled with 1s, not an identity matrix. Therefore, your loss does not enforce the disentanglement of the components of the style vector. Try torch.eye instead.

@Rancherzhang
Copy link
Author

Hi! I believe you have issues with lhe line total_loss += self.loss(cov, torch.ones_like(cov))

torch.ones_like(cov) returns a matrix filled with 1s, not an identity matrix. Therefore, your loss does not enforce the disentanglement of the components of the style vector. Try torch.eye instead.

Thank you for your advice, I will have a try! Besides, I have a question about training on wikiart datasets, because I have noticed that in inference step in your readme.md file, you use landscape images as content images while wikiart images as style images, so, whether I should use the same strategy in my training stage on the wikiart?

@belkakari
Copy link
Collaborator

Sorry for the delayed reply. For style transfer model, both content and style images are being sampled from the wikiart dataset. We observe that trained in this manner, the model can be applied to the real images

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants