generated from jrbourbeau/dask-binder-template
/
prep.py
118 lines (93 loc) · 3.38 KB
/
prep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import time
import sys
import argparse
import os
from glob import glob
import tarfile
import urllib.request
import pandas as pd
DATASETS = ["flights", "all"]
here = os.path.dirname(__file__)
data_dir = os.path.abspath(os.path.join(here, "data"))
def parse_args(args=None):
parser = argparse.ArgumentParser(
description="Downloads, generates and prepares data for the Dask tutorial."
)
parser.add_argument(
"--no-ssl-verify",
dest="no_ssl_verify",
action="store_true",
default=False,
help="Disables SSL verification.",
)
parser.add_argument(
"--small",
action="store_true",
default=None,
help="Whether to use smaller example datasets. Checks DASK_TUTORIAL_SMALL environment variable if not specified.",
)
parser.add_argument(
"-d", "--dataset", choices=DATASETS, help="Datasets to generate.", default="all"
)
return parser.parse_args(args)
if not os.path.exists(data_dir):
raise OSError(
"data/ directory not found, aborting data preparation. "
'Restore it with "git checkout data" from the base '
"directory."
)
def flights(small=None):
start = time.time()
flights_raw = os.path.join(data_dir, "nycflights.tar.gz")
flightdir = os.path.join(data_dir, "nycflights")
jsondir = os.path.join(data_dir, "flightjson")
if small is None:
small = bool(os.environ.get("DASK_TUTORIAL_SMALL", False))
if small:
N = 500
else:
N = 10_000
if not os.path.exists(flights_raw):
print("- Downloading NYC Flights dataset... ", end="", flush=True)
url = "https://storage.googleapis.com/dask-tutorial-data/nycflights.tar.gz"
urllib.request.urlretrieve(url, flights_raw)
print("done", flush=True)
if not os.path.exists(flightdir):
print("- Extracting flight data... ", end="", flush=True)
tar_path = os.path.join(data_dir, "nycflights.tar.gz")
with tarfile.open(tar_path, mode="r:gz") as flights:
flights.extractall("data/")
if small:
for path in glob(os.path.join(data_dir, "nycflights", "*.csv")):
with open(path, "r") as f:
lines = f.readlines()[:1000]
with open(path, "w") as f:
f.writelines(lines)
print("done", flush=True)
if not os.path.exists(jsondir):
print("- Creating json data... ", end="", flush=True)
os.mkdir(jsondir)
for path in glob(os.path.join(data_dir, "nycflights", "*.csv")):
prefix = os.path.splitext(os.path.basename(path))[0]
df = pd.read_csv(path, nrows=N)
df.to_json(
os.path.join(data_dir, "flightjson", prefix + ".json"),
orient="records",
lines=True,
)
print("done", flush=True)
else:
return
end = time.time()
print("** Created flights dataset! in {:0.2f}s**".format(end - start))
def main(args=None):
args = parse_args(args)
if args.no_ssl_verify:
print("- Disabling SSL Verification... ", end="", flush=True)
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
print("done", flush=True)
if args.dataset == "flights" or args.dataset == "all":
flights(args.small)
if __name__ == "__main__":
sys.exit(main())