Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How many tags can this project train at the same time? #27

Open
datar001 opened this issue Oct 6, 2021 · 3 comments
Open

How many tags can this project train at the same time? #27

datar001 opened this issue Oct 6, 2021 · 3 comments

Comments

@datar001
Copy link

datar001 commented Oct 6, 2021

Hi, thanks for your sharing.
How many tags have you tried to train? What's the relation between the number of tags and that of training iterations?
And How many tags will you recommend at the once training?

@imlixinyang
Copy link
Owner

I've succeeded to train 6 tags at the same time. In experiment, I found 50k per tag is enough (i.e., 20k for 6 tags).
HiSD supports various numbers of tags but you should increase the training iteration and the model capacity.
Using gradient accumulation and train all tags in one iteration is also important (so you need to change the code a little).

@datar001
Copy link
Author

datar001 commented Oct 6, 2021

Thanks for your reply.
Is it right about "the gradient accumulation and all tags in one iteration"?
image
image
And '20k for 6 tags' is the typo? The official repo is 200k for 3 tags with 7 attributions.
Then is there a better performance when we train fewer tags?

@imlixinyang
Copy link
Owner

Sorry for the typo, it should be 200k for 3 tags with 7 attributes.
You get the idea of the gradient accumulation in a right way, and you can clarify the update code like:

    def update(self, x, y, i, j, j_trg, iterations):

        this_model = self.models.module if self.multi_gpus else self.models

        # gen 
        for p in this_model.dis.parameters():
            p.requires_grad = False
        for p in this_model.gen.parameters():
            p.requires_grad = True

        self.loss_gen_adv, self.loss_gen_sty, self.loss_gen_rec, \
        x_trg, x_cyc, s, s_trg = self.models((x, y, i, j, j_trg), mode='gen')

        self.loss_gen_adv = self.loss_gen_adv.mean()
        self.loss_gen_sty = self.loss_gen_sty.mean()
        self.loss_gen_rec = self.loss_gen_rec.mean()
        

        # dis
        for p in this_model.dis.parameters():
            p.requires_grad = True
        for p in this_model.gen.parameters():
            p.requires_grad = False


        self.loss_dis_adv = self.models((x, x_trg, x_cyc, s, s_trg, y, i, j, j_trg), mode='dis')
        self.loss_dis_adv = self.loss_dis_adv.mean()
        
        if (iterations + 1) % self.tag_num == 0:
            nn.utils.clip_grad_norm_(this_model.gen.parameters(), 100)
            nn.utils.clip_grad_norm_(this_model.dis.parameters(), 100)
            self.gen_opt.step()
            self.dis_opt.step()
            self.gen_opt.zero_grad()
            self.dis_opt.zero_grad()

            update_average(this_model.gen_test, this_model.gen)

        return self.loss_gen_adv.item(), \
               self.loss_gen_sty.item(), \
               self.loss_gen_rec.item(), \
               self.loss_dis_adv.item()

And you need to decrease the learning rate before backward (maybe lr/tag_num) since the gradient by 'sum' rather than 'average'.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants