Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training_logits, targets维度不匹配 #34

Open
pengwei-iie opened this issue Jan 28, 2019 · 2 comments
Open

training_logits, targets维度不匹配 #34

pengwei-iie opened this issue Jan 28, 2019 · 2 comments

Comments

@pengwei-iie
Copy link

cost = tf.contrib.seq2seq.sequence_loss(training_logits, targets, masks)

tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [5632] vs. [6400]

就是我的target是(256,25)
可是输出得到的training_logits却是(256, 22, 358)358:词表数

我改了一下,这样就对了

def pad_batch_sentence(batch, max_length, pad_id):
    # max_length = max([len(sentence) for sentence in batch])
    return [sentence + [pad_id] * (max_length - len(sentence)) for sentence in batch]


def get_batches(sources, targets, batch_size):

    for batch_i in range(0, len(sources) // batch_size):
        start_i = batch_i * batch_size

        # Slice the right amount for the batch
        sources_batch = sources[start_i:start_i + batch_size]
        targets_batch = targets[start_i:start_i + batch_size]

        pad_idx = source_vocab_to_int.get("<PAD>")
        sources_batch_pad = np.array(pad_batch_sentence(sources_batch, max_source_sentence_length, pad_idx))
        targets_batch_pad = np.array(pad_batch_sentence(targets_batch, max_target_sentence_length, pad_idx))
        # Need the lengths for the _lengths parameters
        # 不应该是对pad过的batch做长度的计算,因为都是25
        targets_lengths = []
        for target in targets_batch_pad:
            targets_lengths.append(len(target))

        source_lengths = []
        for source in sources_batch_pad:
            source_lengths.append(len(source))

        yield sources_batch_pad, targets_batch_pad, source_lengths, targets_lengths

可是这样传入的source_lengths都是(20,20,20...)
targets_lengths都是(25, 25, 25...)
@sxlprince
Copy link

我也觉得这块有点问题,这样source长度全是padding以后的最大长度。。

@sxlprince
Copy link

改了以后会报错。。。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants