-
Notifications
You must be signed in to change notification settings - Fork 0
/
deepNeuralNetwork.py
100 lines (91 loc) · 4.08 KB
/
deepNeuralNetwork.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# ------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
import tensorflow as tf
from tensorflow import keras
import numpy as np
#from tensorflow.examples.tutorials.mnist import input_data
# ------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
class dnNetwork():
""" class representing a deep neural network with multiple hidden layers
"""
def __init__(self, *args, **kwds):
""" a list of dnn layer size's including input and output layer.
first item in the list refers to the size of input layer.
last item in the list refers to the size of output layer.
"""
#currently the code is only for 2 hidden layers, apart from in and out
self._saveFile = kwds.get('saveFile')
self._layerSizes = kwds.get('layers', [])
self._layer1 = keras.layers.Dense(self._layerSizes[1],activation='relu')
self._layer2 = keras.layers.Dense(self._layerSizes[2],activation='relu')
self._outLayer = keras.layers.Dense(self._layerSizes[-1],activation='softmax')
self._inputs = keras.Input(shape=(self._layerSizes[0],)) #returns placeholder
x = self._layer1(self._inputs)
x = self._layer2(x)
self._outputs = self._outLayer(x)
self._model = keras.Model(inputs=self._inputs,outputs=self._outputs)
self._model.compile(optimizer=tf.train.AdamOptimizer(0.001),
loss=self.loss,
metrics=['accuracy'])
def loss(self,yTrue,yPred):
loss = keras.backend.square(yTrue[-1]-yPred[-1])
- keras.backend.dot(keras.backend.transpose(yTrue[0:-1]),keras.backend.log(yPred[0:-1]))
return loss
def loadModel(self):
""" Load the network parameters from a file
"""
self._model.load_weights('my_model')
return None
def saveModel(self):
""" Save the network parameters to a file
"""
self._model.save_weights('./my_model')
return None
def train(self, train_x,train_y):
""" Train the network using passed training data as numpy array
"""
self._model.fit(train_x,train_y,batch_size=1,epochs = 1)
return None
def predict(self,x):
"""Predict the output, given input
"""
return self._model.predict(x)
def evaluate(self,test_x,test_y):
""" evaluate accuracy
"""
evalLoss, evalAcc = self._model.evaluate(test_x,test_y)
print("Evaluation Accuracy :",evalAcc)
return None
# ------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
if __name__ == "__main__":
# #testing on mnist data
# mnist = input_data.read_data_sets('MNIST_data')
# testNetwork = dnNetwork(layers = [784,64,32,10])
# train_x = mnist.train.images
# train_y = keras.utils.to_categorical(mnist.train.labels)
# test_x = mnist.test.images
# test_y = keras.utils.to_categorical(mnist.test.labels)
# testNetwork.train(train_x,train_y)
# print("evaluation of test net")
# testNetwork.evaluate(test_x,test_y)
# testNetwork.saveModel()
# newNetwork = dnNetwork(layers = [784,64,32,10])
# newNetwork.loadModel()
# predictionsTest = testNetwork.predict(train_x)
# predictionsNew = newNetwork.predict(train_x)
# print("evaluation of new net")
# testNetwork.evaluate(test_x,test_y)
# print("Label: ",np.argmax(train_y[1]))
# print("test Net Prediction: ",np.argmax(predictionsTest[1]))
# print("New Net Prediction: ",np.argmax(predictionsTest[1]))
from tttBoard import tttBoard
board = tttBoard(3)
board.makeMove(5)
states = np.zeros((1,9))
testNet = dnNetwork(layers=[9,64,32,10])
states[0,:] = board.decodeState(board.getState())
print(states)
print(testNet.predict(states))
print(testNet.predict(board.decodeState(board.getState())).flatten())