-
Notifications
You must be signed in to change notification settings - Fork 3
/
TianGong-ST.py
361 lines (333 loc) · 16.9 KB
/
TianGong-ST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
# !/usr/bin/python
# coding: utf8
from xml.dom.minidom import parse
import xml.dom.minidom
import time
import pprint
import string
import sys
sys.path.append("..")
import argparse
import re
import os
import numpy as np
import torch
import torch.nn as nn
from utils import *
from math import log
import random
def xml_clean(args):
# open xml file reader & writer
xml_reader = open(os.path.join(args.input, args.dataset), 'r')
xml_writer = open(os.path.join(args.input, 'clean-' + args.dataset), 'w')
# print(xml_reader)
# print(xml_writer)
# remove useless lines
read_line_count = 0
removed_line_count = 0
interaction_count = 0
print(' - {}'.format('start reading from xml file...'))
xml_lines = xml_reader.readlines()
print(' - {}'.format('read {} lines'.format(len(xml_lines))))
print(' - {}'.format('start removing useless lines...'))
for xml_line in xml_lines:
# print(xml_line, end='')
read_line_count += 1
if xml_line.find('<interaction num=') != -1:
interaction_count += 1
if xml_line_removable(xml_line):
# A line that should be removed
removed_line_count += 1
if removed_line_count % 1000000 == 0:
print(' - {}'.format('remove {} lines...'.format(removed_line_count)))
else:
xml_writer.write(xml_line)
# It is guaranteed that there are 10 docs for each query
assert read_line_count == len(xml_lines)
assert removed_line_count == interaction_count + interaction_count * 10 * (1 + 1 + 2 + 6)
print(' - {}'.format('read {} lines'.format(read_line_count)))
print(' - {}'.format('totally {} iteractions'.format(interaction_count)))
print(' - {}'.format('totally remove {} lines'.format(removed_line_count)))
def generate_dict_list(args):
punc = '\\~`!#$%^&*()_+-=|\';":/.,?><~·!@#¥%……&*()——+-=“:’;、。,?》《{}'
session_sid = {}
query_qid = {}
url_uid = {}
vtype_vid = {}
uid_description = {}
print(' - {}'.format('start parsing xml file...'))
DOMTree = xml.dom.minidom.parse(os.path.join(args.input, 'clean-' + args.dataset))
tiangong2020 = DOMTree.documentElement
sessions = tiangong2020.getElementsByTagName('session')
# generate infos_per_session
print(' - {}'.format('generating infos_per_session...'))
infos_per_session = []
junk_interation_num = 0
for session in sessions:
info_per_session = {}
# get the session id
session_number = int(session.getAttribute('num'))
if not (session_number in session_sid):
session_sid[session_number] = len(session_sid)
info_per_session['session_number'] = session_number
info_per_session['sid'] = session_sid[session_number]
# print('session: {}'.format(session_number))
# Get information within a query
interactions = session.getElementsByTagName('interaction')
interaction_infos = []
for interaction in interactions:
interaction_info = {}
interaction_number = int(interaction.getAttribute('num'))
query_id = interaction.getElementsByTagName('query_id')[0].childNodes[0].data
if not (query_id in query_qid):
query_qid[query_id] = len(query_qid)
interaction_info['query'] = query_id
interaction_info['qid'] = query_qid[query_id]
interaction_info['session'] = info_per_session['session_number']
interaction_info['sid'] = info_per_session['sid']
# print('interaction: {}'.format(interaction_number))
# print('query_id: {}'.format(query_id))
# Get document infomation
docs = interaction.getElementsByTagName('results')[0].getElementsByTagName('result')
doc_infos = []
if len(docs) == 0:
print(' - {}'.format('WARNING: find a query with no docs: {}'.format(query_id)))
junk_interation_num += 1
continue
elif len(docs) > 10:
# more than 10 docs is not ok. May cause index out-of-range in embeddings
print(' - {}'.format('WARNING: find a query with more than 10 docs: {}'.format(query_id)))
junk_interation_num += 1
continue
elif len(docs) < 10:
# less than 10 docs is ok. Never cause index out-of-range in embeddings
print(' - {}'.format('WARNING: find a query with less than 10 docs: {}'.format(query_id)))
junk_interation_num += 1
continue
for doc in docs:
# WARNING: there might be junk data in TianGong-ST (e.g. rank > 10), so we use manual doc_rank here
doc_rank = int(doc.getAttribute('rank'))
assert 1 <= doc_rank and doc_rank <= 10
doc_id = doc.getElementsByTagName('docid')[0].childNodes[0].data
vtype = doc.getElementsByTagName('vtype')[0].childNodes[0].data
if not (doc_id in url_uid):
url_uid[doc_id] = len(url_uid)
if not (vtype in vtype_vid):
vtype_vid[vtype] = len(vtype_vid)
doc_info = {}
doc_info['rank'] = doc_rank
doc_info['url'] = doc_id
doc_info['uid'] = url_uid[doc_id]
doc_info['vtype'] = vtype
doc_info['vid'] = vtype_vid[vtype]
doc_info['click'] = 0
doc_infos.append(doc_info)
# print(' doc ranks at {}: {}'.format(doc_rank, doc_id))
# Get click information if there are clicked docs
# Maybe there are no clicks in this query
clicks = interaction.getElementsByTagName('clicked')
if len(clicks) > 0:
clicks = clicks[0].getElementsByTagName('click')
for click in clicks:
clicked_doc_rank = int(click.getElementsByTagName('rank')[0].childNodes[0].data)
for item in doc_infos:
if item['rank'] == clicked_doc_rank:
item['click'] = 1
break
# print(' click doc ranked at {}'.format(clicked_doc_rank))
else:
pass
# print(' click nothing')
interaction_info['docs'] = doc_infos
interaction_info['uids'] = [doc['uid'] for doc in doc_infos]
interaction_info['vids'] = [doc['vid'] for doc in doc_infos]
interaction_info['clicks'] = [doc['click'] for doc in doc_infos]
interaction_infos.append(interaction_info)
info_per_session['interactions'] = interaction_infos
infos_per_session.append(info_per_session)
print(' - {}'.format('abandon {} junk interactions'.format(junk_interation_num)))
# generate infos_per_query
print(' - {}'.format('generating infos_per_query...'))
infos_per_query = []
for info_per_session in infos_per_session:
interaction_infos = info_per_session['interactions']
for interaction_info in interaction_infos:
infos_per_query.append(interaction_info)
# save and check infos_per_session
print(' - {}'.format('save and check infos_per_session...'))
print(' - {}'.format('length of infos_per_session: {}'.format(len(infos_per_session))))
# pprint.pprint(infos_per_session)
# print('length of infos_per_session: {}'.format(len(infos_per_session)))
save_list(args.output, 'infos_per_session.list', infos_per_session)
list1 = load_list(args.output, 'infos_per_session.list')
assert len(infos_per_session) == len(list1)
for idx, item in enumerate(infos_per_session):
assert item == list1[idx]
# save and check infos_per_query
print(' - {}'.format('save and check infos_per_query...'))
print(' - {}'.format('length of infos_per_query: {}'.format(len(infos_per_query))))
# pprint.pprint(infos_per_query)
# print('length of infos_per_query: {}'.format(len(infos_per_query)))
save_list(args.output, 'infos_per_query.list', infos_per_query)
list2 = load_list(args.output, 'infos_per_query.list')
assert len(infos_per_query) == len(list2)
for idx, item in enumerate(infos_per_query):
assert item == list2[idx]
# save and check dictionaries
print(' - {}'.format('save and check session_sid, query_qid, url_uid...'))
print(' - {}'.format('unique session number: {}'.format(len(session_sid))))
print(' - {}'.format('unique query number: {}'.format(len(query_qid))))
print(' - {}'.format('unique doc number: {}'.format(len(url_uid))))
print(' - {}'.format('unique vtype number: {}'.format(len(vtype_vid))))
save_dict(args.output, 'session_sid.dict', session_sid)
save_dict(args.output, 'query_qid.dict', query_qid)
save_dict(args.output, 'url_uid.dict', url_uid)
save_dict(args.output, 'vtype_vid.dict', vtype_vid)
dict1 = load_dict(args.output, 'session_sid.dict')
dict2 = load_dict(args.output, 'query_qid.dict')
dict3 = load_dict(args.output, 'url_uid.dict')
dict4 = load_dict(args.output, 'vtype_vid.dict')
assert len(session_sid) == len(dict1)
assert len(query_qid) == len(dict2)
assert len(url_uid) == len(dict3)
assert len(vtype_vid) == len(dict4)
for key in dict1:
assert dict1[key] == session_sid[key]
assert key > 0
for key in dict2:
assert dict2[key] == query_qid[key]
assert key[0] == 'q'
assert key[1:] != ''
for key in dict3:
assert dict3[key] == url_uid[key]
assert key[0] == 'd'
assert key[1:] != ''
for key in dict4:
assert dict4[key] == vtype_vid[key]
assert type(key) == type('')
assert key[1:] != ''
print(' - {}'.format('Done'))
def generate_ncm_txt(args):
# load session_sid & query_qid & url_uid
print(' - {}'.format('loading session_sid & query_qid & url_uid...'))
session_sid = load_dict(args.output, 'session_sid.dict')
query_qid = load_dict(args.output, 'query_qid.dict')
url_uid = load_dict(args.output, 'url_uid.dict')
vtype_vid = load_dict(args.output, 'vtype_vid.dict')
# write train.txt & dev.txt & test.txt per query
print(' - {}'.format('loading the infos_per_session...'))
infos_per_session = load_list(args.output, 'infos_per_session.list')
# Separate all sessions into train : dev : test
session_num = len(infos_per_session)
train_dev_split = 117431
dev_test_split = 117431 + 13154
train_session_num = 117431
dev_session_num = 13154
test_session_num = session_num - train_session_num - dev_session_num
print(' - {}'.format('train sessions: {}'.format(train_session_num)))
print(' - {}'.format('dev sessions: {}'.format(dev_session_num)))
print(' - {}'.format('test sessions: {}'.format(test_session_num)))
print(' - {}'.format('total sessions: {}'.format(session_num)))
# generate train & dev & test sessions
print(' - {}'.format('generating train & dev & test data per session...'))
random.seed(time.time())
for _ in range(10):
random.shuffle(infos_per_session)
train_sessions = infos_per_session[:train_dev_split]
dev_sessions = infos_per_session[train_dev_split:dev_test_split]
test_sessions = infos_per_session[dev_test_split:]
assert train_session_num == len(train_sessions), 'train_session_num: {}, len(train_sessions): {}'.format(train_session_num, len(train_sessions))
assert dev_session_num == len(dev_sessions), 'dev_session_num: {}, len(dev_sessions): {}'.format(dev_session_num, len(dev_sessions))
assert test_session_num == len(test_sessions), 'test_session_num: {}, len(test_sessions): {}'.format(test_session_num, len(test_sessions))
assert session_num == len(train_sessions) + len(dev_sessions) + len(test_sessions), 'session_num: {}, len(train_sessions) + len(dev_sessions) + len(test_sessions): {}'.format(session_num, len(train_sessions) + len(dev_sessions) + len(test_sessions))
# generate train & dev & test sessions
print(' - {}'.format('generating train & dev & test data per queries...'))
train_queries = []
dev_queries = []
test_queries = []
for info_per_session in train_sessions:
interaction_infos = info_per_session['interactions']
for interaction_info in interaction_infos:
train_queries.append(interaction_info)
for info_per_session in dev_sessions:
interaction_infos = info_per_session['interactions']
for interaction_info in interaction_infos:
dev_queries.append(interaction_info)
for info_per_session in test_sessions:
interaction_infos = info_per_session['interactions']
for interaction_info in interaction_infos:
test_queries.append(interaction_info)
print(' - {}'.format('train queries: {}'.format(len(train_queries))))
print(' - {}'.format('dev queries: {}'.format(len(dev_queries))))
print(' - {}'.format('test queries: {}'.format(len(test_queries))))
print(' - {}'.format('total queries: {}'.format(len(train_queries) + len(dev_queries) + len(test_queries))))
# Write back to txt files
print(' - {}'.format('writint back to txt files...'))
print(' - {}'.format('writing into {}/train_per_query.txt'.format(args.output)))
generate_data_per_query(train_queries, np.arange(0, len(train_queries)), args.output, 'train_per_query.txt')
print(' - {}'.format('writing into {}/dev_per_query.txt'.format(args.output)))
generate_data_per_query(dev_queries, np.arange(0, len(dev_queries)), args.output, 'dev_per_query.txt')
print(' - {}'.format('writing into {}/test_per_query.txt'.format(args.output)))
generate_data_per_query(test_queries, np.arange(0, len(test_queries)), args.output, 'test_per_query.txt')
# open human_labels.txt
print('===> {}'.format('processing human label txt...'))
label_reader = open(os.path.join(args.input + '../human_label/', 'sogou_st_human_labels.txt'), 'r')
label_writer = open(os.path.join(args.output, 'human_label.txt'), 'w')
# start transferring
read_line_count = 0
write_line_count = 0
print(' - {}'.format('start reading from human_label.txt...'))
lines = label_reader.readlines()
print(' - {}'.format('read {} lines'.format(len(lines))))
print(' - {}'.format('start transferring...'))
for line in lines:
# there is a mixture of separators: ' ' & '\t'
line_entry = [str(i) for i in line.strip().split()]
read_line_count += 1
# print(line_entry)
line_entry[1] = str(session_sid[int(line_entry[1])])
line_entry[2] = str(query_qid[line_entry[2]])
line_entry[3] = str(url_uid[line_entry[3]])
line_entry.append('\n')
write_line_count += 1
label_writer.write('\t'.join(line_entry))
label_reader.close()
label_writer.close()
assert read_line_count == len(lines)
assert write_line_count % 10 == 0
print(' - {}'.format('write {} lines'.format(write_line_count)))
print(' - {}'.format('finish reading from human_label.txt...'))
def main():
parser = argparse.ArgumentParser('TianGong-ST')
parser.add_argument('--dataset', default='sogousessiontrack2020.xml',
help='dataset name')
parser.add_argument('--input', default='../dataset/TianGong-ST/data/',
help='input path')
parser.add_argument('--output', default='./data/version6',
help='output path')
parser.add_argument('--xml_clean', action='store_true',
help='remove useless lines in xml files, to reduce the size of xml file')
parser.add_argument('--dict_list', action='store_true',
help='generate dicts and lists for info_per_session/info_per_query')
parser.add_argument('--ncm_txt', action='store_true',
help='generate NCM data txt')
parser.add_argument('--trainset_ratio', default=0.8,
help='ratio of the train session/query according to the total number of sessions/queries')
parser.add_argument('--devset_ratio', default=0.1,
help='ratio of the dev session/query according to the total number of sessions/queries')
args = parser.parse_args()
if args.xml_clean:
# remove useless lines in xml files, to reduce the size of xml file
print('===> {}'.format('cleaning xml file...'))
xml_clean(args)
if args.dict_list:
# generate info_per_session & info_per_query
print('===> {}'.format('generating dicts and lists...'))
generate_dict_list(args)
if args.ncm_txt:
# load lists saved by generate_dict_list() and generates train.txt & dev.txt & test.txt
print('===> {}'.format('generating train & dev & test data txt...'))
generate_ncm_txt(args)
print('===> {}'.format('Done.'))
if __name__ == '__main__':
main()