-
Notifications
You must be signed in to change notification settings - Fork 3
/
simulate_delta_gan.py
64 lines (48 loc) · 1.53 KB
/
simulate_delta_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
from utils_log import MetricSaver
data = 1.
delta_t = 0.01
class GAN_simualte(object):
def __init__(self, gantype, controller_d, damping):
self.type = gantype
self.controller_d = controller_d
self.damping = damping
self.d = 0.
self.g = 0.
def d_step(self):
error = data - self.g
error = self.controller_d(error)
self.d += error * delta_t - self.damping * self.d
def g_step(self):
self.g += self.d * delta_t
class PID_controller(object):
def __init__(self, p, i, d):
self.p = p
self.i = i
self.d = d
self.i_buffer = 0.
self.d_buffer = 0.
def __call__(self, error):
p_signal = error
self.i_buffer += error * delta_t
i_signal = self.i_buffer
d_signal = (error - self.d_buffer) / delta_t
self.d_buffer = error
return self.p * p_signal + self.i * i_signal + self.d * d_signal
p, i, d = 1, 0, 0
damping = 2.
saver = MetricSaver("Generator_{}_{}_{}_{}_g".format(p, i, d, damping),
"./delta_gan/",
save_on_update=False)
saver1 = MetricSaver("Generator_{}_{}_{}_{}_d".format(p, i, d, damping),
"./delta_gan/",
save_on_update=False)
controller = PID_controller(p, i, d)
gan = GAN_simualte('gan', controller, damping)
for i in range(200000):
gan.d_step()
gan.g_step()
saver.update(i, gan.g, save=False)
saver1.update(i, gan.d, save=False)
saver.save()
saver1.save()