transformers 为了适应非常多种模型结构,结构变得非常复杂。
在参考 transformers 、 bert4pytorch 、 Read_Bert_Code 的代码基础上,对结构进行了一些调整,提高了代码的易读性,并和 transformers 的结果完全一致。
关于 lfs,可以参考 git lfs 。
模型名称 | git clone | 自行下载 |
---|---|---|
bert-base-chinese | git clone git@e.coding.net:mmmwhy/file/bert-base-chinese.git |
coding 链接 |
chinese-roberta-wwm-ext | git clone git@e.coding.net:mmmwhy/file/chinese-roberta-wwm-ext.git |
coding 链接 |
chinese-roberta-wwm-ext-large | git lfs clone git@e.coding.net:mmmwhy/file/chinese-roberta-wwm-ext-large.git |
coding 链接 |
ernie 1.0 | git clone git@e.coding.net:mmmwhy/file/ernie-1.0.git |
coding 链接 |
速度还是比较可观的,
1、安装本仓库 pip install pure_attention==0.0.24
2、下载预训练模型
3、开始使用
from pure_attention.common.nlp_tokenization import Tokenizer
from pure_attention.backbone_bert.bert_model import BertModel
bert_model_path = "/data/pretrain_modal/bert-base-chinese"
test_query = "结果一致性验证"
tokenizer = Tokenizer(bert_model_path + "/vocab.txt")
bert = BertModel(bert_model_path)
tokenizer_output = tokenizer.encode(test_query, max_len=64)
our_bert_pooler_output = bert(
input_ids=tokenizer_output.input_ids,
token_type_ids=tokenizer_output.token_type_ids,
attention_mask=tokenizer_output.attention_mask).pooler_output
bert_last_hidden_state = bert(
input_ids=tokenizer_output.input_ids,
token_type_ids=tokenizer_output.token_type_ids,
attention_mask=tokenizer_output.attention_mask).last_hidden_state
分别在下边四个常用中文 bert 上进行测试,结果与 transformers 完全一致。
校验代码,截图时的代码可能比较老,以新代码为准。
import torch
from transformers import BertModel
from transformers import BertTokenizer
bert_model_path = "/data/pretrain_modal/chinese-roberta-wwm-ext-large"
test_query = "结果一致性验证"
text_tokenizer = BertTokenizer.from_pretrained(bert_model_path, do_lower_case=True)
bert_model = BertModel.from_pretrained(bert_model_path)
tensor_caption = text_tokenizer(test_query, return_tensors="pt", padding='max_length', truncation=True,
max_length=64)
origin_bert_pooler_output = bert_model(
input_ids=tensor_caption.input_ids,
attention_mask=tensor_caption.attention_mask,
token_type_ids=tensor_caption.token_type_ids).pooler_output
# 我们简化重构后的代码
from pure_attention.common.nlp_tokenization import Tokenizer as LocalTokenizer
from pure_attention.backbone_bert.bert_model import BertModel as OurBertModel
tokenizer = LocalTokenizer(bert_model_path + "/vocab.txt")
bert = OurBertModel(bert_model_path)
tokenizer_output = tokenizer.encode(test_query, max_len=64)
our_bert_pooler_output = bert(
input_ids=tokenizer_output.input_ids,
token_type_ids=tokenizer_output.token_type_ids,
attention_mask=tokenizer_output.attention_mask).pooler_output
print("check result:", torch.cosine_similarity(origin_bert_pooler_output, our_bert_pooler_output))