-
Notifications
You must be signed in to change notification settings - Fork 18
/
base_model.py
163 lines (142 loc) · 7.93 KB
/
base_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under a Microsoft Research License.
import numpy as np
import tensorflow.compat.v1 as tf # type: ignore
import yaml
from src.solvers import modified_euler_integrate, integrate_while
from src.utils import default_get_value, variable_summaries
from src.procdata import ProcData
def power(x, a):
return tf.exp(a * tf.math.log(x))
def log_prob_laplace(x_obs, x_post_sample, log_precisions, precisions):
log_p_x = tf.math.log(0.5) + log_precisions - precisions * tf.abs(x_post_sample - x_obs)
return log_p_x
def log_prob_gaussian(x_obs, x_post_sample, log_precisions, precisions):
# https://en.wikipedia.org/wiki/Normal_distribution
log_p_x = -0.5 * tf.math.log(2.0 * np.pi) + 0.5 * log_precisions - 0.5 * precisions * tf.square(x_post_sample - x_obs)
return log_p_x
def expand_constant_precisions(precision_list):
# e.g.: precision_list = [theta.prec_x, theta.prec_fp, theta.prec_fp, theta.prec_fp ]
precisions = tf.stack(precision_list, axis=-1)
log_precisions = tf.math.log(precisions)
precisions = tf.expand_dims(precisions, 2)
log_precisions = tf.expand_dims(log_precisions, 2)
return log_precisions, precisions
def expand_decayed_precisions(precision_list): # pylint: disable=unused-argument
raise NotImplementedError("TODO: expand_decayed_precisions")
class BaseModel:
# pylint: disable=attribute-defined-outside-init
def __init__(self, params, procdata : ProcData):
self.params = params
self.relevance = procdata.relevance_vectors
self.default_devices = procdata.default_devices
self.device_depth = procdata.device_depth
self.n_treatments = len(procdata.conditions)
self.use_laplace = default_get_value(self.params, 'use_laplace', False, verbose=True)
self.precision_type = default_get_value(self.params, 'precision_type', 'constant', verbose=True)
self.species = None
self.nspecies = None
#self.layers = []
def gen_reaction_equations(self, theta, conditions, dev_1hot, condition_on_device=True):
raise NotImplementedError("TODO: write your gen_reaction_equations")
def get_precision_list(self, theta):
return [theta.prec_x, theta.prec_rfp, theta.prec_yfp, theta.prec_cfp]
def device_conditioner(self, param, param_name, dev_1hot, kernel_initializer='glorot_uniform', use_bias=False, activation=tf.nn.relu):
"""
Returns a 1D parameter conditioned on device
::NOTE:: condition_on_device is a closure over n_iwae, n_batch, dev_1hot_rep
"""
n_iwae = tf.shape(param)[1]
n_batch = tf.shape(param)[0]
param_flat = tf.reshape(param, [n_iwae * n_batch, 1])
cond_nn = tf.keras.layers.Dense(1, use_bias=use_bias, activation=activation, kernel_initializer=kernel_initializer)
# tile devices, one per iwae sample
dev_1hot_rep = tf.tile(dev_1hot * self.relevance[param_name], [n_iwae, 1])
param_cond = cond_nn(dev_1hot_rep)
if param_name in self.default_devices:
return tf.reshape(param_flat * (1.0 + param_cond), [n_batch, n_iwae])
else:
return tf.reshape(param_flat * param_cond, [n_batch, n_iwae])
def initialize_state(self, theta, treatments):
raise NotImplementedError("TODO: write your initialize_state")
def simulate(self, theta, times, conditions, dev_1hot, solver, condition_on_device=True):
init_state = self.initialize_state(theta, conditions)
d_states_d_t = self.gen_reaction_equations(theta, conditions, dev_1hot, condition_on_device)
if solver == 'modeuler':
# Evaluate ODEs using Modified-Euler
t_state, f_state = modified_euler_integrate(d_states_d_t, init_state, times)
t_state_tr = tf.transpose(t_state, [0, 1, 3, 2])
f_state_tr = tf.transpose(f_state, [0, 1, 3, 2])
elif solver == 'modeulerwhile':
# Evaluate ODEs using Modified-Euler
t_state, f_state = integrate_while(d_states_d_t, init_state, times, algorithm='modeuler')
t_state_tr = tf.transpose(t_state, [1, 2, 0, 3])
f_state_tr = None
elif solver == 'rk4':
# Evaluate ODEs using 4th order Runge-Kutta
t_state, f_state = integrate_while(d_states_d_t, init_state, times, algorithm='rk4')
t_state_tr = tf.transpose(t_state, [1, 2, 0, 3])
f_state_tr = None
else:
raise NotImplementedError("Solver <%s> is not implemented" % solver)
return t_state_tr, f_state_tr
@classmethod
def observe(cls, x_sample, _theta):
x_predict = [
x_sample[:, :, :, 0],
x_sample[:, :, :, 0] * x_sample[:, :, :, 1],
x_sample[:, :, :, 0] * (x_sample[:, :, :, 2] + x_sample[:, :, :, 4]),
x_sample[:, :, :, 0] * (x_sample[:, :, :, 3] + x_sample[:, :, :, 5])]
x_predict = tf.stack(x_predict, axis=-1)
return x_predict
def add_time_dimension(self, p, x):
time_steps = x.shape[1]
p = tf.tile(p, [1, 1, time_steps, 1], name="time_added")
return p
def expand_precisions_by_time(self, theta, _x_predict, x_obs, _x_sample):
precision_list = self.get_precision_list(theta)
log_prec, prec = self.expand_precisions(precision_list)
log_prec = self.add_time_dimension(log_prec, x_obs)
prec = self.add_time_dimension(prec, x_obs)
if self.precision_type == "decayed":
time_steps = x_obs.shape[1]
lin_timesteps = tf.reshape(tf.linspace(1.0, time_steps.value, time_steps.value), [1, 1, time_steps, 1])
prec = prec / lin_timesteps
log_prec = log_prec - tf.math.log(lin_timesteps)
return log_prec, prec
@classmethod
def expand_precisions(cls, precision_list):
return expand_constant_precisions(precision_list)
def log_prob_observations(self, x_predict, x_obs, theta, x_sample):
log_precisions, precisions = self.expand_precisions_by_time(theta, x_predict, x_obs, x_sample)
# expand x_obs for the iw samples in x_post_sample
x_obs_ = tf.expand_dims(x_obs, 1)
lpfunc = log_prob_laplace if self.use_laplace else log_prob_gaussian
log_prob = lpfunc(x_obs_, x_predict, log_precisions, precisions)
# sum along the time and observed species axes
#log_prob = tf.reduce_sum(log_prob, [2, 3])
# sum along the time axis
log_prob = tf.reduce_sum(log_prob, 2)
return log_prob
class NeuralPrecisions(object):
def __init__(self, nspecies, n_hidden_precisions, inputs = None, hidden_activation = tf.nn.tanh):
'''Initialize neural precisions layers'''
self.nspecies = nspecies
if inputs is None:
inputs = self.nspecies+1
inp = tf.keras.layers.Dense(n_hidden_precisions, activation = hidden_activation, use_bias=True, name = "prec_hidden", input_shape=(inputs,))
act_layer = tf.keras.layers.Dense(4, activation = tf.nn.sigmoid, name = "prec_act", bias_constraint = tf.keras.constraints.NonNeg())
deg_layer = tf.keras.layers.Dense(4, activation = tf.nn.sigmoid, name = "prec_deg", bias_constraint = tf.keras.constraints.NonNeg())
self.act = tf.keras.Sequential([inp, act_layer])
self.deg = tf.keras.Sequential([inp, deg_layer])
for layer in [inp, act_layer, deg_layer]:
weights, bias = layer.weights
variable_summaries(weights, layer.name + "_kernel", False)
variable_summaries(bias, layer.name + "_bias", False)
def __call__(self, t, state, n_batch, n_iwae):
reshaped_state = tf.reshape(state[:,:,:-4], [n_batch*n_iwae, self.nspecies])
reshaped_var_state = tf.reshape(state[:,:,-4:], [n_batch*n_iwae, 4])
t_expanded = tf.tile( [[t]], [n_batch*n_iwae, 1] )
ZZ_vrs = tf.concat( [ t_expanded, reshaped_state ], axis=1 )
vrs = tf.reshape(self.act(ZZ_vrs) - self.deg(ZZ_vrs)*reshaped_var_state, [n_batch, n_iwae, 4])
return vrs