Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About Load HuggingFace Bert #205

Open
xiezipeng-ML opened this issue Mar 22, 2022 · 10 comments
Open

About Load HuggingFace Bert #205

xiezipeng-ML opened this issue Mar 22, 2022 · 10 comments
Assignees

Comments

@xiezipeng-ML
Copy link
Contributor

xiezipeng-ML commented Mar 22, 2022

用LiBai的Bert加载huggingface的权重对齐输出发现的一些问题,经过修改后可以与hugigngface输出对齐

参数结构对比,可以先看最下面两个库中Bert的参数结构:

  • LiBaiembedding部分和huggingface的没问题。
  • 然后,看LayerNorm层,我们LiBaiLayerNorm层放在每一结构的输入位置,huggingface的是放在每一结构的输出位置,也是没问题的,只需要加载huggingface权重时加载其上一层结构的LayerNorm即可。
  • 再看qkv部分,huggingface的q、k、v是分开定义,我们LiBai的是直接qkv,只需要加载出huggingface的q、k、v然后拼接就行。
  • 最后,就是加载权重时,凡是涉及到Linear层的地方,权重都进行permute(1,0)就可以。

LiBai的Bert与huggingface的Bert内部逻辑计算上不同,导致输出不对齐:

  • LiBai的MultiheadAttention中有两行代码导致这部分的输出与huggingface没法对齐,下面这两种计算方法得到的q、k、v是不一样的:
# 原始代码:
query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size)
query_key_value = query_key_value.permute(0, 2, 1, 3)
query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1)

# 我修改后的,结果可以与huggingface的q、k、v对齐:
query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1)
query = query.view(query.size(0), query.size(1), self.num_heads, -1).permute(0, 2, 1, 3)
key = key.view(key.size(0), key.size(1), self.num_heads, -1).permute(0, 2, 1, 3)
value = value.view(value.size(0), value.size(1), self.num_heads, -1).permute(0, 2, 1, 3)
  • 然后就是LiBai的TransformerLayer内部计算逻辑和huggingface的有些部分不一样,这里的不同同样导致了LiBai的输出无法与huggingface对齐:
# 这里的计算不同导致之后的所有输出都不一致,比如MLP层接受的输入也不同了
#原始代码:
# https://github.com/Oneflow-Inc/libai/blob/main/libai/layers/transformer_layer.py#L176
hidden_states = hidden_states + attention_output

# 我修改后的:
hidden_states = layernorm_output + attention_output

也就是说LiBai的hidden_states是用self-attention层的结果attention_output加上TransformerLayer的输入得到的Bert中有12层TransformerLayer第一层的TransformerLayer输入是Embedding层的输出但是huggingface中的hidden_states是用self-attention层的
结果attention_output加上TransformerLayer的输入经过一次LayerNorm得到的也就是说LiBai中的hidden_states没有经过LayerNorm就加到hidden_states里面了看起来是不合理的
  • 最后一个问题,也是在LiBaiTransformerLayer中,也是计算逻辑不同导致输出不一致:
# 原始代码:
# https://github.com/Oneflow-Inc/libai/blob/main/libai/layers/transformer_layer.py#L200
output = hidden_states + mlp_output

# 修改过后的:
output = layernorm_output + mlp_output

也就是说LiBai的TransformerLayer层的最后输出是由mlp_output和layernorm_output求和huggingface中这里是用layernorm_output来计算的
  • 修改完上面的问题后,把LiBaiBert中的bias_gelu_fusion、bias_dropout_fusion、apply_query_key_layer_scaling设置为False,然后我写了一个加载huggingface预训练模型的函数,加载之后LiBaiBert使用huggingface的权重可以得到与huggingfaceBert一样的输出(设置相同的一句话作为输入)。

先看LiBai中的Bert参数结构

embeddings.vocab_embeddings.weight oneflow.Size([30522, 768])
embeddings.position_embeddings.weight oneflow.Size([512, 768])
embeddings.tokentype_embeddings.weight oneflow.Size([2, 768])

encoders.0.input_layernorm.weight oneflow.Size([768])
encoders.0.input_layernorm.bias oneflow.Size([768])

encoders.0.self_attention.query_key_value.weight oneflow.Size([768, 2304])
encoders.0.self_attention.query_key_value.bias oneflow.Size([2304])
encoders.0.self_attention.dense.weight oneflow.Size([768, 768])
encoders.0.self_attention.dense.bias oneflow.Size([768])

encoders.0.post_attention_layernorm.weight oneflow.Size([768])
encoders.0.post_attention_layernorm.bias oneflow.Size([768])

encoders.0.mlp.dense_h_to_4h.weight oneflow.Size([768, 3072])
encoders.0.mlp.dense_h_to_4h.bias oneflow.Size([3072])


encoders.0.mlp.dense_4h_to_h.weight oneflow.Size([3072, 768])
encoders.0.mlp.dense_4h_to_h.bias oneflow.Size([768])

encoders.1.input_layernorm.weight oneflow.Size([768])
encoders.1.input_layernorm.bias oneflow.Size([768])

encoders.1.self_attention.query_key_value.weight oneflow.Size([768, 2304])
encoders.1.self_attention.query_key_value.bias oneflow.Size([2304])
encoders.1.self_attention.dense.weight oneflow.Size([768, 768])
encoders.1.self_attention.dense.bias oneflow.Size([768])
encoders.1.post_attention_layernorm.weight oneflow.Size([768])
encoders.1.post_attention_layernorm.bias oneflow.Size([768])
encoders.1.mlp.dense_h_to_4h.weight oneflow.Size([768, 3072])
encoders.1.mlp.dense_h_to_4h.bias oneflow.Size([3072])
encoders.1.mlp.dense_4h_to_h.weight oneflow.Size([3072, 768])
encoders.1.mlp.dense_4h_to_h.bias oneflow.Size([768])

final_layernorm.weight oneflow.Size([768])
final_layernorm.bias oneflow.Size([768])
pooler.dense.weight oneflow.Size([768, 768])
pooler.dense.bias oneflow.Size([768])

再看一下huggingface的参数结构

bert.embeddings.word_embeddings.weight torch.Size([30522, 768])
bert.embeddings.position_embeddings.weight torch.Size([512, 768])
bert.embeddings.token_type_embeddings.weight torch.Size([2, 768])
bert.embeddings.LayerNorm.gamma torch.Size([768])
bert.embeddings.LayerNorm.beta torch.Size([768])

bert.encoder.layer.0.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.query.bias torch.Size([768])
bert.encoder.layer.0.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.key.bias torch.Size([768])
bert.encoder.layer.0.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.value.bias torch.Size([768])
bert.encoder.layer.0.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.gamma torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.beta torch.Size([768])

bert.encoder.layer.0.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.0.intermediate.dense.bias torch.Size([3072])


bert.encoder.layer.0.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.0.output.dense.bias torch.Size([768])
bert.encoder.layer.0.output.LayerNorm.gamma torch.Size([768])
bert.encoder.layer.0.output.LayerNorm.beta torch.Size([768])

bert.encoder.layer.1.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.self.query.bias torch.Size([768])
bert.encoder.layer.1.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.self.key.bias torch.Size([768])
bert.encoder.layer.1.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.self.value.bias torch.Size([768])
bert.encoder.layer.1.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.1.attention.output.LayerNorm.gamma torch.Size([768])
bert.encoder.layer.1.attention.output.LayerNorm.beta torch.Size([768])
bert.encoder.layer.1.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.1.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.1.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.1.output.dense.bias torch.Size([768])
bert.encoder.layer.1.output.LayerNorm.gamma torch.Size([768])
bert.encoder.layer.1.output.LayerNorm.beta torch.Size([768])

bert.pooler.dense.weight torch.Size([768, 768])
bert.pooler.dense.bias torch.Size([768])
@L1aoXingyu
Copy link
Collaborator

关于 layernorm 位置的问题可以快速回复一下

我们参考的是 megatron 的代码实现,关于残差的位置在 megatron lm 的 paper 里面有写这样一段话

We further investigated this behavior and empirically demonstrated that rearranging the order of the layer normalization and the residual connections as shown in Figure 7 is critical to enable the scaling of the BERT-style models beyond BERT-Large. The architecture (b) in Figure 7 eliminates instabilities observed using the original BERT architecture in (a) and also has a lower training loss.

image

所以 libai 里面的 TransformerLayer 的位置和原始的 bert 是有所不同的.
@xiezipeng-ML

@xiezipeng-ML
Copy link
Contributor Author

关于 layernorm 位置的问题可以快速回复一下

我们参考的是 megatron 的代码实现,关于残差的位置在 megatron lm 的 paper 里面有写这样一段话

We further investigated this behavior and empirically demonstrated that rearranging the order of the layer normalization and the residual connections as shown in Figure 7 is critical to enable the scaling of the BERT-style models beyond BERT-Large. The architecture (b) in Figure 7 eliminates instabilities observed using the original BERT architecture in (a) and also has a lower training loss.

image

所以 libai 里面的 TransformerLayer 的位置和原始的 bert 是有所不同的. @xiezipeng-ML

LayerNorm位置确实是没问题的,是正常运算的。

@L1aoXingyu
Copy link
Collaborator

关于 qkv 计算的部分,我们也是参考下面 megatron 的代码,不过 libai 里面没有把 sequence 放到最前面,对应起来流程就是

libai:
		         view				   permute                split
[b, sq, (np*3*hn) -->  [b, sq, np, 3*hn] --> [b, np, sq, 3*hn] --> [b, np, sq, hn]

huggingface:
                  split                 view                permute
[b, sq, (np*3*hn)] --> [b, sq, (np*hn)] --> [b, sq, np, hn] --> [b, np, sq, hn]

https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/megatron/model/transformer.py#L221-L232

@rentainhe
Copy link
Contributor

这个issue特别好,感觉可以单独整理出来作为一个常见问题模块,或者是Advanced Tutorials

@xiezipeng-ML
Copy link
Contributor Author

image

记录了一下两种qkv的计算方法产生不同sbp的问题, @CPFLAME

@L1aoXingyu
Copy link
Collaborator

我们推导了一下,发现对齐 huggingface 的写法会导致之前推导的 sbp 出现问题,因为 huggingface 的写法是先做 chunk,而且 chunk 的维度刚好的 sbp.split,这样切完中间隐含了一次通信开销,所以我们觉得这样做可能会带了更多别的问题,你考虑用之前开杰提供的方案试试呢。

#146 (comment)

@xiezipeng-ML
Copy link
Contributor Author

我们推导了一下,发现对齐 huggingface 的写法会导致之前推导的 sbp 出现问题,因为 huggingface 的写法是先做 chunk,而且 chunk 的维度刚好的 sbp.split,这样切完中间隐含了一次通信开销,所以我们觉得这样做可能会带了更多别的问题,你考虑用之前开杰提供的方案试试呢。

#146 (comment)

好的星宇,我看怎么可以正确加载权重后能够对齐

@xiezipeng-ML
Copy link
Contributor Author

xiezipeng-ML commented Mar 27, 2022

不改变模型,改变wieght的加载能得到相同的结果

由于我们已经证明了libai的qkv计算的正确性(换成huggingface的qkv计算导致模型并行时sbp会出问题,目前的解决办法只有直接进行to_global来解决这个问题,而且不知道会不会造成别的问题,也就是说libai中的整套模型的sbp方案是配好的,换成别的计算方式有问题),所以这里考虑用不同的weight加载方式。

两种qkv计算方式:

# LiBai中的qkv计算方式:
# query_key_value:[batch_size, seq_len, 3*hidden_size]
query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size)  #(a) 
query_key_value = query_key_value.permute(0, 2, 1, 3)                                #(b)
query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1)                    #(c)

# huggingface中的qkv计算方式:
query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1)
query = query.view(query.size(0), query.size(1), self.num_heads, -1).permute(0, 2, 1, 3)
key = key.view(key.size(0), key.size(1), self.num_heads, -1).permute(0, 2, 1, 3)
value = value.view(value.size(0), value.size(1), self.num_heads, -1).permute(0, 2, 1, 3)

首先解释一下为什么 LiBaiMultiheadAttention 中的 qkv 计算方式加载 huggingfaceweight 后无法得到相同的结果:

  • 假设querykeyvalue是768的向量,因为加载权重时的加载方式是将querykeyvalue拼接起来也就形成了一个shape=[768,3]的矩阵[q, k, v],得到下列代码中query_key_value矩阵的最后一维3*hidden_size实际是[q_value, k_value, v_value]#(a)query_key_valu的最后一维[3*hidden_size] view操作为[self.num_heads, 3 * self.head_size],这里造成q_valuek_valuev_value中的元素重叠,所以#(c)chunk操作对最后一维chunk切分是无法得到正确的querykeyvalue的。
# q,k,v重叠
flow.arange(1, 2305).view(12, 3*64)
tensor([[   1,    2,    3,  ...,  190,  191,  192],
        [ 193,  194,  195,  ...,  382,  383,  384],
        [ 385,  386,  387,  ...,  574,  575,  576],
        ...,
        [1729, 1730, 1731,  ..., 1918, 1919, 1920],
        [1921, 1922, 1923,  ..., 2110, 2111, 2112],
        [2113, 2114, 2115,  ..., 2302, 2303, 2304]])

解决思路:

  • 所以在不改动qkv计算方式的情况下,对q、k、v权重用不同的加载方式来解决上面由于#(1)view操作导致q_valuek_valuev_value中元素重叠问题,也就是重新排列q、k、v的权重,使得#(a)操作后并且chunk可以得到正确的q_valuek_valuev_value
  • 以下代码是我的解决方式(在无bias的情况下可以成功),weight代表拼接后qkv_weightshape=[hidden_size*3, hidden_size],第一维hidden_size*3
import torch
import torch.nn.functional as F

bsz = 32
seq_len = 5
num_heads = 12
head_size = 64
hidden_size = num_heads*head_size

x = torch.rand(bsz, seq_len, hidden_size)
weight = torch.rand(hidden_size*3, hidden_size)
# bias = torch.rand(2304)

# my method for weight------------------------------
weight1 = weight.view([num_heads, 3, head_size, hidden_size])
weight_q, weight_k, weight_v = weight1.chunk(chunks=3, dim=0)            # [4, 3, head_size, hidden_size]
weight_q = weight_q.view(-1, head_size, hidden_size)                     # [12, head_size, hidden_size]
weight_k = weight_k.view(-1, head_size, hidden_size)
weight_v = weight_v.view(-1, head_size, hidden_size)

weight_q = weight_q.unsqueeze(1)
weight_k = weight_k.unsqueeze(1)
weight_v = weight_v.unsqueeze(1)

weight1 = torch.cat([weight_q, weight_k, weight_v], dim=1)     # [12*head_size, 3, hidden_size]
weight1 = weight1.view(-1, hidden_size)
# my method for weight end-----------------------------------------------------

weight2 = weight
qkv1 = F.linear(x, weight1, bias=None)
qkv2 = F.linear(x, weight2, bias=None)

# libai------------------------------------------
qkv1 = qkv1.view(bsz, seq_len, num_heads, 3*head_size)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)

# huggingface------------------------------------------
q2, k2, v2 = torch.chunk(qkv2, chunks=3, dim=-1)
q2 = q2.view(q2.size(0), q2.size(1), num_heads, -1).transpose(1,2)
k2 = k2.view(k2.size(0), k2.size(1), num_heads, -1).transpose(1,2)
v2 = v2.view(v2.size(0), v2.size(1), num_heads, -1).transpose(1,2)

print((q1==q2).all())     # tensor(True)
print((k1==k2).all())     # tensor(True)
print((v1==v2).all())     # tensor(True)
  • 但是qkv_bias应该是没办法调整的,qkv_bias是一个[2304]的向量,因为当qkv_weight调整过排列方式后,qkv_bias没办法加到对应的位置,所以。。。。。想了很久,大家看看有什么办法

@rentainhe
Copy link
Contributor

rentainhe commented Mar 28, 2022

bias解决方案

  • bias其实就是weight少了一个hidden_size维度, 所以只需要对bias和weight做一样的转置操作就行了,可以参考以下的代码
import torch
import torch.nn.functional as F
import pdb

bsz = 32
seq_len = 5
num_heads = 12
head_size = 64
hidden_size = num_heads*head_size

x = torch.rand(bsz, seq_len, hidden_size)
weight = torch.rand(hidden_size*3, hidden_size)
bias = torch.rand(2304)

# my method for weight------------------------------
weight1 = weight.view([num_heads, 3, head_size, hidden_size])
weight_temp = weight1

weight_q, weight_k, weight_v = weight1.chunk(chunks=3, dim=0)            # [4, 3, head_size, hidden_size]
weight_q = weight_q.view(-1, head_size, hidden_size)                     # [12, head_size, hidden_size]
weight_k = weight_k.view(-1, head_size, hidden_size)
weight_v = weight_v.view(-1, head_size, hidden_size)


weight_q = weight_q.unsqueeze(1)
weight_k = weight_k.unsqueeze(1)
weight_v = weight_v.unsqueeze(1)

weight1 = torch.cat([weight_q, weight_k, weight_v], dim=1)     # [12*head_size, 3, hidden_size]
weight1 = weight1.view(-1, hidden_size)
# my method for weight end-----------------------------------------------------

weight2 = weight

# --------------convert bias-------------------------------
bias_ = bias.view(num_heads, 3, head_size)
bias_q, bias_k, bias_v = bias_.chunk(3, dim=0)
bias_q = bias_q.view(-1, head_size).unsqueeze(1)
bias_k = bias_k.view(-1, head_size).unsqueeze(1)
bias_v = bias_v.view(-1, head_size).unsqueeze(1)
bias1 = torch.cat([bias_q, bias_k, bias_v], dim=1).view(-1)
# -----------------------------------------------------------


qkv1 = F.linear(x, weight1, bias=bias1)  # 2304, 768
qkv2 = F.linear(x, weight2, bias=bias)
# pdb.set_trace()

# libai------------------------------------------
qkv1 = qkv1.view(bsz, seq_len, num_heads, 3*head_size)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)

# huggingface------------------------------------------
q2, k2, v2 = torch.chunk(qkv2, chunks=3, dim=-1)
q2 = q2.view(q2.size(0), q2.size(1), num_heads, -1).transpose(1,2)
k2 = k2.view(k2.size(0), k2.size(1), num_heads, -1).transpose(1,2)
v2 = v2.view(v2.size(0), v2.size(1), num_heads, -1).transpose(1,2)

print((q1==q2).all())     # tensor(True)
print((k1==k2).all())     # tensor(True)
print((v1==v2).all())     # tensor(True)

@xiezipeng-ML

整理后的代码

import torch
import torch.nn.functional as F

bsz = 32
seq_len = 5
num_heads = 12
head_size = 64
hidden_size = num_heads*head_size

x = torch.rand(bsz, seq_len, hidden_size)
weight = torch.rand(hidden_size*3, hidden_size)
bias = torch.rand(2304)

# convert weight and bias
weight1 = weight.view([num_heads, 3, head_size, hidden_size])
weight_q, weight_k, weight_v = weight1.chunk(chunks=3, dim=0)
weight_q = weight_q.view(-1, head_size, hidden_size).unsqueeze(1)
weight_k = weight_k.view(-1, head_size, hidden_size).unsqueeze(1)
weight_v = weight_v.view(-1, head_size, hidden_size).unsqueeze(1)
weight1 = torch.cat([weight_q, weight_k, weight_v], dim=1).view(-1, hidden_size)

bias_ = bias.view(num_heads, 3, head_size)
bias_q, bias_k, bias_v = bias_.chunk(3, dim=0)
bias_q = bias_q.view(-1, head_size).unsqueeze(1)
bias_k = bias_k.view(-1, head_size).unsqueeze(1)
bias_v = bias_v.view(-1, head_size).unsqueeze(1)
bias1 = torch.cat([bias_q, bias_k, bias_v], dim=1).view(-1)


weight2 = weight
bias2 = bias


qkv1 = F.linear(x, weight1, bias=bias1)
qkv2 = F.linear(x, weight2, bias=bias2)


# libai------------------------------------------
qkv1 = qkv1.view(bsz, seq_len, num_heads, 3*head_size)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)

# huggingface------------------------------------------
q2, k2, v2 = torch.chunk(qkv2, chunks=3, dim=-1)
q2 = q2.view(q2.size(0), q2.size(1), num_heads, -1).transpose(1,2)
k2 = k2.view(k2.size(0), k2.size(1), num_heads, -1).transpose(1,2)
v2 = v2.view(v2.size(0), v2.size(1), num_heads, -1).transpose(1,2)

print((q1==q2).all())     # tensor(True)
print((k1==k2).all())     # tensor(True)
print((v1==v2).all())     # tensor(True)

@xiezipeng-ML
Copy link
Contributor Author

bert的load_pretrain_weight后输出对齐了

import oneflow as flow
import libai
from libai.models import build_model
from libai.config import LazyCall
from load_huggingface_weight import load_huggingface_bert
from libai.utils import distributed as dist
import transformers
import torch
import numpy as np


input_ids = [[101, 1962, 2110, 739, 999, 1, 2, 3, 102]]
mask = [[1]*len(input_ids)]

# libai result
cfg = dict(
    vocab_size=21128,
    hidden_size=768,
    hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=512,
    num_tokentypes=2,
    add_pooling_layer=True,
    initializer_range=0.02,
    layernorm_eps=1e-12,
    bias_gelu_fusion=False, #
    bias_dropout_fusion=False,#
    scale_mask_softmax_fusion=False,
    apply_query_key_layer_scaling=False,#
    add_binary_head=True,
    amp_enabled=False,
    apply_residual_post_layernorm=True
)
bert_lib = build_model(LazyCall(libai.models.BertModel)(cfg=cfg))
load_huggingface_bert(bert_lib, './pretrain/pytorch_model.bin', cfg['hidden_size'], cfg['num_attention_heads'])
input_of = flow.tensor(input_ids, dtype=flow.long, sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.broadcast]), placement=flow.placement("cuda" if flow.cuda.is_available() else "cpu", [0]),)
mask_of = flow.tensor(mask, dtype=flow.long, sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.broadcast]), placement=flow.placement("cuda" if flow.cuda.is_available() else "cpu", [0]),)
bert_lib.eval()
last_hidden_state_of, pooler_output_of = bert_lib(input_of, mask_of)


# huggingface result
bert_hug = transformers.BertModel.from_pretrained('./pretrain')
bert_hug.eval()
input_pt = torch.tensor(input_ids)
mask_pt = torch.tensor(mask)
last_hidden_state_pt = bert_hug(input_pt, mask_pt).last_hidden_state 


res1 = last_hidden_state_of.detach().numpy()
res2 = last_hidden_state_pt.detach().numpy()
print(res1.sum())
print(res2.sum())

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants