-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
92 lines (80 loc) · 4.68 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import voc_reader
from functools import partial
import math
from keras.callbacks import ModelCheckpoint,CSVLogger,LearningRateScheduler,ReduceLROnPlateau,EarlyStopping,TensorBoard
from keras.models import load_model
from metrics import (dice_coefficient, dice_coefficient_loss)
def train_generator_data(voc_reader):
while True:
x,y=voc_reader.next_train_batch()
yield (x,y)
def val_generator_data(voc_reader):
while True:
x,y=voc_reader.next_val_batch()
yield (x,y)
def step_decay(epoch,initial_lrate,drop,epochs_drop):
return initial_lrate * math.floor((1+epoch)/float(epochs_drop))
def get_callbacks(model_file,initial_learning_rate=0.0001,learning_rate_drop=0.5,learning_rate_epochs=None,
learning_rate_patience=50,logging_file="training.log",verbosity=1,early_stopping_patience=None):
callbacks=list()
# weights.{epoch:02d}-{val_loss:.2f}.hdf5
callbacks.append(ModelCheckpoint(model_file,save_best_only=True))
callbacks.append(CSVLogger(logging_file,append=True))
callbacks.append(TensorBoard())
if learning_rate_epochs:
callbacks.append(LearningRateScheduler(partial(step_decay, initial_lrate=initial_learning_rate,
drop=learning_rate_drop, epochs_drop=learning_rate_epochs)))
else:
callbacks.append(ReduceLROnPlateau(factor=learning_rate_drop, patience=learning_rate_patience,
verbose=verbosity))
if early_stopping_patience:
callbacks.append(EarlyStopping(verbose=verbosity, patience=early_stopping_patience))
return callbacks
# def load_old_model(model_file):
# print("Loading pre-trained model")
# custom_objects = {'dice_coefficient_loss': dice_coefficient_loss, 'dice_coefficient': dice_coefficient}
# try:
# from keras_contrib.layers import InstanceNormalization
# custom_objects["InstanceNormalization"] = InstanceNormalization
# except ImportError:
# pass
# try:
# return load_model(model_file, custom_objects=custom_objects)
# except ValueError as error:
# if 'InstanceNormalization' in str(error):
# raise ValueError(str(error) + "\n\nPlease install keras-contrib to use InstanceNormalization:\n"
# "'pip install git+https://www.github.com/keras-team/keras-contrib.git'")
# else:
# raise error
def train_model(model, model_file, training_generator, validation_generator, steps_per_epoch, validation_steps,
initial_learning_rate=0.01, learning_rate_drop=0.5, learning_rate_epochs=None, n_epochs=80,
learning_rate_patience=20, early_stopping_patience=None):
"""
Train a Keras model.
:param early_stopping_patience: If set, training will end early if the validation loss does not improve after the
specified number of epochs.
:param learning_rate_patience: If learning_rate_epochs is not set, the learning rate will decrease if the validation
loss does not improve after the specified number of epochs. (default is 20)
:param model: Keras model that will be trained.
:param model_file: Where to save the Keras model.
:param training_generator: Generator that iterates through the training data.
:param validation_generator: Generator that iterates through the validation data.
:param steps_per_epoch: Number of batches that the training generator will provide during a given epoch.
:param validation_steps: Number of batches that the validation generator will provide during a given epoch.
:param initial_learning_rate: Learning rate at the beginning of training.
:param learning_rate_drop: How much at which to the learning rate will decay.
:param learning_rate_epochs: Number of epochs after which the learning rate will drop.
:param n_epochs: Total number of epochs to train the model.
:return:
"""
model.fit_generator(generator=training_generator,
steps_per_epoch=steps_per_epoch,
epochs=n_epochs,
validation_data=validation_generator,
validation_steps=validation_steps,
callbacks=get_callbacks(model_file,
initial_learning_rate=initial_learning_rate,
learning_rate_drop=learning_rate_drop,
learning_rate_epochs=learning_rate_epochs,
learning_rate_patience=learning_rate_patience,
early_stopping_patience=early_stopping_patience))