-
Notifications
You must be signed in to change notification settings - Fork 4
/
reader.py
48 lines (46 loc) · 1.67 KB
/
reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import re
EOS_ID = 1
UNK_ID = 2
class reader():
def __init__(self, file_name_post, file_name_resp, file_name_word):
with open(file_name_word, 'rb') as file_word:
self.d = {}
self.symbol = []
num = 0
for line in file_word.readlines():
line = line[:-1]
self.symbol.append(line)
self.d[line] = num
num += 1
self.file_name_post = file_name_post
self.file_name_resp = file_name_resp
self.post = open(self.file_name_post, 'rb')
self.resp = open(self.file_name_resp, 'rb')
self.epoch = 0
self.k = 0
def get_batch(self, batch_size):
result = []
self.k += batch_size
for _ in range(batch_size):
post = self.post.readline()
resp = self.resp.readline()
if not post:
self.restore()
self.epoch += 1
self.k = 0
print 'epoch: ', self.epoch
return self.get_batch(batch_size)
post = post[:-1]
resp = resp[:-1]
words_post = re.split(' ', post)
words_resp = re.split(' ', resp)
index_post = [self.d[word] if word in self.d else UNK_ID for word in words_post]
index_resp = [self.d[word] if word in self.d else UNK_ID for word in words_resp]
index_resp = index_resp + [EOS_ID]
result.append((index_post, index_resp))
return result
def restore(self):
self.post.close()
self.resp.close()
self.post = open(self.file_name_post, 'rb')
self.resp = open(self.file_name_resp, 'rb')