/
M_calculate.py
105 lines (77 loc) · 2.92 KB
/
M_calculate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import collections
from math import inf
from os import error
from tempfile import tempdir
import torch
import numpy as np
def joke():
with open('FeatureMap/Layer1/conv_in', 'r') as f:
fin = torch.from_numpy(
np.array(f.read().split(), dtype='f').reshape(1, 1, 25, 25))
with open('FeatureMap/Layer1/conv_out', 'r') as f:
fout = torch.from_numpy(
np.array(f.read().split(), dtype='f').reshape(1, 56, 25, 25))
Qbest = torch.load('Qbest.pth', map_location=lambda storage, loc: storage)
temp_weight = collections.OrderedDict()
temp_weight['weight'] = torch.zeros(56, 1, 5, 5)
for i in range(56):
for j in range(5):
for k in range(5):
temp_weight['weight'][i][0][j][k] = torch.int_repr(
Qbest['first_part.0.weight'][i][0][j][k]).item()
conv = torch.nn.Conv2d(1, 56, kernel_size=5,
stride=1, padding=2, bias=False)
conv.load_state_dict(temp_weight)
relu = torch.nn.ReLU(inplace=False)
with torch.no_grad():
inference = conv(fin)
inference = relu(inference)
M_sum = 0
count = 0
for i in range(56):
for j in range(25):
for k in range(25):
if(fout[0][i][j][k] != 0):
M_sum += (inference[0][i][j][k] / fout[0][i][j][k])
count += 1
M = (M_sum) / count # 理論M值
M = 1 / M
for i in range(20):
print('M0 = ' + str(M * (2**i)) + '\tS = ' + str(i))
def MDebug():
with open('FeatureMap/Layer1/conv_in', 'r') as f:
fin = torch.from_numpy(
np.array(f.read().split(), dtype='f').reshape(1, 1, 25, 25))
with open('FeatureMap/Layer1/conv_out', 'r') as f:
fout = torch.from_numpy(
np.array(f.read().split(), dtype='f').reshape(1, 56, 25, 25))
Qbest = torch.load('./vsd/Qbest.pth', map_location=lambda storage, loc: storage)
temp_weight = collections.OrderedDict()
temp_weight['weight'] = torch.zeros(56, 1, 5, 5)
for i in range(56):
for j in range(5):
for k in range(5):
temp_weight['weight'][i][0][j][k] = torch.int_repr(
Qbest['first_part.0.weight'][i][0][j][k]).item()
conv = torch.nn.Conv2d(1, 56, kernel_size=5,
stride=1, padding=2, bias=False)
conv.load_state_dict(temp_weight)
relu = torch.nn.ReLU(inplace=False)
with torch.no_grad():
inference = conv(fin)
inference = relu(inference)
inference = torch.flatten(inference)
fout = torch.flatten(fout)
inference = inference * 565
inference = inference >> 15
#inference *= 0.017245340311241277
error = 0
for i,j in zip(inference,fout):
if (torch.round(i) != torch.round(j)):
print(i)
print(j)
print('---')
error += 1;
print(error)
if __name__ == "__main__":
MDebug()