/
setup.py
91 lines (71 loc) · 3.35 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import logging
import argparse
import gdown
import zipfile
import git
import subprocess
logging.basicConfig(level = logging.INFO)
def main(args):
""" Module to setup the codebase """
# create the data folder
if not os.path.exists("data"):
logging.info("creating 'data' directory...")
os.mkdir("data")
if not os.path.exists("data/wikisum"):
logging.info("creating 'data/wikisum' directory...")
os.mkdir("data/wikisum")
if not os.path.exists("data/qmdscnn"):
logging.info("creating 'data/qmdscnn' directory...")
os.mkdir("data/qmdscnn")
# create the results folder
if not os.path.exists("results"):
logging.info("creating the 'results' directory...")
os.mkdir("results")
if not args.ignore_datasets:
# download and unzip the wikisum dataset
output_path = "data/wikisum/ranked_wiki_b40.zip"
if not os.path.exists(output_path):
logging.info("Downloading the encoded WikiSum dataset...")
url = "https://drive.google.com/uc?id=1AnqeUpLkO9MR3PH0V8q32A6PEPDEZ0td&export=download"
gdown.download(url, output_path, quiet=False)
logging.info("Unziping the data...")
with zipfile.ZipFile(output_path, "r") as zip_ref:
zip_ref.extractall("data/wikisum")
output_path = "data/wikisum/ranked_wiki_b40_query.zip"
if not os.path.exists(output_path):
logging.info("Downloading the encoded WikiSum dataset...")
url = "https://drive.google.com/uc?id=1RdX-t3pznnyaGyrswFubAfoo9S9w9K5d&export=download"
gdown.download(url, output_path, quiet=False)
logging.info("Unziping the data...")
with zipfile.ZipFile(output_path, "r") as zip_ref:
zip_ref.extractall(data/wikisum)
# download and unzip the QMDSCNN dataset
output_path = "data/qmdscnn/pytorch_qmdscnn.zip"
if not os.path.exists(output_path):
logging.info("Downloading the encoded QMDSCNN dataset...")
url = "https://drive.google.com/uc?id=1KXsvfnK6s6cnYQzD8ZOkXPdA6r5-quPK&export=download"
gdown.download(url, output_path, quiet=False)
logging.info("Unziping the data...")
with zipfile.ZipFile(output_path, "r") as zip_ref:
zip_ref.extractall("data/qmdscnn")
output_path = "data/qmdscnn/pytorch_qmdscnn_query.zip"
if not os.path.exists(output_path):
url = "https://drive.google.com/uc?id=12i_3dikeJLsOj-SQGPmc4w9Is7fB-hT-&export=download"
gdown.download(url, output_path, quiet=False)
logging.info("Unziping the data...")
with zipfile.ZipFile(output_path, "r") as zip_ref:
zip_ref.extractall(data/qmdscnn)
# download the pyrouge git repo
if not os.path.exists("pyrouge"):
repo_url = "https://github.com/andersjo/pyrouge.git"
logging.info(f"Downloading repo: {repo_url}")
git.Git(".").clone(repo_url)
# set the ROUGE path
rouge_path = os.path.join(os.getcwd(),"pyrouge/tools/ROUGE-1.5.5")
subprocess.run(["pyrouge_set_rouge_path", f"{rouge_path}"])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ignore_datasets", default=False, action="store_true")
args = parser.parse_args()
main(args)