-
Notifications
You must be signed in to change notification settings - Fork 2
/
vocab.py
224 lines (199 loc) · 7.48 KB
/
vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import numpy as np
import copy
curdir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(curdir))
class Vocab(object):
"""
Vocabulary preprocessor replacement for tensorflow.contrib.learn.preprocessing.VocabularyProcessor
"""
def __init__(self, filename=None, initial_tokens=None, lower=False):
self.id2token = {}
self.token2id = {}
self.token_cnt = {}
self.pos2id = {'pad': 0}
self.lower = lower
self.embed_dim = None
self.embeddings = None
self.pad_token = '<blank>'
self.unk_token = '<unk>'
self.initial_tokens = initial_tokens if initial_tokens is not None else []
self.initial_tokens.extend([self.pad_token, self.unk_token])
for token in self.initial_tokens:
self.add(token)
if filename is not None:
self.load_from_file(filename)
def size(self):
"""
Get the size of vocabulary.
:return: Int An integer indicating the size.
"""
return len(self.id2token)
def load_from_file(self, file_path):
"""
Loads the vocab from file_path
:param file_path: str a file with a word in each line.
:return: None
"""
for line in open(file_path, 'r'):
token = line.strip('\n')
self.add(token)
def get_id(self, token):
"""
Gets the id of a token, returns the id of unk token if token id not in vocab.
:param token: str a string indicating the word
:return: Int An integer
"""
token = token.lower() if self.lower else token
try:
return self.token2id[token]
except KeyError as e:
# print('Unknown token {}'.format(token))
return self.token2id[self.unk_token]
def get_tf(self, token):
token = token.lower() if self.lower else token
try:
return self.token_cnt[token]
except KeyError as e:
return 1
def get_token(self, idx):
"""
Gets the token corresponding to idx, returns unk token id idx is not in vocab
:param idx: int an integer
:return: token: str a token string
"""
try:
return self.id2token[idx]
except KeyError as e:
# print("Unknown index")
return self.unk_token
def add(self, token, cnt=1):
"""
Adds the token to vocab.
:param token: str
:param cnt: int a num indicating the count of the token to add, default is 1
:return: idx: int
"""
token = token.lower() if self.lower else token
if token in self.token2id:
# Token exist. Gets the idx straightforwardly.
idx = self.token2id[token]
else:
# Token doesn't exist. Push into the id2token and token2id dict.
idx = len(self.id2token)
self.id2token[idx] = token
self.token2id[token] = idx
if cnt > 0:
if token in self.token_cnt:
self.token_cnt[token] += cnt
else:
self.token_cnt[token] = cnt
return idx
def add_pos2id(self, tag):
if self.pos2id.__contains__(tag):
pass
else:
self.pos2id[tag] = len(self.pos2id)
def filter_tokens_by_cnt(self, min_cnt):
"""
Filter the tokens in vocab by their count.
:param min_cnt: int tokens with frequency less than min_cnt is filtered
:return: None
"""
filtered_tokens = [token for token in self.token2id if self.token_cnt[token] >= min_cnt]
self.token2id = {}
self.id2token = {}
for token in self.initial_tokens:
self.add(token, cnt=0)
for token in filtered_tokens:
self.add(token, cnt=0)
def randomly_init_embeddings(self, embed_dim):
"""
Randomly initializes the embeddings for each token.
:param embed_dim: int the size of the embedding for each token.
:return: None
"""
self.embed_dim = embed_dim
self.embeddings = np.random.rand(self.size(), embed_dim)
for token in [self.pad_token, self.unk_token]:
self.embeddings[self.get_id(token)] = np.zeros([self.embed_dim])
def load_pretrained_embeddings(self, embedding_path):
"""
Load the pre-trained word embeddings from embedding_path.
Reconstructed the token2id and id2token dict. Tokens not in pre-trained embeddings will be filtered.
:param embedding_path: str
:return: None
"""
trained_embeddings = {}
with open(embedding_path, 'r') as fin:
while True:
line = fin.readline()
if not line:
print("Pre-trained embeddings load successfully!")
break
contents = line.strip().split(' ')
token = contents[0]
if token not in self.token2id:
continue
trained_embeddings[token] = list(map(float, contents[1:]))
if self.embed_dim is None:
self.embed_dim = len(contents) - 1
filtered_tokens = trained_embeddings.keys()
# rebuild the token x id map
self.token2id = {}
self.id2token = {}
for token in self.initial_tokens:
self.add(token, cnt=0)
for token in filtered_tokens:
self.add(token, cnt=0)
# load embeddings
self.embeddings = np.zeros([self.size(), self.embed_dim])
for token in self.token2id.keys():
if token in trained_embeddings:
self.embeddings[self.get_id(token)] = trained_embeddings[token]
def convert2ids(self, tokens):
"""
Convert a list of tokens to ids, use unk_token if the token is not in vocab.
:param tokens: list A list of tokens
:return: list A list of ids
"""
vec = [self.get_id(label) for label in tokens]
return vec
def convert2tfs(self, tokens):
tfs = [self.get_tf(label) for label in tokens]
return tfs
def recover_from_ids(self, ids, stop_id=None):
"""
Convert a list of ids to tokens, stop converting if the stop_id is encountered.
:param ids: list A list of ids to convert.
:param stop_id: int The stop id, default is None.
:return:
"""
tokens = []
for i in ids:
tokens.append(self.get_token(i))
if stop_id is not None and i == stop_id:
break
return tokens
def desensitization(self):
"""
Desensitization for the data set by removing the transformation between id and token.
"""
# id2token to id2id
for tmp_id in self.id2token.keys():
self.id2token[tmp_id] = tmp_id
# token cnt to id cnt
token_cnt_copy = copy.deepcopy(self.token_cnt)
for tmp_token in token_cnt_copy.keys():
if tmp_token in self.token2id:
token_id = self.token2id[tmp_token]
self.token_cnt[token_id] = self.token_cnt.pop(tmp_token)
else:
self.token_cnt.pop(tmp_token)
# token2id to id2id
token2id_copy = copy.deepcopy(self.token2id)
for tmp_token in token2id_copy.keys():
self.token2id[token2id_copy[tmp_token]] = self.token2id.pop(tmp_token)