/
Cardiomegaly_model_ResNet.py
113 lines (96 loc) · 3.42 KB
/
Cardiomegaly_model_ResNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras import regularizers
from PIL import ImageFile
import pickle
import os
import logging
import pickle
import mlflow
from tensorflow.keras.callbacks import EarlyStopping
ImageFile.LOAD_TRUNCATED_IMAGES = True
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
LR = 6e-4
BATCH_SIZE = 32
EPOCHS = 50
IMG_SIZE = 244 # Updated image size
def get_train_generator():
"""Get the Train Path"""
data_datagen = ImageDataGenerator(
samplewise_center=True,
rescale=1.0 / 255,
width_shift_range=0.2,
shear_range=0.1,
zoom_range=0.15,
height_shift_range=0.2,
brightness_range=[0.5, 1.5],
horizontal_flip=True,
fill_mode = 'reflect'
)
return data_datagen.flow_from_directory(
"Cardiomegaly_detection_dataset/train/",
target_size=(IMG_SIZE, IMG_SIZE), # Updated target size
batch_size=BATCH_SIZE,
)
def get_valid_generator():
"""Get the Valid Path"""
data_datagen = ImageDataGenerator(rescale=1.0 / 255)
return data_datagen.flow_from_directory(
"Cardiomegaly_detection_dataset/valid/",
target_size=(IMG_SIZE, IMG_SIZE), # Updated target size
batch_size=BATCH_SIZE,
)
def train():
"""Train the model"""
logging.info("Training Model.")
resnet_body = tf.keras.applications.ResNet50V2(
weights="imagenet",
include_top=False,
input_shape=(IMG_SIZE, IMG_SIZE, 3), # Updated input shape
)
resnet_body.trainable = False
unfreeze_layers = 15 # Number of blocks to unfreeze
for layer in resnet_body.layers[-unfreeze_layers:]:
layer.trainable = True
# early_stopping = EarlyStopping(
# monitor='val_loss', # Metric to monitor for early stopping
# patience=5, # Number of epochs with no improvement before stopping
# restore_best_weights=True # Restores the best weights based on the monitored metric
# )
# inputs = tf.keras.layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3)) # Updated input shape
# x = resnet_body(inputs, training=False)
# x = tf.keras.layers.Flatten()(x)
# outputs = tf.keras.layers.Dense(2, activation="softmax")(x)
# resnet_model = tf.keras.Model(inputs, outputs)
resnet_model = tf.keras.Sequential([
resnet_body,
GlobalAveragePooling2D(),
Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.001)),
Dropout(0.5),
Dense(2, activation='softmax')
])
resnet_model.compile(
optimizer=tf.optimizers.Adam(learning_rate=LR),
loss=tf.losses.categorical_crossentropy,
metrics=["accuracy"],
)
train_generator = get_train_generator()
valid_generator = get_valid_generator()
logging.info(resnet_body.summary())
logging.info("\n\n")
logging.info(resnet_model.summary())
mlflow.autolog()
resnet_model.fit(
train_generator, epochs=EPOCHS, validation_data=valid_generator
)
labels = train_generator.class_indices
logging.info("Dump models.")
resnet_model.save("./models/Cardiomegaly_resnet/1")
with open("./models/Cardiomegaly_labels.pickle", "wb") as handle:
pickle.dump(labels, handle)
logging.info("Finished training.")
if __name__ == "__main__":
train()