1
+ import math
2
+ import torch
3
+ import torch .nn as nn
4
+ from . import block as B
5
+ import torchvision
6
+
7
+ #######################
8
+ # Generator
9
+ #######################
10
+
11
+ class PPON (nn .Module ):
12
+ def __init__ (self , in_nc , nf , nb , out_nc , upscale = 4 , act_type = 'lrelu' ):
13
+ super (PPON , self ).__init__ ()
14
+ n_upscale = int (math .log (upscale , 2 ))
15
+ if upscale == 3 :
16
+ n_upscale = 1
17
+
18
+ fea_conv = B .conv_layer (in_nc , nf , kernel_size = 3 ) # common
19
+ rb_blocks = [B .RRBlock_32 () for _ in range (nb )] # L1
20
+ LR_conv = B .conv_layer (nf , nf , kernel_size = 3 )
21
+
22
+ ssim_branch = [B .RRBlock_32 () for _ in range (2 )] # SSIM
23
+ gan_branch = [B .RRBlock_32 () for _ in range (2 )] # Gan
24
+
25
+ upsample_block = B .upconv_block
26
+
27
+ if upscale == 3 :
28
+ upsampler = upsample_block (nf , nf , 3 , act_type = act_type )
29
+ upsampler_ssim = upsample_block (nf , nf , 3 , act_type = act_type )
30
+ upsampler_gan = upsample_block (nf , nf , 3 , act_type = act_type )
31
+ else :
32
+ upsampler = [upsample_block (nf , nf , act_type = act_type ) for _ in range (n_upscale )]
33
+ upsampler_ssim = [upsample_block (nf , nf , act_type = act_type ) for _ in range (n_upscale )]
34
+ upsampler_gan = [upsample_block (nf , nf , act_type = act_type ) for _ in range (n_upscale )]
35
+
36
+ HR_conv0 = B .conv_block (nf , nf , kernel_size = 3 , norm_type = None , act_type = act_type )
37
+ HR_conv1 = B .conv_block (nf , out_nc , kernel_size = 3 , norm_type = None , act_type = None )
38
+
39
+ HR_conv0_S = B .conv_block (nf , nf , kernel_size = 3 , norm_type = None , act_type = act_type )
40
+ HR_conv1_S = B .conv_block (nf , out_nc , kernel_size = 3 , norm_type = None , act_type = None )
41
+
42
+ HR_conv0_P = B .conv_block (nf , nf , kernel_size = 3 , norm_type = None , act_type = act_type )
43
+ HR_conv1_P = B .conv_block (nf , out_nc , kernel_size = 3 , norm_type = None , act_type = None )
44
+
45
+ self .CFEM = B .sequential (fea_conv , B .ShortcutBlock (B .sequential (* rb_blocks , LR_conv )))
46
+ self .SFEM = B .sequential (* ssim_branch )
47
+ self .PFEM = B .sequential (* gan_branch )
48
+
49
+ self .CRM = B .sequential (* upsampler , HR_conv0 , HR_conv1 ) # recon l1
50
+ self .SRM = B .sequential (* upsampler_ssim , HR_conv0_S , HR_conv1_S ) # recon ssim
51
+ self .PRM = B .sequential (* upsampler_gan , HR_conv0_P , HR_conv1_P ) # recon gan
52
+
53
+ def forward (self , x ):
54
+ out_CFEM = self .CFEM (x )
55
+ out_c = self .CRM (out_CFEM )
56
+
57
+ out_SFEM = self .SFEM (out_CFEM )
58
+ out_s = self .SRM (out_SFEM ) + out_c
59
+
60
+ out_PFEM = self .PFEM (out_SFEM )
61
+ out_p = self .PRM (out_PFEM ) + out_s
62
+
63
+ return out_c , out_s , out_p
64
+
65
+ #########################
66
+ # Discriminator
67
+ #########################
68
+
69
+ class Discriminator_192 (nn .Module ):
70
+ def __init__ (self , in_nc , base_nf , norm_type = 'batch' , act_type = 'lrelu' ):
71
+ super (Discriminator_192 , self ).__init__ ()
72
+
73
+ conv0 = B .conv_block (in_nc , base_nf , kernel_size = 3 , norm_type = None , act_type = act_type ) # 3-->64
74
+ conv1 = B .conv_block (base_nf , base_nf , kernel_size = 4 , stride = 2 , norm_type = norm_type , # 64-->64, 96*96
75
+ act_type = act_type )
76
+
77
+ conv2 = B .conv_block (base_nf , base_nf * 2 , kernel_size = 3 , stride = 1 , norm_type = norm_type , # 64-->128
78
+ act_type = act_type )
79
+ conv3 = B .conv_block (base_nf * 2 , base_nf * 2 , kernel_size = 4 , stride = 2 , norm_type = norm_type , # 128-->128, 48*48
80
+ act_type = act_type )
81
+
82
+ conv4 = B .conv_block (base_nf * 2 , base_nf * 4 , kernel_size = 3 , stride = 1 , norm_type = norm_type , # 128-->256
83
+ act_type = act_type )
84
+ conv5 = B .conv_block (base_nf * 4 , base_nf * 4 , kernel_size = 4 , stride = 2 , norm_type = norm_type , # 256-->256, 24*24
85
+ act_type = act_type )
86
+
87
+ conv6 = B .conv_block (base_nf * 4 , base_nf * 8 , kernel_size = 3 , stride = 1 , norm_type = norm_type , # 256-->512
88
+ act_type = act_type )
89
+ conv7 = B .conv_block (base_nf * 8 , base_nf * 8 , kernel_size = 4 , stride = 2 , norm_type = norm_type , # 512-->512 12*12
90
+ act_type = act_type )
91
+
92
+ conv8 = B .conv_block (base_nf * 8 , base_nf * 8 , kernel_size = 3 , stride = 1 , norm_type = norm_type , # 512-->512
93
+ act_type = act_type )
94
+ conv9 = B .conv_block (base_nf * 8 , base_nf * 8 , kernel_size = 4 , stride = 2 , norm_type = norm_type , # 512-->512 6*6
95
+ act_type = act_type )
96
+ conv10 = B .conv_block (base_nf * 8 , base_nf * 8 , kernel_size = 3 , stride = 1 , norm_type = norm_type ,
97
+ act_type = act_type )
98
+ conv11 = B .conv_block (base_nf * 8 , base_nf * 8 , kernel_size = 4 , stride = 2 , norm_type = norm_type , # 3*3
99
+ act_type = act_type )
100
+
101
+ self .features = B .sequential (conv0 , conv1 , conv2 , conv3 , conv4 , conv5 , conv6 , conv7 , conv8 ,
102
+ conv9 , conv10 , conv11 )
103
+
104
+ self .classifier2 = nn .Sequential (
105
+ nn .Linear (512 * 3 * 3 , 128 ), nn .LeakyReLU (0.2 , True ), nn .Linear (128 , 1 ))
106
+
107
+ def forward (self , x ):
108
+ x = self .features (x )
109
+ x = x .view (x .size (0 ), - 1 )
110
+ x = self .classifier2 (x )
111
+ return x
112
+
113
+ #########################
114
+ # Perceptual Network
115
+ #########################
116
+
117
+ # data range [0, 1]
118
+ class VGGFeatureExtractor (nn .Module ):
119
+ def __init__ (self ,
120
+ feature_layer = 34 ,
121
+ use_bn = False ,
122
+ use_input_norm = True ):
123
+ super (VGGFeatureExtractor , self ).__init__ ()
124
+ if use_bn :
125
+ model = torchvision .models .vgg19_bn (pretrained = True )
126
+ else :
127
+ model = torchvision .models .vgg19 (pretrained = True )
128
+ self .use_input_norm = use_input_norm
129
+ if self .use_input_norm :
130
+ self .mean = torch .tensor ([0.485 , 0.456 , 0.406 ]).view (1 , 3 , 1 , 1 ).cuda ()
131
+ self .std = torch .tensor ([0.229 , 0.224 , 0.225 ]).view (1 , 3 , 1 , 1 ).cuda ()
132
+
133
+ self .features = nn .Sequential (* list (model .features .children ())[:(feature_layer + 1 )])
134
+ for k , v in self .features .named_parameters ():
135
+ v .requires_grad = False
136
+
137
+ def forward (self , x ):
138
+ if self .use_input_norm :
139
+ x = (x - self .mean ) / self .std
140
+ output = self .features (x )
141
+ return output
0 commit comments