/
Decoder.py
56 lines (43 loc) · 2.27 KB
/
Decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.conv6_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv6_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv6_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.upsample1 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
self.conv7_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv7_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv7_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.upsample2 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
self.conv8_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
self.conv8_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv8_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.upsample3 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
self.conv9_1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
self.conv9_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.upsample4 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
self.conv10_1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
self.conv10_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.conv10_3 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
def forward(self, xb):
xb = F.relu(self.conv6_1(xb))
xb = F.relu(self.conv6_2(xb))
xb = F.relu(self.conv6_3(xb))
xb = self.upsample1(xb)
xb = F.relu(self.conv7_1(xb))
xb = F.relu(self.conv7_2(xb))
xb = F.relu(self.conv7_3(xb))
xb = self.upsample2(xb)
xb = F.relu(self.conv8_1(xb))
xb = F.relu(self.conv8_2(xb))
xb = F.relu(self.conv8_3(xb))
xb = self.upsample3(xb)
xb = F.relu(self.conv9_1(xb))
xb = F.relu(self.conv9_2(xb))
xb = self.upsample4(xb)
xb = F.relu(self.conv10_1(xb))
xb = F.relu(self.conv10_2(xb))
xb = self.conv10_3(xb)
return xb