-
Notifications
You must be signed in to change notification settings - Fork 2
/
test.py
72 lines (53 loc) · 1.78 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from torch import nn
from torch import optim
from KroneckerProduct import *
import sys
from torch.autograd import Variable
def write_flush(text, stream=sys.stdout):
stream.write(text)
stream.flush()
class Model(nn.Module):
def __init__(self, A_shape, B_shape):
super(Model, self).__init__()
self.A_shape = A_shape
self.B_shape = B_shape
self.kronecker = KroneckerProduct(A_shape[1:], B_shape[1:])
self.register_parameter('A', nn.Parameter(torch.randint(100, size=A_shape).float()))
def forward(self, B):
return self.kronecker(self.A, B)
def main():
batch_size = 1
A_shape = (batch_size, 2, 2)
B_shape = (batch_size, 2, 3)
kronecker = KroneckerProduct(A_shape[1:], B_shape[1:])
kronecker.cuda()
A_target = torch.randint(10, size=A_shape).float().cuda()
B = torch.randint(20, size=B_shape).float().cuda()
C_target = kronecker(A_target, B)
model = Model(A_shape, B_shape)
model.cuda()
optimiser = optim.SGD(model.parameters(), lr=1e-5, momentum=0.9)
n_epochs = 4000
for epoch in range(n_epochs):
optimiser.zero_grad()
C = model(B)
loss = torch.mean((C_target - C) ** 2)
loss.backward()
optimiser.step()
write_flush('\r'+' '*100+'\rEpoch %d: Loss = %.4f' %(epoch, loss.item()))
if epoch % (n_epochs//10) == 0:
write_flush('\n')
write_flush('\n')
print('\nLearnt A: ')
print(model.A)
print('\nA_target:')
print(A_target)
print('\nB:')
print(B)
print('\nC_target: ')
print(C_target)
print('\nFinal C:')
print(C)
if __name__ == '__main__':
main()