Skip to content

Commit 77fc3b9

Browse files
committed
feat: fix tokenizer, add average pooling to bert
1 parent 661b79f commit 77fc3b9

File tree

6 files changed

+109
-7
lines changed

6 files changed

+109
-7
lines changed

examples/demo_bert.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,25 @@
88
string vocab_file = "./vocab/gte_vocab.mllm";
99
string model_file = "./models/gte-small-fp32.mllm";
1010

11+
/*
12+
* an intent to support gte-small BertModel to do text embedding
13+
* current implementation is just a very basic example with a simple WordPiece tokenizer and a simple BertModel
14+
* not support batch embedding
15+
* */
16+
1117
int main(int argc, char *argv[]) {
12-
BertTokenizer tokenizer(vocab_file, false);
13-
string text = "Hello, my dog is cute.";
18+
BertTokenizer tokenizer(vocab_file, true);
19+
string text = "Help me set an alarm at 21:30";
1420
auto [token_ids, type_ids, position_ids] = tokenizer.process(text);
1521
// token_ids.printData<float>();
1622

1723
auto config = BertConfig();
1824
auto model = BertModel(config);
1925
model.load(model_file);
2026

21-
auto res = model({token_ids, type_ids, position_ids});
22-
res[0].printData<float>();
27+
auto res = model({token_ids, type_ids, position_ids})[0];
28+
29+
res.printData<float>();
30+
31+
return 0;
2332
}

src/models/bert/configuration_bert.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ struct BertConfig : public TransformerConfig {
4848
bos_token_id = 151643;
4949
eos_token_id = 151645;
5050
hidden_act = "GELU";
51+
pooling_type = "mean";
5152
hidden_size = 384;
5253
initializer_range = 0.02;
5354
intermediate_size = 1536;
@@ -73,6 +74,7 @@ struct BertConfig : public TransformerConfig {
7374
int bos_token_id = 151643;
7475
int eos_token_id = 151643;
7576
std::string hidden_act = "GELU";
77+
std::string pooling_type = "mean";
7678
int hidden_size = 1024;
7779
float initializer_range = 0.02;
7880
int intermediate_size = 2816;

src/models/bert/modeling_bert.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,30 @@ class BertLayer : public Module {
7474
Layer attn_norm, ff_norm;
7575
};
7676

77+
78+
class AvgPooler : public Module{
79+
public:
80+
AvgPooler() = default;
81+
std::vector<Tensor> Forward(std::vector<Tensor> inputs, std::vector<std::any> args) override {
82+
auto x = inputs[0];
83+
x = x.mean(SEQUENCE);
84+
return {x};
85+
}
86+
};
87+
7788
class BertModel : public Module {
7889
public:
7990
BertModel(BertConfig &config) {
8091
embeddings = BertEmbeddings(config.vocab_size, config.hidden_size, config.type_vocab_size, config.max_position_embeddings, config.layer_norm_eps, config.names_config);
8192
layers = List<BertLayer>(config.num_hidden_layers, config, "encoder.layer.");
93+
94+
if(config.pooling_type == "mean") {
95+
pooler = make_unique<AvgPooler>();
96+
}else {
97+
// print not support pooling type and exit
98+
std::cout << "Not support pooling type: " << config.pooling_type << std::endl;
99+
exit(0);
100+
}
82101
}
83102

84103
std::vector<Tensor> Forward(std::vector<Tensor> inputs, std::vector<std::any> args) override {
@@ -88,12 +107,16 @@ class BertModel : public Module {
88107
x = layer({x})[0];
89108
}
90109

110+
x = (*pooler)({x})[0];
111+
91112
return {x};
92113
}
93114

94115
private:
95116
BertEmbeddings embeddings;
96117
std::vector<BertLayer> layers;
118+
119+
unique_ptr<Module> pooler;
97120
};
98121

99122
#endif //! MODELING_BERT_HPP

src/models/bert/tokenization_bert.hpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,23 @@ using namespace mllm;
1616

1717
class BertTokenizer final : public WordPieceTokenizer {
1818
public:
19-
explicit BertTokenizer(const std::string &vocab_file, bool bos = true) :
19+
explicit BertTokenizer(const std::string &vocab_file, bool add_special_tokens = true) :
2020
WordPieceTokenizer(vocab_file) {
2121
Module::initBackend(MLLM_CPU);
22+
_add_special_tokens = add_special_tokens;
23+
this->add_special_tokens({"[PAD]", "[CLS]", "[SEP]", "[MASK]"});
2224
}
23-
std::tuple<Tensor, Tensor, Tensor> process(std::string &text){
25+
std::tuple<Tensor, Tensor, Tensor> process(std::string text){
26+
if (_add_special_tokens) {
27+
text = "[CLS] " + text + " [SEP]";
28+
}
2429
auto tokens_id = vector<token_id_t>();
2530
WordPieceTokenizer::tokenize(text, tokens_id, false);
31+
// printf("token: ");
32+
// for (auto &token_id : tokens_id) {
33+
// printf("%d ", token_id);
34+
// }
35+
printf("\n");
2636
auto tokens_type = vector<token_id_t>(tokens_id.size(), 0);
2737
auto position_ids = vector<token_id_t>(tokens_id.size());
2838
for (size_t i = 0; i < tokens_id.size(); i++) {
@@ -34,6 +44,9 @@ class BertTokenizer final : public WordPieceTokenizer {
3444
tokens2Input(position_ids, "input_position_ids")
3545
};
3646
}
47+
48+
private:
49+
bool _add_special_tokens;
3750
};
3851

3952
#endif //! TOKENIZATION_BERT_HPP

src/tokenizers/WordPiece/WordPiece.cpp

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,52 @@ bool mllm::BasicTokenizer::is_chinese_char(wchar_t cp) {
104104
(cp >= 0x20000 && cp <= 0x2A6DF);
105105
}
106106

107+
std::vector<std::wstring> splitBySet(const std::wstring& text, const std::unordered_set<std::wstring>& words) {
108+
std::vector<std::wstring> result;
109+
size_t pos = 0;
110+
111+
while (pos < text.length()) {
112+
size_t minPos = std::wstring::npos;
113+
std::wstring foundWord;
114+
115+
// 查找最近的匹配项
116+
for (const auto& word : words) {
117+
size_t found = text.find(word, pos);
118+
if (found != std::wstring::npos && (found < minPos)) {
119+
minPos = found;
120+
foundWord = word;
121+
}
122+
}
123+
124+
// 如果找到匹配项,处理之前的文本和匹配项
125+
if (minPos != std::wstring::npos) {
126+
if (minPos > pos) {
127+
// 添加匹配项前的文本
128+
result.push_back(text.substr(pos, minPos - pos));
129+
}
130+
// 添加匹配项
131+
result.push_back(foundWord);
132+
pos = minPos + foundWord.size();
133+
} else {
134+
// 没有更多匹配项,添加剩余所有文本
135+
result.push_back(text.substr(pos));
136+
break;
137+
}
138+
}
139+
140+
return result;
141+
}
142+
107143
std::vector<std::wstring> mllm::BasicTokenizer::tokenize(const std::wstring& text) {
108144
std::wstring cleaned = clean_text(text);
109145
if (_tokenize_chinese_chars)
110146
cleaned = tokenize_chinese_chars(cleaned);
111-
std::vector<std::wstring> split_tokens = whitespace_tokenize(cleaned);
147+
std::vector<std::wstring> white_space_splited_tokens = whitespace_tokenize(cleaned);
148+
std::vector<std::wstring> split_tokens;
149+
for (const auto& token : white_space_splited_tokens) {
150+
auto sub_tokens = splitBySet(token, never_split);
151+
split_tokens.insert(split_tokens.end(), sub_tokens.begin(), sub_tokens.end());
152+
}
112153
std::vector<std::wstring> output_tokens;
113154

114155
for (auto& token : split_tokens) {
@@ -138,6 +179,9 @@ std::vector<std::wstring> mllm::BasicTokenizer::tokenize(const std::wstring& tex
138179

139180
return output_tokens;
140181
}
182+
void mllm::BasicTokenizer::add_never_split(const std::wstring &token) {
183+
never_split.insert(token);
184+
}
141185

142186
void mllm::WordPieceTokenizer::tokenize(const string &text, vector<token_id_t> &tokens, bool bos) {
143187
auto wstr = utf8_to_wstring(text);
@@ -167,6 +211,7 @@ void mllm::WordPieceTokenizer::tokenize(const string &text, vector<token_id_t> &
167211
break;
168212
} else{
169213
token_strs.push_back(str);
214+
// printf("word: %s\n", str.c_str());
170215
}
171216
start = end;
172217
}
@@ -177,3 +222,9 @@ void mllm::WordPieceTokenizer::tokenize(const string &text, vector<token_id_t> &
177222
tokens.push_back(vocab_map_[token_str]);
178223
}
179224
}
225+
void mllm::WordPieceTokenizer::add_special_tokens(const vector<std::string> &special_tokens) {
226+
// add never split tokens to basic tokenizer
227+
for (const auto& token : special_tokens) {
228+
basic_tokenizer.add_never_split(utf8_to_wstring(token));
229+
}
230+
}

src/tokenizers/WordPiece/WordPiece.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class BasicTokenizer {
3333

3434
std::vector<std::wstring> tokenize(const std::wstring& text);
3535

36+
void add_never_split(const std::wstring& token);
37+
3638
private:
3739
bool do_lower_case;
3840
bool _tokenize_chinese_chars;
@@ -55,6 +57,8 @@ class WordPieceTokenizer: public Tokenizer {
5557

5658
WordPieceTokenizer(const std::string &vocab_file);
5759
void tokenize(const std::string &text, std::vector<token_id_t> &tokens, bool bos) override;
60+
61+
void add_special_tokens(const std::vector<std::string> &special_tokens);
5862
};
5963

6064
}

0 commit comments

Comments
 (0)