/
generate_dataset.py
112 lines (95 loc) · 4.57 KB
/
generate_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import logging
import os
import time
# set tensorflow logging level before importing things that contain tensorflow
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 0 = INFO, 1 = WARN, 2 = ERROR, 3 = FATAL
logging.getLogger('tensorflow').setLevel(logging.ERROR)
import tensorflow as tf
from utilities.utils import load_param_file, set_logger, get_study_dirs
from utilities.input_functions import InputFunctions
# generates at TF dataset and writes to disk
def generate_dataset(params):
# logging
dataset_logger = logging.getLogger()
# generate directories
study_dirs = get_study_dirs(params) # returns a dict of "train", "val", and "eval"
# generate dataset objects for model inputs
input_fn = InputFunctions(params)
# handle compression
if params.use_gzip_compression:
comp = 'GZIP'
else:
comp = None
# handle eval dataset
if study_dirs["eval"]:
start = time.time()
dataset_logger.info("Saving evaluation dataset...")
eval_dataset_dir = os.path.join(params.dataset_dir, "eval")
eval_inputs = input_fn.get_dataset(data_dirs=study_dirs["eval"], mode="eval")
tf.data.experimental.save(eval_inputs, eval_dataset_dir, compression=comp)
end = time.time()
dataset_logger.info(f"- elapsed time is {(end - start)/60:0.2f} minutes")
else:
dataset_logger.info("No evaluation data found")
# handle val dataset
if study_dirs["val"]:
start = time.time()
dataset_logger.info("Saving validation dataset...")
val_dataset_dir = os.path.join(params.dataset_dir, "val")
val_inputs = input_fn.get_dataset(data_dirs=study_dirs["val"], mode="val")
tf.data.experimental.save(val_inputs, val_dataset_dir, compression=comp)
end = time.time()
dataset_logger.info(f"- elapsed time is {(end - start) / 60:0.2f} minutes")
else:
dataset_logger.info("No validation data found")
# handle train dataset
if study_dirs["train"]:
start = time.time()
dataset_logger.info("Saving training dataset...")
train_dataset_dir = os.path.join(params.dataset_dir, "train")
train_inputs = input_fn.get_dataset(data_dirs=study_dirs["train"], mode="train")
tf.data.experimental.save(train_inputs, train_dataset_dir, compression=comp)
end = time.time()
dataset_logger.info(f"- elapsed time is {(end - start) / 60:0.2f} minutes")
else:
dataset_logger.info("No training data found")
# executed as script
if __name__ == '__main__':
# parse input arguments
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--param_file', default=None, type=str,
help="Path to params.json")
parser.add_argument('-l', '--logging', default=2, type=int, choices=[1, 2, 3, 4, 5],
help="Set logging level: 1=DEBUG, 2=INFO, 3=WARN, 4=ERROR, 5=CRITICAL")
parser.add_argument('-x', '--overwrite', default=False,
help="Overwrite existing data.",
action='store_true')
parser.add_argument('-c', '--compression', default=False,
help="Save data using GZIP compression [slow].",
action='store_true')
# Load the parameters from the experiment params.json file in model_dir
args = parser.parse_args()
assert args.param_file, "Must specify a parameter file using --param_file"
assert os.path.isfile(args.param_file), "No json configuration file found at {}".format(args.param_file)
# load params from param file
my_params = load_param_file(args.param_file)
# set global random seed for tensorflow operations
tf.random.set_seed(my_params.random_state)
# determine dataset directory and create it if it doesn't exist, if it does exist check overwrite argument
my_params.dataset_dir = os.path.join(my_params.model_dir, 'dataset')
if not os.path.isdir(my_params.dataset_dir):
os.mkdir(my_params.dataset_dir)
elif not args.overwrite:
raise FileExistsError("Dataset directory already exists and overwrite argument is false!")
# handle logging argument
log_path = os.path.join(my_params.dataset_dir, 'dataset.log')
if os.path.isfile(log_path) and args.overwrite:
os.remove(log_path)
logger = set_logger(log_path, level=args.logging * 10)
logger.info(f"Using dataset directory {my_params.dataset_dir}")
logger.info(f"Using TensorFlow version {tf.__version__}")
# handle compression argument
my_params.use_gzip_compression = args.compression
# do work
generate_dataset(my_params)