Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bert4torch版本0.2.8升级到0.3.4问题 #157

Open
kakaxzc opened this issue Nov 7, 2023 · 2 comments
Open

bert4torch版本0.2.8升级到0.3.4问题 #157

kakaxzc opened this issue Nov 7, 2023 · 2 comments

Comments

@kakaxzc
Copy link

kakaxzc commented Nov 7, 2023

你好 我原本的bert4torch版本是0.2.8执行task_seq2seq_autotitle_csl_mt5等一些类似模型没有问题,但是版本升级到0.3.4发生问题
在下面这个方法中outputs值返回2个值
class CrossEntropyLoss(nn.CrossEntropyLoss):
def init(self, **kwargs):
super().init(**kwargs)

def forward(self, outputs, y_true):
    _, _, y_pred = outputs
    y_pred = y_pred.reshape(-1, y_pred.shape[-1])
    return super().forward(y_pred, y_true)

如果去掉一个的话 在下面这部分的return地方会报错。 请问要如何解决
class AutoTitle(AutoRegressiveDecoder):
"""seq2seq解码器
"""
@AutoRegressiveDecoder.wraps(default_rtype='logits')
def predict(self, inputs, output_ids, states):
# inputs中包含了[decoder_ids, encoder_hidden_state, encoder_attention_mask]
# 保留最后一位
return model.decoder.predict([output_ids] + inputs)[-1][:, -1, :]

@Tongjilibo
Copy link
Owner

您好,这个是之前改版时候,example没有更改过来,应该按照下述这样修改一下就可以了,也可以升级到最新的0.3.7,最新版本不需要convert权重,仅需使用bert4torch_config.json就可以加载了

def forward(self, outputs, y_true):
    y_pred = outputs[-1]
    y_pred = y_pred.reshape(-1, y_pred.shape[-1])
    return super().forward(y_pred, y_true)

@AutoRegressiveDecoder.wraps(default_rtype='logits')
def predict(self, inputs, output_ids, states):
    res = model.decoder.predict([output_ids] + inputs)
    return res[-1][:, -1, :] if isinstance(res, list) else res[:, -1, :]  # 保留最后一位

@kakaxzc
Copy link
Author

kakaxzc commented Nov 8, 2023

问题修复了,感谢~!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants