-
Notifications
You must be signed in to change notification settings - Fork 0
/
Oracle.py
132 lines (111 loc) · 4.94 KB
/
Oracle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from dataloader_new import SmilesTokenizer
import numpy as np
import tensorflow as tf
############# Changes from Original Paper #######################
'''
1) I basically just copy and pasted the Generator class, because it basically works the same as generator pre-training
'''
class Oracle(object):
def __init__(self, train_data_loader, validation_data_loader, units = 256, leaky_relu_alpha = 0.1,
num_layers = 1, opt = tf.keras.optimizers.Adam(lr=0.01), dropout_keep_prob = 1.0, l2_reg_lambda=0.0,
loss = 'categorical_crossentropy',metrics = ['loss','val_loss']):
assert num_layers >= 1
self.st = SmilesTokenizer()
self.train_dl = train_data_loader
self.val_dl = validation_data_loader
self.table_len = self.st.table_len
self.opt = opt
self.model = self.build_model()
self.rewards = []
self.num_layers = num_layers
self.units = units
self.leaky_alpha = leaky_relu_alpha
self.loss = loss
self.metrics = metrics
self.dropout_keep_prob = dropout_keep_prob
self.kernel_regularizer = tf.keras.regularizers.l2(l2 = l2_reg_lambda)
def build_model(self, metrics = ['loss','val_loss'] ):
model = tf.keras.Sequential()
model.add(tf.keras.layers.InputLayer(
input_shape = [None, self.table_len]
))
if self.num_layers == 1:
model.add(tf.keras.layers.Bidirectional(
tf.keras.layers.LSTM(
units = self.units,
return_sequences=False,
activation = 'tanh',
recurrent_activation = 'tanh',
kernel_regularizer = self.kernel_regularizer
)))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU(alpha = self.leaky_alpha))
model.add( tf.keras.layers.Dropout(rate = self.dropout_keep_prob))
model.add(tf.keras.layers.Dense(
units = self.table_len,
activation = 'softmax',
kernel_regularizer = self.kernel_regularizer
))
else:
for i in range(self.num_layers - 1):
model.add(tf.keras.layers.Bidirectional(
tf.keras.layers.LSTM(
units = self.units,
return_sequences=True,
activation = 'tanh',
recurrent_activation = 'tanh',,
kernel_regularizer = self.kernel_regularizer
)))
model.add(tf.keras.layers.BatchNormalization())
model.add( tf.keras.layers.Dropout(rate = self.dropout_keep_prob))
model.add(tf.keras.layers.Bidirectional(
tf.keras.layers.LSTM(
units = self.units,
return_sequences=False,
activation = 'tanh',
recurrent_activation = 'tanh',
kernel_regularizer = self.kernel_regularizer
)))
model.add( tf.keras.layers.Dropout(rate = self.dropout_keep_prob))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU(alpha = self.leaky_alpha))
model.add(tf.keras.layers.Dense(
units = self.table_len,
activation = 'softmax',
kernel_regularizer = self.kernel_regularizer
))
model.compile(
optimizer = self.opt,
loss = self.loss,
metrics = self.metrics
)
return model
def compile_model(self,model, loss = "categorical_crossentropy", metrics = ['loss','val_loss'] ):
self.model.compile(
optimizer = self.opt,
loss = "categorical_crossentropy",
metrics = metrics
)
print("self.model compiled!")
def load_weights(self, filepath):
print("Loading weights from " + str(filepath))
self.model.load_weights(filepath)
print("Weights Loaded")
def save_weights(self, filepath):
print("Saving weights to " + str(filepath))
self.model.save_weights(filepath)
print("Weights Saved")
def train(self, num_epochs, verbose = 1, callbacks=[], save_weights = False, filepath = ''):
history = self.model.fit(
self.train_dl,
steps_per_epoch = self.train_dl.__len__(),
epochs = num_epochs,
verbose = verbose,
validation_data = self.valid_dl,
validation_steps = self.valid_dl.__len__(),
shuffle = True,
callbacks = callbacks,
)
if (save_weights):
save_weights(filepath)
return history