-
Notifications
You must be signed in to change notification settings - Fork 1
/
Training.py
130 lines (115 loc) · 6.13 KB
/
Training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Author:
Jay Lago, NIWC/SDSU, 2021
"""
import os
import pickle
import time
import datetime as dt
import numpy as np
import tensorflow as tf
import HelperFuns as hf
def train_model(hyp_params, train_data, val_set, model, loss):
# Dictionary to store all relevant training parameters and losses
train_params = dict()
train_params['start_time'] = time.time()
train_params['train_loss_results'] = []
train_params['val_loss_results'] = []
train_params['val_loss_comps_avgs'] = []
# Variables to control adaptive number of observables routine below
max_win_stp = 1
num_time_steps = int(hyp_params['num_time_steps'])
# Set optimizer
if hyp_params['optimizer'] == 'adam':
myoptimizer = tf.keras.optimizers.Adam(hyp_params['lr'])
if hyp_params['optimizer'] == 'sgd':
myoptimizer = tf.keras.optimizers.SGD(learning_rate=hyp_params['lr'], momentum=0.9)
# Begin primary training loop
for epoch in range(1, hyp_params['max_epochs'] + 1):
epoch_start_time = dt.datetime.now()
epoch_time = time.time()
epoch_loss_avg_train = tf.keras.metrics.Mean()
# Shuffle, batch, and prefetch training data to the GPU
train_set = train_data.shuffle(hyp_params['num_train_init_conds']) \
.batch(hyp_params['batch_size'], drop_remainder=True)
try:
train_set = train_set.prefetch(tf.data.AUTOTUNE)
except:
train_set = train_set.prefetch(tf.data.experimental.AUTOTUNE)
# Begin batch training
with tf.device(hyp_params['device']):
for train_batch in train_set:
with tf.GradientTape() as tape:
train_pred = model(train_batch, training=True)
train_loss = loss(train_pred, train_batch)
gradients = tape.gradient(train_loss, model.trainable_weights)
myoptimizer.apply_gradients([(grad, var) for (grad, var) in zip(gradients, model.trainable_weights)
if grad is not None])
myoptimizer.apply_gradients(zip(gradients, model.trainable_weights))
epoch_loss_avg_train.update_state(train_loss)
# Batch validation
init_num_obsvs = model.num_observables
min_loss = 1e6
lrecon = tf.keras.metrics.Mean()
lpred = tf.keras.metrics.Mean()
ldmd = tf.keras.metrics.Mean()
# Use batch validation to determine optimal number of observables for this epoch
for num_obsvs in range(init_num_obsvs - max_win_stp, init_num_obsvs + max_win_stp + 1):
model.num_observables = num_obsvs
model.window = num_time_steps - (num_obsvs - 1)
loss.num_observables = num_obsvs
loss.window = num_time_steps - (num_obsvs - 1)
epoch_loss_avg_val = tf.keras.metrics.Mean()
for val_batch in val_set:
val_pred = model(val_batch)
val_loss = loss(val_pred, val_batch)
epoch_loss_avg_val.update_state(val_loss)
avg_val_loss = epoch_loss_avg_val.result()
if avg_val_loss.numpy() < min_loss:
min_loss = avg_val_loss.numpy()
num_obsvs_opt = num_obsvs
min_avg_val_loss = avg_val_loss
model.num_observables = num_obsvs_opt
model.window = num_time_steps - (num_obsvs_opt - 1)
loss.num_observables = num_obsvs_opt
loss.window = num_time_steps - (num_obsvs_opt - 1)
# Save loss components for diagnostic plotting. Note, needs to be tweaked due to changing
# number of observables, but will train regardless.
lrecon.update_state(np.log10(loss.loss_recon))
lpred.update_state(np.log10(loss.loss_pred))
ldmd.update_state(np.log10(loss.loss_dmd))
train_params['val_loss_comps_avgs'].append([lrecon.result(), lpred.result(), ldmd.result()])
# Report training status
train_params['train_loss_results'].append(np.ma.log10(epoch_loss_avg_train.result()))
train_params['val_loss_results'].append(np.ma.log10(min_avg_val_loss))
print("Epoch {epoch} of {max_epoch} / Train {train:3.7f} / Val {test:3.7f} / LR {lr:2.7f} / {time:4.2f} seconds"
.format(epoch=epoch, max_epoch=hyp_params['max_epochs'],
train=train_params['train_loss_results'][-1],
test=train_params['val_loss_results'][-1],
lr=hyp_params['lr'],
time=time.time() - epoch_time))
# Save training diagnostic plots
if epoch == 1 or epoch % hyp_params['plot_every'] == 0:
if not os.path.exists(hyp_params['plot_path']):
os.makedirs(hyp_params['plot_path'])
this_plot = hyp_params['plot_path'] + '/' + epoch_start_time.strftime("%Y%m%d%H%M%S") + '.png'
hf.diagnostic_plot(val_pred, val_batch, hyp_params, epoch,
this_plot, train_params['val_loss_comps_avgs'],
train_params['val_loss_results'])
# Save model
if epoch % hyp_params['save_every'] == 0 or epoch == hyp_params['max_epochs']:
if not os.path.exists(hyp_params['model_path']):
os.makedirs(hyp_params['model_path'])
model_path = hyp_params['model_path'] + '/epoch_{epoch}_loss_{loss:2.3}' \
.format(epoch=epoch, loss=train_params['val_loss_results'][-1])
model.save_weights(model_path + '.h5')
pickle.dump(hyp_params, open(model_path + '.pkl', 'wb'))
print("\nTotal training time: %4.2f minutes" % ((time.time() - train_params['start_time']) / 60.0))
print("Final train loss: %2.7f" % (train_params['train_loss_results'][-1]))
print("Final validation loss: %2.7f" % (train_params['val_loss_results'][-1]))
results = dict()
results['model'] = model
results['loss'] = loss
results['val_loss_history'] = train_params['val_loss_results']
results['val_loss_comps'] = train_params['val_loss_comps_avgs']
return results