-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_cov.py
72 lines (59 loc) · 2.64 KB
/
plot_cov.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import math
import numpy as np
import torch
from model import *
import matplotlib.pyplot as plt
def comp_gen(g_optim, d_optim):
generator = AffineNet().double()
sigma = torch.tensor([[1., 0.], [0., 0.04]], dtype=torch.double)
dist = np.array([0.]*200)
v = generator.V
for i in range(1, 2001, 10):
generator.load_state_dict(torch.load("./checkpoints/covariance/{:s}-{:s}/generator-epoch_{:d}.tar".format(g_optim, d_optim, i), map_location='cpu')['model_state_dict'])
diff = v.t().mm(v).clone().detach() - sigma
dist[int(i/10)] = (diff ** 2).sum()
return dist
def comp_gen2(g_optim, d_optim):
dist = np.array([0.]*200)
for i in range(1, 2001, 10):
dist0 = torch.load("./checkpoints/covariance/{:s}-{:s}/generator-epoch_{:d}.tar".format(g_optim, d_optim, i), map_location='cpu')['generator_norm']
dist[int(i/10)] = dist0
return dist
def comp_dis(g_optim, d_optim):
discriminator = QuadraticNet(2).double()
dist = np.array([0.]*200)
for i in range(1, 2001, 10):
discriminator.load_state_dict(torch.load("./checkpoints/covariance/{:s}-{:s}/discriminator-epoch_{:d}.tar".format(g_optim, d_optim, i), map_location='cpu')['model_state_dict'])
w = discriminator.W
w = (w + w.t())/2
dist[int(i/10)] = math.sqrt((w ** 2).sum())
return dist
if __name__ == "__main__":
epochs = 2000
# dist1 = comp_gen2('gd', 'gd')[:int(epochs/10)]
dist2 = comp_gen2('gd', 'newton')[:int(epochs/10)]
# dist3 = comp_gen2('sd', 'gd')[:int(epochs/10)]
dist4 = comp_gen2('gd', 'fr')[:int(epochs/10)]
# ddist1 = comp_dis('gd', 'gd')[:int(epochs/10)]
ddist2 = comp_dis('gd', 'newton')[:int(epochs/10)]
# ddist3 = comp_dis('sd', 'gd')[:int(epochs/10)]
ddist4 = comp_dis('gd', 'fr')[:int(epochs/10)]
ax1 = plt.subplot(121)
# ax1.plot(range(0, epochs, 10), dist1, label='gda', linestyle='-')
ax1.plot(range(0, epochs, 10), dist2, label='gdn', linestyle='--')
# print(len(dist3))
# ax1.plot(range(0, epochs, 10), dist3, label='sd', linestyle=':')
ax1.plot(range(0, epochs, 10), dist4, label='fr', linestyle='-.')
ax1.legend()
plt.yscale('log')
ax1.set_title(r'$||VV^T - \Sigma||_2$')
ax2 = plt.subplot(122)
# ax2.plot(range(0, epochs, 10), ddist1, label='gda', linestyle='-')
ax2.plot(range(0, epochs, 10), ddist2, label='gdn', linestyle='--')
# ax2.plot(range(0, epochs, 10), ddist3, label='sd', linestyle=':')
ax2.plot(range(0, epochs, 10), ddist4, label='fr', linestyle='-.')
ax2.legend()
ax2.set_title(r'$||(W + W^T)/2||_2$')
plt.yscale('log')
plt.show()
# plt.savefig('images/gd-gd.png')