-
Notifications
You must be signed in to change notification settings - Fork 1
/
data_process.py
77 lines (75 loc) · 2.68 KB
/
data_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# 加载并处理数据,对输入的文本进行编码处理,词转化成id,并对长度进行了对齐。
def load_esim_data_and_labels(data_path,char_to_id,q_max_len=22,t_max_len=22):
f = open(data_path, 'r')
y = []
t_feat_index = []
q_feat_index = []
for line in f:
line = line.strip().split(",")
y.append([int(line[4])])
query = []
for i in line[1].split():
if i in char_to_id:
query.append(char_to_id[i])
else:
query.append(0)
if len(query) < q_max_len:
query = query + [0] * (q_max_len - len(query))
else:
query = query[:q_max_len]
q_feat_index.append(query)
title = []
for i in line[3].split():
if i in char_to_id:
title.append(char_to_id[i])
else:
title.append(0)
if len(title) < t_max_len:
title = title + [0] * (t_max_len - len(title))
else:
title = title[:t_max_len]
t_feat_index.append(title)
f.close()
return {"q_feat_index": q_feat_index,"t_feat_index": t_feat_index,"label": y}
# 一批一批加载并处理数据,对输入的文本进行编码处理,词转化成id,并对长度进行了对齐。
def yield_esim_data_and_labels(data_path,char_to_id,batch_size,q_max_len=22,t_max_len=22):
f = open(data_path, 'r')
y = []
t_feat_index = []
q_feat_index = []
temp_b=0
for line in f:
line = line.strip().split(",")
y.append([int(line[4])])
query = []
for i in line[1].split():
if i in char_to_id:
query.append(char_to_id[i])
else:
query.append(0)
if len(query) < q_max_len:
query = query + [0] * (q_max_len - len(query))
else:
query = query[:q_max_len]
q_feat_index.append(query)
title = []
for i in line[3].split():
if i in char_to_id:
title.append(char_to_id[i])
else:
title.append(0)
if len(title) < t_max_len:
title = title + [0] * (t_max_len - len(title))
else:
title = title[:t_max_len]
t_feat_index.append(title)
temp_b += 1
if temp_b == batch_size:
yield {"q_feat_index": q_feat_index, "t_feat_index": t_feat_index, "label": y}
y = []
t_feat_index = []
q_feat_index = []
temp_b = 0
if temp_b!=0:
yield {"q_feat_index": q_feat_index,"t_feat_index": t_feat_index,"label": y}
f.close()