Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I found a small mistake in the Transformer model. #158

Open
rhksgud92 opened this issue Jul 3, 2020 · 2 comments
Open

I found a small mistake in the Transformer model. #158

rhksgud92 opened this issue Jul 3, 2020 · 2 comments

Comments

@rhksgud92
Copy link

rhksgud92 commented Jul 3, 2020

https://github.com/Kyubyong/transformer/blob/master/model.py
In this code from line 176 ~ 181, you are using "==" inside of tensorflow model which won't work.
for _ in tqdm(range(self.hp.maxlen2)): logits, y_hat, y, sents2 = self.decode(ys, memory, src_masks, False) if tf.reduce_sum(y_hat, 1) == self.token2idx["<pad>"]: break _decoder_inputs = tf.concat((decoder_inputs, y_hat), 1) ys = (_decoder_inputs, y, y_seqlen, sents2)

This would result not stopping at the pad output but keep iterates until the maxlen ends.
This is a minor issue but makes the eval function slower.

Use something like this instead would make the eval function faster:
logits, y_hat, y, sent2 = tf.cond(tf.equal(y_hat[0][-1], self.token2idx["<pad>"]), lambda: (logits, y_hat, y, sent2), lambda:self.decode(ys, memory, src_masks, False))

@bozhenhhu
Copy link

I wonder that why you use y_hat[0][-1], because the first shape of y_hat equals with self.hp.batch_size , why you use every first example to calculate one batch data whether meets 'pad' or not ?

@rhksgud92
Copy link
Author

Sorry, it was supposed to be tf.reduce_sum(y_hat, 1) not y_hat[0][-1]. Since if statement doesn't work in tensorflow version 1.

To make it stop the decode calculation part If sum of all elements are 0 (pad).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants