Skip to content

Commit c171915

Browse files
committed
Added lock to multiprocessing
1 parent 45b59b5 commit c171915

File tree

2 files changed

+40
-23
lines changed

2 files changed

+40
-23
lines changed

scrape.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
import sys
1010
import re
1111
import csv
12-
from multiprocessing import Pool
12+
import glob
13+
from multiprocessing import Pool, Lock
1314

1415

15-
def save_document(doc_id, output_dir, csv_writer):
16+
def save_document(doc_id, output_dir):
1617
'''Gather document details'''
1718

1819
# Get document details
@@ -53,7 +54,7 @@ def save_document(doc_id, output_dir, csv_writer):
5354
# Save content
5455
open(os.path.join(output_dir, str(doc_id) + ".txt"), 'w').write(doc["content"])
5556
del doc["content"]
56-
csv_writer.writerow(doc)
57+
return doc
5758

5859

5960
def drawLoadingBar(val, maximum):
@@ -106,7 +107,10 @@ def parse_args(args):
106107

107108
return parser.parse_args(args)
108109

110+
111+
109112
def main(args):
113+
global _process_line
110114
# Make data dir
111115
try:
112116
os.mkdir(args.output_dir)
@@ -141,26 +145,38 @@ def main(args):
141145
break
142146
drawLoadingBar(currentCount, args.max_per_lang)
143147
currentCount += 1
148+
writer.writerow(save_document(d["id"], args.output_dir))
144149

145-
save_document(d["id"], args.output_dir, writer)
146150
elif args.command == "recreate":
147-
writers = {}
148151
lines = list(args.id_file.readlines()[1:])
152+
153+
for split in ["train", "test", "dev"]:
154+
with open(os.path.join(args.output_dir, "labels." + split + ".csv"), 'a') as f:
155+
writer = csv.DictWriter(f, fieldnames=["document_id", "author_id", "L1", "english_proficiency"])
156+
writer.writeheader()
157+
149158
indexes = []
150159
def _process_line(tup):
151160
index, line = tup
152161
doc_id, split = line.strip().split(",")
153-
if split not in writers.keys():
154-
f = open(os.path.join(args.output_dir, "labels." + split + ".csv"), 'w')
155-
writers[split] = csv.DictWriter(f, fieldnames=["document_id", "author_id", "L1", "english_proficiency"])
156-
writers[split].writeheader()
157-
158-
save_document(doc_id, args.output_dir, writers[split])
159-
indexes.append(index)
160-
drawLoadingBar(len(indexes), len(lines))
161-
162-
with Pool(processes=args.agents) as pool:
163-
result = pool.map(_process_line, enumerate(lines), 1)
162+
doc = save_document(doc_id, args.output_dir)
163+
if doc is not None:
164+
l.acquire()
165+
with open(os.path.join(args.output_dir, "labels." + split + ".csv"), 'a') as f:
166+
writer = csv.DictWriter(f, fieldnames=["document_id", "author_id", "L1", "english_proficiency"])
167+
writer.writerow(doc)
168+
indexes.append(index)
169+
drawLoadingBar(len(glob.glob(os.path.join(args.output_dir, "*.txt"))), len(lines))
170+
l.release()
171+
172+
def init(l):
173+
global lock
174+
lock = l
175+
176+
l = Lock()
177+
print("start pool")
178+
with Pool(processes=args.num_agents, initializer=init, initargs=(l,)) as pool:
179+
_ = pool.map(_process_line, enumerate(lines), 1)
164180

165181
if __name__ == "__main__":
166182
main(parse_args(sys.argv[1:]))

tests/test_scrape.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@ def test_save_document_normal():
8383
with tempfile.TemporaryDirectory() as tmpdirname:
8484
csv_filename = os.path.join(tmpdirname, "test.csv")
8585
with open(csv_filename, 'w') as f:
86-
writer = csv.DictWriter(f, fieldnames=["document_id", "author_id", "L1", "english_proficiency"])
86+
# writer = csv.DictWriter(f, fieldnames=["document_id", "author_id", "L1", "english_proficiency"])
8787
test_document, test_author = make_mocks("1234", "1234", "hindi")
8888
m.get("https://www.italki.com/api/notebook/1234", text=json.dumps(test_document))
8989
m.get("https://www.italki.com/api/user/1234", text=json.dumps(test_author))
90-
save_document("1234", tmpdirname, writer)
91-
assert open(csv_filename).read() == "1234,1234,hindi,0\n"
90+
doc = save_document("1234", tmpdirname)
91+
assert doc == {"document_id": test_document["data"]["id"], "author_id": test_author["data"]["id"], "L1": "hindi", "english_proficiency": 0}
9292
assert open(os.path.join(tmpdirname, "1234.txt")).read() == test_document["data"]["content"]
9393

9494

@@ -100,8 +100,8 @@ def test_save_document_404():
100100
with open(csv_filename, 'w') as f:
101101
writer = csv.DictWriter(f, fieldnames=["document_id", "author_id", "L1", "english_proficiency"])
102102
m.get("https://www.italki.com/api/notebook/1234", status_code=404)
103-
save_document("1234", tmpdirname, writer)
104-
assert open(csv_filename).read() == ""
103+
doc = save_document("1234", tmpdirname)
104+
assert doc is None
105105
assert not os.path.isfile(os.path.join(tmpdirname, "1234.txt"))
106106

107107

@@ -135,15 +135,16 @@ def test_recreate():
135135
m.get("https://www.italki.com/api/notebook/4", text=json.dumps(test_document4))
136136
main(SimpleNamespace(
137137
command="recreate",
138-
agents=1,
138+
num_agents=1,
139139
output_dir=os.path.join(tmpdirname, "output"),
140140
id_file=open(os.path.join(tmpdirname, "test_ids.txt"))
141141
))
142142
assert open(os.path.join(tmpdirname, "output", "1.txt")).read() == test_document1["data"]["content"]
143143
assert open(os.path.join(tmpdirname, "output", "2.txt")).read() == test_document2["data"]["content"]
144144
assert open(os.path.join(tmpdirname, "output", "3.txt")).read() == test_document3["data"]["content"]
145145
assert open(os.path.join(tmpdirname, "output", "4.txt")).read() == test_document4["data"]["content"]
146-
assert open(os.path.join(tmpdirname, "output", "labels.train.csv")).read() == "document_id,author_id,L1,english_proficiency\n1,1234,hindi,0\n4,12344,french,5\n"
146+
print(open(os.path.join(tmpdirname, "output", "labels.train.csv")).readlines())
147+
assert set(open(os.path.join(tmpdirname, "output", "labels.train.csv")).readlines()) == set(["document_id,author_id,L1,english_proficiency\n", "1,1234,hindi,0\n", "4,12344,french,5\n"])
147148
assert open(os.path.join(tmpdirname, "output", "labels.test.csv")).read() == "document_id,author_id,L1,english_proficiency\n2,1234,hindi,0\n"
148149
assert open(os.path.join(tmpdirname, "output", "labels.dev.csv")).read() == "document_id,author_id,L1,english_proficiency\n3,1234,hindi,0\n"
149150

0 commit comments

Comments
 (0)