Skip to content

Commit c74324c

Browse files
committed
test code
1 parent 792c656 commit c74324c

File tree

5 files changed

+575
-0
lines changed

5 files changed

+575
-0
lines changed

models/architecture.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import math
2+
import torch
3+
import torch.nn as nn
4+
from . import block as B
5+
import torchvision
6+
7+
#######################
8+
# Generator
9+
#######################
10+
11+
class PPON(nn.Module):
12+
def __init__(self, in_nc, nf, nb, out_nc, upscale=4, act_type='lrelu'):
13+
super(PPON, self).__init__()
14+
n_upscale = int(math.log(upscale, 2))
15+
if upscale == 3:
16+
n_upscale = 1
17+
18+
fea_conv = B.conv_layer(in_nc, nf, kernel_size=3) # common
19+
rb_blocks = [B.RRBlock_32() for _ in range(nb)] # L1
20+
LR_conv = B.conv_layer(nf, nf, kernel_size=3)
21+
22+
ssim_branch = [B.RRBlock_32() for _ in range(2)] # SSIM
23+
gan_branch = [B.RRBlock_32() for _ in range(2)] # Gan
24+
25+
upsample_block = B.upconv_block
26+
27+
if upscale == 3:
28+
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
29+
upsampler_ssim = upsample_block(nf, nf, 3, act_type=act_type)
30+
upsampler_gan = upsample_block(nf, nf, 3, act_type=act_type)
31+
else:
32+
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
33+
upsampler_ssim = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
34+
upsampler_gan = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
35+
36+
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
37+
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
38+
39+
HR_conv0_S = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
40+
HR_conv1_S = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
41+
42+
HR_conv0_P = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
43+
HR_conv1_P = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
44+
45+
self.CFEM = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)))
46+
self.SFEM = B.sequential(*ssim_branch)
47+
self.PFEM = B.sequential(*gan_branch)
48+
49+
self.CRM = B.sequential(*upsampler, HR_conv0, HR_conv1) # recon l1
50+
self.SRM = B.sequential(*upsampler_ssim, HR_conv0_S, HR_conv1_S) # recon ssim
51+
self.PRM = B.sequential(*upsampler_gan, HR_conv0_P, HR_conv1_P) # recon gan
52+
53+
def forward(self, x):
54+
out_CFEM = self.CFEM(x)
55+
out_c = self.CRM(out_CFEM)
56+
57+
out_SFEM = self.SFEM(out_CFEM)
58+
out_s = self.SRM(out_SFEM) + out_c
59+
60+
out_PFEM = self.PFEM(out_SFEM)
61+
out_p = self.PRM(out_PFEM) + out_s
62+
63+
return out_c, out_s, out_p
64+
65+
#########################
66+
# Discriminator
67+
#########################
68+
69+
class Discriminator_192(nn.Module):
70+
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='lrelu'):
71+
super(Discriminator_192, self).__init__()
72+
73+
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type) # 3-->64
74+
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, # 64-->64, 96*96
75+
act_type=act_type)
76+
77+
conv2 = B.conv_block(base_nf, base_nf * 2, kernel_size=3, stride=1, norm_type=norm_type, # 64-->128
78+
act_type=act_type)
79+
conv3 = B.conv_block(base_nf * 2, base_nf * 2, kernel_size=4, stride=2, norm_type=norm_type, # 128-->128, 48*48
80+
act_type=act_type)
81+
82+
conv4 = B.conv_block(base_nf * 2, base_nf * 4, kernel_size=3, stride=1, norm_type=norm_type, # 128-->256
83+
act_type=act_type)
84+
conv5 = B.conv_block(base_nf * 4, base_nf * 4, kernel_size=4, stride=2, norm_type=norm_type, # 256-->256, 24*24
85+
act_type=act_type)
86+
87+
conv6 = B.conv_block(base_nf * 4, base_nf * 8, kernel_size=3, stride=1, norm_type=norm_type, # 256-->512
88+
act_type=act_type)
89+
conv7 = B.conv_block(base_nf * 8, base_nf * 8, kernel_size=4, stride=2, norm_type=norm_type, # 512-->512 12*12
90+
act_type=act_type)
91+
92+
conv8 = B.conv_block(base_nf * 8, base_nf * 8, kernel_size=3, stride=1, norm_type=norm_type, # 512-->512
93+
act_type=act_type)
94+
conv9 = B.conv_block(base_nf * 8, base_nf * 8, kernel_size=4, stride=2, norm_type=norm_type, # 512-->512 6*6
95+
act_type=act_type)
96+
conv10 = B.conv_block(base_nf * 8, base_nf * 8, kernel_size=3, stride=1, norm_type=norm_type,
97+
act_type=act_type)
98+
conv11 = B.conv_block(base_nf * 8, base_nf * 8, kernel_size=4, stride=2, norm_type=norm_type, # 3*3
99+
act_type=act_type)
100+
101+
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,
102+
conv9, conv10, conv11)
103+
104+
self.classifier2 = nn.Sequential(
105+
nn.Linear(512 * 3 * 3, 128), nn.LeakyReLU(0.2, True), nn.Linear(128, 1))
106+
107+
def forward(self, x):
108+
x = self.features(x)
109+
x = x.view(x.size(0), -1)
110+
x = self.classifier2(x)
111+
return x
112+
113+
#########################
114+
# Perceptual Network
115+
#########################
116+
117+
# data range [0, 1]
118+
class VGGFeatureExtractor(nn.Module):
119+
def __init__(self,
120+
feature_layer=34,
121+
use_bn=False,
122+
use_input_norm=True):
123+
super(VGGFeatureExtractor, self).__init__()
124+
if use_bn:
125+
model = torchvision.models.vgg19_bn(pretrained=True)
126+
else:
127+
model = torchvision.models.vgg19(pretrained=True)
128+
self.use_input_norm = use_input_norm
129+
if self.use_input_norm:
130+
self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
131+
self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
132+
133+
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
134+
for k, v in self.features.named_parameters():
135+
v.requires_grad = False
136+
137+
def forward(self, x):
138+
if self.use_input_norm:
139+
x = (x - self.mean) / self.std
140+
output = self.features(x)
141+
return output

models/block.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import torch
2+
import torch.nn as nn
3+
from collections import OrderedDict
4+
5+
def conv_layer(in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1):
6+
padding = int((kernel_size - 1) / 2) * dilation
7+
return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=True, dilation=dilation, groups=groups)
8+
9+
def norm(norm_type, nc):
10+
norm_type = norm_type.lower()
11+
if norm_type == 'batch':
12+
layer = nn.BatchNorm2d(nc, affine=True)
13+
elif norm_type == 'instance':
14+
layer = nn.InstanceNorm2d(nc, affine=False)
15+
else:
16+
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
17+
return layer
18+
19+
20+
def pad(pad_type, padding):
21+
pad_type = pad_type.lower()
22+
if padding == 0:
23+
return None
24+
if pad_type == 'reflect':
25+
layer = nn.ReflectionPad2d(padding)
26+
elif pad_type == 'replicate':
27+
layer = nn.ReplicationPad2d(padding)
28+
else:
29+
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
30+
return layer
31+
32+
def get_valid_padding(kernel_size, dilation):
33+
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
34+
padding = (kernel_size - 1) // 2
35+
return padding
36+
37+
38+
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
39+
pad_type='zero', norm_type=None, act_type='relu'):
40+
41+
padding = get_valid_padding(kernel_size, dilation)
42+
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
43+
padding = padding if pad_type == 'zero' else 0
44+
45+
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
46+
dilation=dilation, bias=bias, groups=groups)
47+
a = activation(act_type) if act_type else None
48+
n = norm(norm_type, out_nc) if norm_type else None
49+
return sequential(p, c, n, a)
50+
51+
52+
def activation(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
53+
act_type = act_type.lower()
54+
if act_type == 'relu':
55+
layer = nn.ReLU(inplace)
56+
elif act_type == 'lrelu':
57+
layer = nn.LeakyReLU(neg_slope, inplace)
58+
elif act_type == 'prelu':
59+
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
60+
else:
61+
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
62+
return layer
63+
64+
65+
class ShortcutBlock(nn.Module):
66+
#Elementwise sum the output of a submodule to its input
67+
def __init__(self, submodule):
68+
super(ShortcutBlock, self).__init__()
69+
self.sub = submodule
70+
71+
def forward(self, x):
72+
output = x + self.sub(x)
73+
return output
74+
75+
def __repr__(self):
76+
tmpstr = 'Identity + \n|'
77+
modstr = self.sub.__repr__().replace('\n', '\n|')
78+
tmpstr = tmpstr + modstr
79+
return tmpstr
80+
81+
82+
def sequential(*args):
83+
if len(args) == 1:
84+
if isinstance(args[0], OrderedDict):
85+
raise NotImplementedError('sequential does not support OrderedDict input.')
86+
return args[0]
87+
modules = []
88+
for module in args:
89+
if isinstance(module, nn.Sequential):
90+
for submodule in module.children():
91+
modules.append(submodule)
92+
elif isinstance(module, nn.Module):
93+
modules.append(module)
94+
return nn.Sequential(*modules)
95+
96+
97+
class _ResBlock_32(nn.Module):
98+
def __init__(self, nc=64):
99+
super(_ResBlock_32, self).__init__()
100+
self.c1 = conv_layer(nc, nc, 3, 1, 1)
101+
self.d1 = conv_layer(nc, nc//2, 3, 1, 1) # rate=1
102+
self.d2 = conv_layer(nc, nc//2, 3, 1, 2) # rate=2
103+
self.d3 = conv_layer(nc, nc//2, 3, 1, 3) # rate=3
104+
self.d4 = conv_layer(nc, nc//2, 3, 1, 4) # rate=4
105+
self.d5 = conv_layer(nc, nc//2, 3, 1, 5) # rate=5
106+
self.d6 = conv_layer(nc, nc//2, 3, 1, 6) # rate=6
107+
self.d7 = conv_layer(nc, nc//2, 3, 1, 7) # rate=7
108+
self.d8 = conv_layer(nc, nc//2, 3, 1, 8) # rate=8
109+
self.act = activation('lrelu')
110+
self.c2 = conv_layer(nc * 4, nc, 1, 1, 1) # 256-->64
111+
112+
def forward(self, input):
113+
output1 = self.act(self.c1(input))
114+
d1 = self.d1(output1)
115+
d2 = self.d2(output1)
116+
d3 = self.d3(output1)
117+
d4 = self.d4(output1)
118+
d5 = self.d5(output1)
119+
d6 = self.d6(output1)
120+
d7 = self.d7(output1)
121+
d8 = self.d8(output1)
122+
123+
add1 = d1 + d2
124+
add2 = add1 + d3
125+
add3 = add2 + d4
126+
add4 = add3 + d5
127+
add5 = add4 + d6
128+
add6 = add5 + d7
129+
add7 = add6 + d8
130+
131+
combine = torch.cat([d1, add1, add2, add3, add4, add5, add6, add7], 1)
132+
output2 = self.c2(self.act(combine))
133+
output = input + output2.mul(0.2)
134+
135+
return output
136+
137+
class RRBlock_32(nn.Module):
138+
def __init__(self):
139+
super(RRBlock_32, self).__init__()
140+
self.RB1 = _ResBlock_32()
141+
self.RB2 = _ResBlock_32()
142+
self.RB3 = _ResBlock_32()
143+
144+
def forward(self, input):
145+
out = self.RB1(input)
146+
out = self.RB2(out)
147+
out = self.RB3(out)
148+
return out.mul(0.2) + input
149+
150+
def upconv_block(in_channels, out_channels, upscale_factor=2, kernel_size=3, stride=1, act_type='relu'):
151+
upsample = nn.Upsample(scale_factor=upscale_factor, mode='nearest')
152+
conv = conv_layer(in_channels, out_channels, kernel_size, stride)
153+
act = activation(act_type)
154+
return sequential(upsample, conv, act)
155+
156+
def pixelshuffle_block(in_channels, out_channels, upscale_factor=2, kernel_size=3, stride=1):
157+
conv = conv_layer(in_channels, out_channels * (upscale_factor ** 2), kernel_size, stride)
158+
pixel_shuffle = nn.PixelShuffle(upscale_factor)
159+
return sequential(conv, pixel_shuffle)

0 commit comments

Comments
 (0)