/
build_lm_train_test.py
79 lines (68 loc) · 2.67 KB
/
build_lm_train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import random
import argparse
import os
import sys
def filterLine(line, lang, targets):
if lang=='latin':
line = ''.join([i for i in line if not (i.isdigit() or i=='#')])
elif lang=='english':
wrong_pos = False
correct_pos = False
for target in targets:
line_l = line.split()
if target in line:
line = line.replace(target, target[:-3])
correct_pos = True
if target[:-3] in line_l:
wrong_pos = True
if correct_pos:
return line
elif wrong_pos:
return None
return line
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--corpus_paths", default='data/english/english_1.txt;data/english/english_2.txt', type=str,
help="Paths to all corpus time slices separated by ';'.")
parser.add_argument("--target_path", default='data/english/targets.txt', type=str,
help="Path to target files")
parser.add_argument("--language", const='english', nargs='?',
help="Choose a language", choices=['english', 'latin', 'swedish', 'german'])
parser.add_argument("--lm_train_test_folder", default='data/english',
help="Path to folder that contains output language model train and test sets")
args = parser.parse_args()
lang = args.language
languages = ['english', 'latin', 'swedish', 'german']
if lang not in languages:
print("Language not valid, valid choices are: ", ", ".join(languages))
sys.exit()
target_path = args.target_path
corpora = args.corpus_paths.split(';')
output_folder = args.lm_train_test_folder
data = []
if lang == 'english':
targets = []
with open(target_path, 'r', encoding='utf8') as f:
for line in f:
target = line.strip()
if len(target) > 0 :
targets.append(target)
else:
targets = None
for corpus in corpora:
with open(corpus, 'r', encoding='utf8') as f:
for line in f:
line = filterLine(line, lang, targets)
if line is not None:
data.append(line)
random.shuffle(data)
valid_index = int(0.9 * len(data))
output_train = open(os.path.join(output_folder,'train.txt'), 'w', encoding='utf8')
output_test = open(os.path.join(output_folder, 'test.txt'), 'w', encoding='utf8')
for idx, sent in enumerate(data):
if idx < valid_index:
output_train.write(sent)
else:
output_test.write(sent)
output_train.close()
output_test.close()