-
Notifications
You must be signed in to change notification settings - Fork 5
/
visualize_map.py
118 lines (90 loc) · 3.32 KB
/
visualize_map.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
import data_generator
from models import NN_MAP
flags = tf.app.flags
flags.DEFINE_integer("n_batch_size", 512, "Batch size to train [512]")
FLAGS = flags.FLAGS
class Trainer(object):
def __init__(self):
self.x_generator = None
self.y_generator = None
self.x = None
self.f = None
self.fx = None
self.f_var_list = None
self.loss = None
self.ckpt_dir_ot = 'ckpts/stochastic_ot_computation/'
self.ckpt_dir_map = 'ckpts/optimal_map_estimation/'
self.visualize_dir_map = 'viz/'
self.f_saver = None
self.sess = None
self.coord = None
self.threads = None
self.define_dataset()
self.define_model()
self.define_saver()
self.define_viz_dir()
self.initialize_session_and_etc()
def define_dataset(self):
self.x_generator = iter(data_generator.GeneratorGaussian1(FLAGS.n_batch_size))
self.y_generator = iter(data_generator.GeneratorGaussians4(FLAGS.n_batch_size))
self.x = tf.placeholder(tf.float32, (None, 2))
def define_model(self):
self.f = NN_MAP(self.x, 'f')
self.fx = self.f.output
self.f_var_list = self.f.var_list
def define_saver(self):
self.f_saver = tf.train.Saver(self.f_var_list)
def define_viz_dir(self):
if not os.path.exists(self.visualize_dir_map):
os.makedirs(self.visualize_dir_map)
def initialize_session_and_etc(self):
gpu_options = tf.GPUOptions(allow_growth=True)
sess_config = tf.ConfigProto(allow_soft_placement=True,
gpu_options=gpu_options)
self.sess = tf.Session(config=sess_config)
self.sess.run(tf.local_variables_initializer())
self.sess.run(tf.global_variables_initializer())
self.f_saver.restore(self.sess, self.ckpt_dir_map)
self.coord = tf.train.Coordinator()
self.threads = tf.train.start_queue_runners(sess=self.sess, coord=self.coord)
def train(self):
try:
x = next(self.x_generator)
y = next(self.y_generator)
fx = self.sess.run(self.fx, feed_dict={self.x: x})
visualize(x, y, fx)
except KeyboardInterrupt:
print("Interrupted!")
self.coord.request_stop()
finally:
self.f_saver.save(self.sess, self.ckpt_dir_map)
print('Stop')
self.coord.request_stop()
self.coord.join(self.threads)
def visualize(x, y, fx):
plt.scatter(x[:, 0], x[:, 1], s=1, c='g')
plt.scatter(y[:, 0], y[:, 1], s=1, c='r')
plt.xlim(-1.5, +1.5)
plt.ylim(-1.5, +1.5)
plt.savefig('viz/XnY.png')
plt.clf()
plt.scatter(x[:, 0], x[:, 1], s=1, c='g')
plt.scatter(fx[:, 0], fx[:, 1], s=1, c='b')
ax = plt.axes()
for i in range(int(x.shape[0]/8)):
ax.arrow(x[i, 0], x[i, 1], fx[i, 0]-x[i, 0], fx[i, 1]-x[i, 1],
head_width=0.03, head_length=0.02, fc='k', ec='k')
plt.xlim(-1.5, +1.5)
plt.ylim(-1.5, +1.5)
plt.savefig('viz/XnFx.png')
plt.clf()
fig = sns.jointplot(fx[:, 0], fx[:, 1], kind='kde')
fig.savefig('viz/Fx.png')
if __name__ == '__main__':
trainer = Trainer()
trainer.train()
print("Done!")