/
run_starclass.py
169 lines (145 loc) · 6.74 KB
/
run_starclass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Command-line interface for running classifications.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
import os
import argparse
import logging
from tqdm import tqdm
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import starclass
#--------------------------------------------------------------------------------------------------
def main():
# Parse command line arguments:
parser = argparse.ArgumentParser(description='Command-line interface for running stellar classifiers.')
parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true')
parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true')
parser.add_argument('-o', '--overwrite', help='Overwrite existing results.', action='store_true')
parser.add_argument('--chunks', type=int, default=10, help="Number of tasks sent to each worker at a time.")
parser.add_argument('--no-in-memory', action='store_false', help="Do not run TaskManager completely in-memory.")
parser.add_argument('--clear-cache', help='Clear existing features cache tables before running. Can only be used together with --overwrite.', action='store_true')
# Option to select which classifier to run:
parser.add_argument('-c', '--classifier',
default=None,
choices=starclass.classifier_list,
metavar='{CLASSIFIER}',
help='Classifier to run. Default is to run all classifiers. Choises are ' + ", ".join(starclass.classifier_list) + '.')
# Option to select training set:
parser.add_argument('-t', '--trainingset',
default='keplerq9v3',
choices=starclass.trainingset_list,
metavar='{TSET}',
help='Train classifier using this training-set. Choises are ' + ", ".join(starclass.trainingset_list) + '.')
parser.add_argument('-l', '--level', help='Classification level.', default='L1', choices=('L1', 'L2'))
parser.add_argument('--linfit', help='Enable linfit in training set.', action='store_true')
#parser.add_argument('--datalevel', help="", default='corr', choices=('raw', 'corr')) # TODO: Come up with better name than "datalevel"?
#parser.add_argument('--starid', type=int, help='TIC identifier of target.', nargs='?', default=None)
# Lightcurve truncate override switch:
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument('--truncate', dest='truncate', action='store_true', help='Force light curve truncation.')
group.add_argument('--no-truncate', dest='truncate', action='store_false', help='Force no light curve truncation.')
parser.set_defaults(truncate=None)
# Data directory:
parser.add_argument('--datadir', type=str, default=None, help='Directory where trained models and diagnostics will be loaded. Default is to load from the programs data directory.')
# Input todo-file/directory:
parser.add_argument('input_folder', type=str, nargs='?', default=None, help='Input directory to run classification on.')
args = parser.parse_args()
# Cache tables (MOAT) should not be cleared unless results tables are also cleared.
# Otherwise we could end up with non-complete MOAT tables.
if args.clear_cache and not args.overwrite:
parser.error("--clear-cache can not be used without --overwrite")
# Set logging level:
logging_level = logging.INFO
fmtstr = '%(asctime)s - %(levelname)s - %(message)s'
if args.quiet:
logging_level = logging.WARNING
elif args.debug:
logging_level = logging.DEBUG
fmtstr = '%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s'
# Setup logging:
formatter = logging.Formatter(fmtstr)
logger = logging.getLogger('starclass')
if not logger.hasHandlers():
console = starclass.utilities.TqdmLoggingHandler()
console.setFormatter(formatter)
logger.addHandler(console)
logger.setLevel(logging_level)
# Settings for tqdm:
tqdm_settings = {'disable': None if logger.isEnabledFor(logging.INFO) else True}
# Get input and output folder from environment variables:
input_folder = args.input_folder
if input_folder is None:
input_folder = os.environ.get('STARCLASS_INPUT')
if input_folder is None:
parser.error("No input folder specified")
if not os.path.exists(input_folder):
parser.error("INPUT_FOLDER does not exist")
if os.path.isdir(input_folder):
todo_file = os.path.join(input_folder, 'todo.sqlite')
else:
todo_file = os.path.abspath(input_folder)
input_folder = os.path.dirname(input_folder)
# Choose which classifier to use:
# If nothing was specified, run all classifiers, and automatically switch between them:
if args.classifier is None:
current_classifier = starclass.classifier_list[0]
change_classifier = True
else:
current_classifier = args.classifier
change_classifier = False
# Make sure we have turned plotting to non-interactive:
starclass.plots.plots_noninteractive()
# Initialize training set:
tsetclass = starclass.get_trainingset(args.trainingset)
tset = tsetclass(level=args.level, linfit=args.linfit)
# Running:
# When simply running the classifier on new stars:
stcl = None
with starclass.TaskManager(todo_file, overwrite=args.overwrite, classes=tset.StellarClasses,
load_into_memory=args.no_in_memory) as tm:
# If we were asked to do so, start by clearing the existing MOAT tables:
if args.overwrite and args.clear_cache:
tm.moat_clear()
# Get number of tasks:
numtasks = tm.get_number_tasks(classifier=args.classifier)
logger.info("%d tasks to be run", numtasks)
with tqdm(total=numtasks, **tqdm_settings) as pbar:
while True:
tasks = tm.get_task(
classifier=current_classifier,
change_classifier=change_classifier,
chunk=args.chunks)
if tasks is None:
break
logger.debug(tasks)
tm.start_task(tasks)
# ----------------- This code would run on each worker ------------------------
if tasks[0]['classifier'] != current_classifier or stcl is None:
current_classifier = tasks[0]['classifier']
if stcl:
stcl.close()
stcl = starclass.get_classifier(current_classifier)
stcl = stcl(tset=tset, features_cache=None, truncate_lightcurves=args.truncate, data_dir=args.datadir)
results = []
for task in tasks:
res = stcl.classify(task)
results.append(res)
# ----------------- This code would run on each worker ------------------------
# Return to TaskManager to be saved:
tm.save_results(results)
# Update progressbar:
pbar.update(1)
# Assign final classes:
if args.classifier is None or args.classifier == 'meta':
try:
tm.assign_final_class(tset, data_dir=args.datadir)
except starclass.exceptions.DiagnosticsNotAvailableError:
logger.error("Could not assign final classes due to missing diagnostics information.")
tset.close()
logger.info("Done.")
#--------------------------------------------------------------------------------------------------
if __name__ == '__main__':
main()