-
Notifications
You must be signed in to change notification settings - Fork 10
/
model.py
372 lines (301 loc) · 15 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import tensorflow as tf
import numpy as np
from base_model import BaseModel
class Multi_Label_Class(BaseModel):
def build(self):
""" Build all tensorflow model by called each method...... """
self.build_vgg16()
self.build_rnn()
if self.is_train: # eval
self.build_optimizer()
self.build_summary()
def build_vgg16(self):
""" Build the VGG16 net. """
print("Building CNN model (vgg16)..........")
config = self.config
# Input
images = tf.placeholder(
dtype = tf.float32,
shape = self.image_shape) #image_shape in base_model.py
'''hight_out = (hight_in - hight_f + 2*padding)/stride +1
width_out = (width_in - width_f + 2*padding)/stride +1
'''
conv1_1_feats = self.nn.conv2d(images, 64, name = 'conv1_1')
conv1_2_feats = self.nn.conv2d(conv1_1_feats, 64, name = 'conv1_2')
pool1_feats = self.nn.max_pool2d(conv1_2_feats, name = 'pool1')
conv2_1_feats = self.nn.conv2d(pool1_feats, 128, name = 'conv2_1')
conv2_2_feats = self.nn.conv2d(conv2_1_feats, 128, name = 'conv2_2')
pool2_feats = self.nn.max_pool2d(conv2_2_feats, name = 'pool2')
conv3_1_feats = self.nn.conv2d(pool2_feats, 256, name = 'conv3_1')
conv3_2_feats = self.nn.conv2d(conv3_1_feats, 256, name = 'conv3_2')
conv3_3_feats = self.nn.conv2d(conv3_2_feats, 256, name = 'conv3_3')
pool3_feats = self.nn.max_pool2d(conv3_3_feats, name = 'pool3')
conv4_1_feats = self.nn.conv2d(pool3_feats, 512, name = 'conv4_1')
conv4_2_feats = self.nn.conv2d(conv4_1_feats, 512, name = 'conv4_2')
conv4_3_feats = self.nn.conv2d(conv4_2_feats, 512, name = 'conv4_3')
pool4_feats = self.nn.max_pool2d(conv4_3_feats, name = 'pool4')
conv5_1_feats = self.nn.conv2d(pool4_feats, 512, name = 'conv5_1')
conv5_2_feats = self.nn.conv2d(conv5_1_feats, 512, name = 'conv5_2')
conv5_3_feats = self.nn.conv2d(conv5_2_feats, 512, name = 'conv5_3')
reshaped_conv5_3_feats = tf.reshape(conv5_3_feats,
[config.batch_size, 45, 512])
self.conv_feats = reshaped_conv5_3_feats # CNN output (into RNN)
self.num_ctx = 45
self.dim_ctx = 512
self.images = images
#=========================================================
def build_rnn(self):
''' Build the RNN................... '''
print("Building the RNN Model..........")
config = self.config
# Setup the placeholders
contexts = self.conv_feats # come from CNN output
self.labels = tf.placeholder(
dtype = tf.float32,
shape = [config.batch_size,
1,
config.label_index_length])
'''Building LSTM cell with Dropout..........'''
lstm = tf.nn.rnn_cell.LSTMCell(
config.num_lstm_units,
initializer = self.nn.fc_kernel_initializer)
if self.is_train:
lstm = tf.nn.rnn_cell.DropoutWrapper(
lstm,
input_keep_prob = 1.0-config.lstm_drop_rate,
output_keep_prob = 1.0-config.lstm_drop_rate,
state_keep_prob = 1.0-config.lstm_drop_rate)
'''Initializing input data using the mean context...'''
with tf.variable_scope("initialize"):
context_mean = tf.reduce_mean(self.conv_feats, axis = 1) # take mean of CNN output
initial_memory, initial_output = self.initialize(context_mean) #Call initialize()
#initial_state = initial_memory, initial_output
''' Prepare to run model...................'''
num_steps = config.max_class_label_length
probability = tf.zeros([config.batch_size, config.label_index_length], tf.float32)
hard_label = tf.zeros([config.batch_size, config.label_index_length], tf.float32)
last_memory = initial_memory # C after initialized
last_output = initial_output
last_state = last_memory, last_output
result_max_idx = []
result_max_value = []
if self.is_train:
alphas = [] # Parameters in attention operation
cross_entropies = []
''' LSTM predict with time step: max_class_label_length'''
for _ in range(num_steps):
# Attention mechanism
with tf.variable_scope("attend"):
alpha = self.attend(contexts, last_output) # After Softmax
context = tf.reduce_sum(contexts*tf.expand_dims(alpha, 2),
axis = 1)
if self.is_train:
alphas.append(tf.reshape(alpha, [-1]))
with tf.variable_scope("lstm"):
current_input = tf.concat([context, #attention input
probability,
hard_label], 1)
output, state = lstm(current_input, last_state) # state = (C, h)
memory, _ = state
'''Decode the expanded output of LSTM into a word'''
with tf.variable_scope("decode"):
expanded_output = tf.concat([output,
context,
probability,
hard_label],
axis = 1)
'''decode(): Last nn layers for predict'''
logits = self.decode(expanded_output)
label_reshape = tf.reshape(self.labels, logits.get_shape())
probs = tf.nn.sigmoid(logits) # Become next input
'''Generator Hard lable'''
compare_probs = tf.subtract(probs, hard_label)
max_value = tf.reduce_max(compare_probs, axis=1)
max_id = tf.argmax(compare_probs, axis=1)
hard_label = tf.add(hard_label,
tf.cast(
tf.equal(compare_probs,
tf.expand_dims(max_value, 1)), tf.float32))
if not self.is_train:
result_max_value.append(max_value)
result_max_idx.append(max_id)
""" Compute the loss for this step, if necessary. """
if self.is_train:
cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(
labels = label_reshape,
logits = logits)
cross_entropies.append(cross_entropy) #loss of each step
# next step input
probability = probs
last_output = output
last_memory = memory
last_state = state
tf.get_variable_scope().reuse_variables()
'''End: for loop'''
if not self.is_train:
self.final_prob_predict = tf.identity(probs, name='final_prob_predict')
self.final_result_max_idx = tf.stack(result_max_idx)
self.final_result_max_value = tf.stack(result_max_value)
# Compute the final loss in Training Process
if self.is_train:
cross_entropies = tf.stack(cross_entropies, axis = 1)
cross_entropy_loss = tf.reduce_mean(cross_entropies)
alphas = tf.stack(alphas, axis = 1)
alphas = tf.reshape(alphas, [config.batch_size, self.num_ctx, -1])
attentions = tf.reduce_sum(alphas, axis = 2)
diffs = tf.ones_like(attentions) - attentions
attention_loss = (config.attention_loss_factor *
tf.nn.l2_loss(diffs) /
(config.batch_size * self.num_ctx))
reg_loss = tf.losses.get_regularization_loss()
total_loss = cross_entropy_loss + attention_loss + reg_loss
#predictions_correct = tf.stack(predictions_correct, axis = 1)
#accuracy = tf.reduce_sum(predictions_correct)
self.total_loss = total_loss
self.cross_entropy_loss = cross_entropy_loss
self.attention_loss = attention_loss
self.reg_loss = reg_loss
self.attentions = attentions
self.contexts = contexts
print("RNN built...........................")
''' End: build_rnn() '''
def initialize(self, context_mean):
""" Initialize the LSTM using the mean context. """
config = self.config
context_mean = self.nn.dropout(context_mean)
# use 2 fc layers to initialize
temp1 = self.nn.dense(context_mean,
units = config.dim_initalize_layer,
activation = tf.tanh,
name = 'fc_a1')
temp1 = self.nn.dropout(temp1)
memory = self.nn.dense(temp1,
units = config.num_lstm_units,
activation = None,
name = 'fc_a2')
temp2 = self.nn.dense(context_mean,
units = config.dim_initalize_layer,
activation = tf.tanh,
name = 'fc_b1')
temp2 = self.nn.dropout(temp2)
output = self.nn.dense(temp2,
units = config.num_lstm_units,
activation = None,
name = 'fc_b2')
return memory, output
def attend(self, contexts, output):
""" Attention Mechanism....."""
config = self.config
reshaped_contexts = tf.reshape(contexts, [-1, self.dim_ctx])
reshaped_contexts = self.nn.dropout(reshaped_contexts)
output = self.nn.dropout(output)
# use 2 fc layers to attend
temp1 = self.nn.dense(reshaped_contexts,
units = config.dim_attend_layer,
activation = tf.tanh,
name = 'fc_1a')
temp2 = self.nn.dense(output,
units = config.dim_attend_layer,
activation = tf.tanh,
name = 'fc_1b')
temp2 = tf.tile(tf.expand_dims(temp2, 1), [1, self.num_ctx, 1])
temp2 = tf.reshape(temp2, [-1, config.dim_attend_layer])
temp = temp1 + temp2
temp = self.nn.dropout(temp)
logits = self.nn.dense(temp,
units = 1,
activation = None,
use_bias = False,
name = 'fc_2')
logits = tf.reshape(logits, [-1, self.num_ctx])
alpha = tf.nn.softmax(logits)
return alpha
def decode(self, expanded_output):
""" Decode the expanded output of the LSTM into a word...."""
config = self.config
expanded_output = self.nn.dropout(expanded_output)
# use 2 fc layers to decode
temp = self.nn.dense(expanded_output,
units = config.dim_decode_layer,
activation = tf.tanh,
name = 'fc_1')
temp = self.nn.dropout(temp)
logits = self.nn.dense(temp,
units = config.label_index_length,
activation = None,
name = 'fc_2')
return logits
def build_optimizer(self):
""" opt_op :
Setup the optimizer and training operation. """
config = self.config
learning_rate = tf.constant(config.initial_learning_rate)
if config.learning_rate_decay_factor < 1.0:
def _learning_rate_decay_fn(learning_rate, global_step):
return tf.train.exponential_decay(
learning_rate,
global_step,
decay_steps = config.num_steps_per_decay,
decay_rate = config.learning_rate_decay_factor,
staircase = True)
learning_rate_decay_fn = _learning_rate_decay_fn
else:
learning_rate_decay_fn = None
with tf.variable_scope('optimizer', reuse = tf.AUTO_REUSE):
if config.optimizer == 'Adam':
optimizer = tf.train.AdamOptimizer(
learning_rate = config.initial_learning_rate,
beta1 = config.beta1,
beta2 = config.beta2,
epsilon = config.epsilon
)
elif config.optimizer == 'RMSProp':
optimizer = tf.train.RMSPropOptimizer(
learning_rate = config.initial_learning_rate,
decay = config.decay,
momentum = config.momentum,
centered = config.centered,
epsilon = config.epsilon
)
elif config.optimizer == 'Momentum':
optimizer = tf.train.MomentumOptimizer(
learning_rate = config.initial_learning_rate,
momentum = config.momentum,
use_nesterov = config.use_nesterov
)
else:
optimizer = tf.train.GradientDescentOptimizer(
learning_rate = config.initial_learning_rate
)
opt_op = tf.contrib.layers.optimize_loss(
loss = self.total_loss,
global_step = self.global_step,
learning_rate = learning_rate,
optimizer = optimizer,
clip_gradients = config.clip_gradients,
learning_rate_decay_fn = learning_rate_decay_fn)
self.opt_op = opt_op
def build_summary(self):
""" Build the summary (for TensorBoard visualization). """
with tf.name_scope("variables"):
for var in tf.trainable_variables():
with tf.name_scope(var.name[:var.name.find(":")]):
self.variable_summary(var)
with tf.name_scope("metrics"):
tf.summary.scalar("cross_entropy_loss", self.cross_entropy_loss)
tf.summary.scalar("attention_loss", self.attention_loss)
tf.summary.scalar("reg_loss", self.reg_loss)
tf.summary.scalar("total_loss", self.total_loss)
#tf.summary.scalar("accuracy", self.accuracy)
with tf.name_scope("attentions"):
self.variable_summary(self.attentions)
self.summary = tf.summary.merge_all()
def variable_summary(self, var):
""" Build the summary for a variable...."""
mean = tf.reduce_mean(var)
tf.summary.scalar('mean', mean)
stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
tf.summary.scalar('stddev', stddev)
tf.summary.scalar('max', tf.reduce_max(var))
tf.summary.scalar('min', tf.reduce_min(var))
tf.summary.histogram('histogram', var)