-
Notifications
You must be signed in to change notification settings - Fork 2
/
Network.py
130 lines (111 loc) · 5.71 KB
/
Network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch.nn as nn
import torch.nn.init as init
from cfg import par
class Denoising_Net_gray(nn.Module):
def __init__(self, depth=15, input_channel=par.input_channel, n_channel=72, output_channel=par.output_channel):
super(Denoising_Net_gray, self).__init__()
layers = []
for _ in range(depth):
layers.append(
nn.Conv2d(n_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(n_channel, output_channel, kernel_size=(3, 3), padding=(1, 1), bias=False))
self.denoisingNet = nn.Sequential(*layers)
self.InputNet0 = nn.Sequential(
nn.Conv2d(input_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(n_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(n_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True)
)
self.InputNet1 = nn.Sequential(
nn.Conv2d(input_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(n_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True)
)
self.InputNet2 = nn.Sequential(
nn.Conv2d(input_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(n_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True)
)
self.InputNet3 = nn.Sequential(
nn.Conv2d(input_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True)
)
self._initialize_weights()
def forward(self, x):
x0 = self.InputNet0(x[:, [0, 4], :, :])
x1 = self.InputNet1(x[:, [1, 4], :, :])
x2 = self.InputNet2(x[:, [2, 4], :, :])
x3 = self.InputNet3(x[:, [3, 4], :, :])
x = x0 + x1 + x2 + x3
z = self.denoisingNet(x)
return x[:, 0:4, :, :] - z
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
# init.xavier_uniform_(m.weight)
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
# init.orthogonal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
class Denoising_Net_color(nn.Module):
def __init__(self, depth=12, input_channel=par.input_channel, n_channel=108, output_channel=par.output_channel):
super(Denoising_Net_color, self).__init__()
layers = []
for _ in range(depth):
layers.append(
nn.Conv2d(n_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(n_channel, output_channel, kernel_size=(3, 3), padding=(1, 1), bias=False))
self.denoisingNet = nn.Sequential(*layers)
self.InputNet0 = nn.Sequential(
nn.Conv2d(input_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(n_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(n_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True)
)
self.InputNet1 = nn.Sequential(
nn.Conv2d(input_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(n_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True)
)
self.InputNet2 = nn.Sequential(
nn.Conv2d(input_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(n_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True)
)
self.InputNet3 = nn.Sequential(
nn.Conv2d(input_channel, n_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
nn.ReLU(inplace=True)
)
self._initialize_weights()
def forward(self, x):
x0 = self.InputNet0(x[:, [0, 4, 8, 12], :, :])
x1 = self.InputNet1(x[:, [1, 5, 9, 12], :, :])
x2 = self.InputNet2(x[:, [2, 6, 10, 12], :, :])
x3 = self.InputNet3(x[:, [3, 7, 11, 12], :, :])
x = x0 + x1 + x2 + x3
z = self.denoisingNet(x)
return x[:, 0:12, :, :] - z
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
# init.xavier_uniform_(m.weight)
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
# init.orthogonal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)