-
Notifications
You must be signed in to change notification settings - Fork 35
/
train_models.py
72 lines (55 loc) · 2.08 KB
/
train_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
## train_models.py -- train the neural network models for attacking
##
## Copyright (C) 2016, Nicholas Carlini <nicholas@carlini.com>.
##
## This program is licenced under the BSD 2-Clause licence,
## contained in the LICENCE file in this directory.
## Modified for the needs of MagNet.
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.optimizers import SGD
import tensorflow as tf
from setup_mnist import MNIST
import os
def train(data, file_name, params, num_epochs=50, batch_size=128):
"""
Standard neural network training procedure.
"""
model = Sequential()
model.add(Conv2D(params[0], (3, 3), input_shape=data.train_data.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(params[1], (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(params[2], (3, 3)))
model.add(Activation('relu'))
model.add(Conv2D(params[3], (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(params[4]))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(params[5]))
model.add(Activation('relu'))
model.add(Dense(10))
def fn(correct, predicted):
return tf.nn.softmax_cross_entropy_with_logits(labels=correct,
logits=predicted)
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss=fn,
optimizer=sgd,
metrics=['accuracy'])
model.fit(data.train_data, data.train_labels,
batch_size=batch_size,
validation_data=(data.validation_data, data.validation_labels),
nb_epoch=num_epochs,
shuffle=True)
if file_name != None:
model.save(file_name)
return model
if not os.path.isdir('models'):
os.makedirs('models')
train(MNIST(), "models/example_classifier", [32, 32, 64, 64, 200, 200],
num_epochs=50)