-
Notifications
You must be signed in to change notification settings - Fork 9
/
generators.py
128 lines (105 loc) · 4.04 KB
/
generators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
'''
Courtsey of: https://github.com/Muzammal-Naseer/Cross-domain-perturbations
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools
###########################
# Generator: Resnet
###########################
# To control feature map in generator
ngf = 64
class GeneratorResnet(nn.Module):
def __init__(self, inception=False):
'''
:param inception: if True crop layer will be added to go from 3x300x300 t0 3x299x299.
:param data_dim: for high dimentional dataset (imagenet) 6 resblocks will be add otherwise only 2.
'''
super(GeneratorResnet, self).__init__()
self.inception = inception
# Input_size = 3, n, n
self.block1 = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(3, ngf, kernel_size=7, padding=0, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True)
)
# Input size = 3, n, n
self.block2 = nn.Sequential(
nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True)
)
# Input size = 3, n/2, n/2
self.block3 = nn.Sequential(
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True)
)
# Input size = 3, n/4, n/4
# Residual Blocks: 6
self.resblock1 = ResidualBlock(ngf * 4)
self.resblock2 = ResidualBlock(ngf * 4)
self.resblock3 = ResidualBlock(ngf * 4)
self.resblock4 = ResidualBlock(ngf * 4)
self.resblock5 = ResidualBlock(ngf * 4)
self.resblock6 = ResidualBlock(ngf * 4)
# Input size = 3, n/4, n/4
self.upsampl1 = nn.Sequential(
nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True)
)
# Input size = 3, n/2, n/2
self.upsampl2 = nn.Sequential(
nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True)
)
# Input size = 3, n, n
self.blockf = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(ngf, 3, kernel_size=7, padding=0)
)
self.crop = nn.ConstantPad2d((0, -1, -1, 0), 0)
def forward(self, input):
x = self.block1(input)
x = self.block2(x)
x = self.block3(x)
x = self.resblock1(x)
x = self.resblock2(x)
x = self.resblock3(x)
x = self.resblock4(x)
x = self.resblock5(x)
x = self.resblock6(x)
x = self.upsampl1(x)
x = self.upsampl2(x)
x = self.blockf(x)
if self.inception:
x = self.crop(x)
return (torch.tanh(x) + 1) / 2 # Output range [0 1]
class ResidualBlock(nn.Module):
def __init__(self, num_filters):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_channels=num_filters, out_channels=num_filters, kernel_size=3, stride=1, padding=0,
bias=False),
nn.BatchNorm2d(num_filters),
nn.ReLU(True),
nn.Dropout(0.5),
nn.ReflectionPad2d(1),
nn.Conv2d(in_channels=num_filters, out_channels=num_filters, kernel_size=3, stride=1, padding=0,
bias=False),
nn.BatchNorm2d(num_filters)
)
def forward(self, x):
residual = self.block(x)
return x + residual
if __name__ == '__main__':
netG = GeneratorResnet()
test_sample = torch.rand(1, 3, 224, 224)
print('Generator output:', netG(test_sample).size())
print('Generator parameters:', sum(p.numel() for p in netG.parameters() if p.requires_grad)/1000000)