-
Notifications
You must be signed in to change notification settings - Fork 0
/
dndt.py
54 lines (40 loc) · 1.75 KB
/
dndt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import tensorflow as tf
from functools import reduce
class DNDT(tf.keras.layers.Layer):
def __init__(self, num_outputs, num_cuts, num_leaves, temperature = 0.1):
super(DNDT, self).__init__()
self.temperature = temperature
weights = self.add_weight(
name = 'leaf_score',
shape = (num_leaves, num_outputs),
initializer = 'uniform',
trainable = True
)
self.leaf_score = weights
self.num_cuts = num_cuts
self.cuts_list = []
for i in range(len(self.num_cuts)):
weights = self.add_weight(
name = 'cut_{}'.format(i),
shape = (self.num_cuts[i], ),
initializer = 'uniform',
trainable = True
)
self.cuts_list.append(weights)
# math stuff
def tf_kron_prod(self, a, b):
res = tf.einsum('ij,ik->ijk', a, b)
res = tf.reshape(res, [-1, tf.reduce_prod(res.shape[1:])])
return res
def tf_bin(self, x, cut_idx):
D = self.cuts_list[cut_idx].get_shape().as_list()[0]
W = tf.reshape(tf.linspace(1.0, D + 1.0, D + 1), [1, -1])
self.cuts_list[cut_idx] = tf.sort(self.cuts_list[cut_idx])
b = tf.cumsum(tf.concat([tf.constant(0.0, shape = [1]), -self.cuts_list[cut_idx]], 0))
h = tf.matmul(x, W) + b
res = tf.nn.softmax(h/self.temperature)
return res
def call(self, inputs):
leaf = reduce(self.tf_kron_prod,
map(lambda i: self.tf_bin(inputs[:, i:i + 1], i), range(len(self.cuts_list)) ))
return tf.matmul(leaf, self.leaf_score)