Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about how to handle inputs_embeds #28

Open
forest1988 opened this issue Oct 26, 2020 · 2 comments
Open

Question about how to handle inputs_embeds #28

forest1988 opened this issue Oct 26, 2020 · 2 comments

Comments

@forest1988
Copy link

Hi,

Thank you for sharing your great work!

I have a question about how to handle input_embeds in the PPLM code.

When I look run_pplm.py, I found something I cannot understand the intention.
https://github.com/uber-research/PPLM/blob/master/run_pplm.py#L220

        if loss_type == PPLM_DISCRIM or loss_type == PPLM_BOW_DISCRIM:
            ce_loss = torch.nn.CrossEntropyLoss()
            # TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
            curr_unpert_past = unpert_past
            curr_probs = torch.unsqueeze(probs, dim=1)
            wte = model.resize_token_embeddings()
            for _ in range(horizon_length):
                inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
                _, curr_unpert_past, curr_all_hidden = model(
                    past=curr_unpert_past,
                    inputs_embeds=inputs_embeds
                )
                curr_hidden = curr_all_hidden[-1]
                new_accumulated_hidden = new_accumulated_hidden + torch.sum(
                    curr_hidden, dim=1)

inputs_embeds is updated in the for loop,
but curr_probs and wte.weight.data seem not to be updated in the loop.

Could you please tell me the reason inputs_embeds is calculated in the for loop?

Thank you in advance!

@dathath
Copy link
Contributor

dathath commented Oct 27, 2020

I think that is actually a bug, and might also explain why our experiments with horizon_length > 1 did not work so well (we use horizon-length=1 in all of our experiments). If you're running with horizon-length=1, it shouldn't matter but that is a bug. We do need to update curr_probs inside the loop here.

@forest1988
Copy link
Author

Excuse me for my reply was so delayed. I've missed the notification that this issue was kindly responded to by you.

Thank you for the detailed information. Now I understood that what I mentioned is a bug and you used horizon-lenghth=1 for your experiments to make it do not matter.

Updating curr_probs inside the loop here is needed, I understood.
I'm not a good programmer, but if I can do somewhat helpful to you, please tell me.

Again, I apoligize for my late reply.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants