/
ens_sel_gen.py
239 lines (195 loc) · 7.66 KB
/
ens_sel_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
#!/usr/bin/env python3
#
# Copyright (c) 2017-18 Jonathan Weyn <jweyn@uw.edu>
#
# See the file LICENSE for your rights.
#
"""
Trains and tests an ensemble selection model using predictors generated by ens_sel_batch_process.py. Implements an
'online learning' scheme whereby chunks of the data are loaded dynamically and training occurs on these individual
chunks. Uses Keras fit_generator() method to do so.
"""
from ensemble_net.util import save_model, AdamLearningRateTracker
from ensemble_net.ensemble_selection import preprocessing, verify
from ensemble_net.ensemble_selection.model import EnsembleSelector, DataGenerator
import numpy as np
import time
import xarray as xr
import os
import random
from shutil import copyfile
from keras.optimizers import SGD
from keras.callbacks import TerminateOnNaN
#%% User parameters
# Paths to important files
root_data_dir = '/home/disk/wave2/jweyn/Data/ensemble-net'
predictor_file = '%s/predictors_201504-201703_28N40N100W78W_x4_no_c_fss.nc' % root_data_dir
model_file = '%s/selector_ncar_2yr_fss_conv' % root_data_dir
result_file = '%s/result_ncar_2yr_fss_conv.nc' % root_data_dir
convolved = False
# Copy file to scratch space
copy_file_to_scratch = True
# Tell the data generator which types of variables to use
obs_errors = 'none'
radar_fss = 'both'
# Optionally predict for only a subset of variables. Must use integer index as a list, or 'all'
variables = 'all'
# Neural network configuration and options
batch_size = 8 # in model init dates
scaler_fit_size = 200
epochs = 50
impute_missing = True
scale_targets = False
val = 'random'
val_size = 71
# Use multiple GPUs
n_gpu = 1
# Seed the random validation set generator
random.seed(0)
# Print some results at the end
print_results = True
#%% End user configuration
# Parameter checks
if variables == 'all' or variables is None:
ens_sel = {}
else:
if type(variables) is not list and type(variables) is not tuple:
try:
variables = int(variables)
variables = [variables]
except (TypeError, ValueError):
raise TypeError("'variables' must be a list of integers or 'all'")
else:
try:
variables = [int(v) for v in variables]
except (TypeError, ValueError):
raise TypeError("indices in 'variables' must be integer types")
ens_sel = {'obs_var': variables}
#%% Copy file; do the initial loading and assessing
# Copy the file to scratch, if requested, and available
try:
job_id = os.environ['SLURM_JOB_ID']
except KeyError:
copy_file_to_scratch = False
if copy_file_to_scratch:
predictor_file_name = predictor_file.split('/')[-1]
scratch_file = '/scratch/%s/%s/%s' % (os.environ['USER'], os.environ['SLURM_JOB_ID'], predictor_file_name)
print('Copying predictor file to scratch space...')
copyfile(predictor_file, scratch_file)
predictor_file = scratch_file
# Load a Dataset with the predictors
print('Opening predictor dataset %s...' % predictor_file)
predictor_ds = xr.open_dataset(predictor_file, mask_and_scale=True)
num_dates = predictor_ds.dims['init_date']
num_members = predictor_ds.dims['member']
num_stations = predictor_ds.dims['station']
# Remove samples with missing FSS
if radar_fss is not None:
fss_pred = predictor_ds['FSS_PRED'].values
missing_index, = np.where(np.sum(np.isnan(fss_pred.reshape(num_dates, -1)), axis=-1) > 0)
valid_index = list(range(num_dates))
for m in missing_index:
valid_index.remove(m)
num_dates = len(valid_index)
predictor_ds = predictor_ds.isel(init_date=valid_index)
# Select the observation variables
predictor_ds = predictor_ds.sel(**ens_sel)
num_variables = predictor_ds.dims['obs_var']
#%% Get indices for the training and validation sets
if val == 'first':
val_set = list(range(0, val_size))
train_set = list(range(val_size, num_dates))
elif val == 'last':
val_set = list(range(num_dates - val_size, num_dates))
train_set = list(range(0, num_dates - val_size))
elif val == 'random':
train_set = list(range(num_dates))
val_set = []
for j in range(val_size):
i = random.choice(train_set)
val_set.append(i)
train_set.remove(i)
val_set.sort()
else:
raise ValueError("'val' must be 'first', 'last', or 'random'")
#%% Create an EnsembleSelector and Generator. The selector and generator are intertwined for scaling and imputing.
print('Building an EnsembleSelector model...')
selector = EnsembleSelector(impute_missing=impute_missing, scale_targets=scale_targets)
# Make a DataGenerator for training
generator = DataGenerator(selector, predictor_ds.isel(init_date=train_set), batch_size,
convolved=convolved, obs_errors=obs_errors, radar_fss=radar_fss)
# Make a DataGenerator for validation
val_generator = DataGenerator(selector, predictor_ds.isel(init_date=val_set), batch_size,
convolved=convolved, obs_errors=obs_errors, radar_fss=radar_fss)
# Initialize the model's Imputer and Scaler with a larger set of data
print('Fitting the EnsembleSelector Imputer and Scaler...')
fit_set = train_set[:scaler_fit_size]
predictors, targets = generator.generate_data(fit_set, scale_and_impute=False)
input_shape = predictors.shape[1:]
num_outputs = targets.shape[1]
conv_shape = generator.spatial_shape
selector.init_fit(predictors, targets)
predictors = None
targets = None
# Load the validation set, which will now also be scaled
print('Processing validation set...')
p_val, t_val = val_generator.generate_data([])
#%% Build and train the ensemble selection model
layers = (
('PartialConv2D', (16, 5), {
'strides': 3,
'conv_size': conv_shape,
'conv_first': True,
'activation': 'relu',
'input_shape': input_shape
}),
('Dense', (1024,), {
'activation': 'relu',
# 'input_shape': input_shape
}),
('Dropout', (0.25,), {}),
# ('Dense', (2*num_outputs,), {
# 'activation': 'relu'
# }),
# ('Dropout', (0.25,), {}),
('Dense', (num_outputs,), {
'activation': 'linear'
})
)
selector.build_model(layers=layers, gpus=n_gpu, loss='mse', optimizer='adam', metrics=['mae'])
# Train and evaluate the model
print('Training the EnsembleSelector model...')
start_time = time.time()
history = selector.fit_generator(generator, epochs=epochs, verbose=1, validation_data=(p_val, t_val),
use_multiprocessing=True, callbacks=[TerminateOnNaN(), AdamLearningRateTracker()])
end_time = time.time()
# Use model.evaluate() because p_val and t_val are already scaled
score = selector.model.evaluate(p_val, t_val, verbose=0)
print("\nTrain time -- %s seconds --" % (end_time - start_time))
print('Test loss:', score[0])
print('Test mean absolute error:', score[1])
# Save the model, if requested
if model_file is not None:
print('Saving model to disk...')
save_model(selector, model_file, history=history)
#%% Process the results
predicted = selector.predict(p_val)
# Reshape the prediction and the targets to meaningful dimensions
new_target_shape = (val_size, num_members, num_stations, num_variables)
predicted = predicted.reshape(new_target_shape)
if scale_targets:
t_test = selector.scaler_y.inverse_transform(t_val).reshape(new_target_shape)
else:
t_test = t_val.reshape(new_target_shape)
# Create a Dataset for the results
result = xr.Dataset(
coords={
'time': predictor_ds['init_date'].isel(init_date=val_set),
'member': predictor_ds.member,
'variable': predictor_ds.obs_var.isel(**ens_sel),
'station': range(num_stations)
}
)
result['prediction'] = (('time', 'member', 'station', 'variable'), predicted)
result['target'] = (('time', 'member', 'station', 'variable'), t_test)
result.to_netcdf(result_file)