/
cyclegan.py
105 lines (79 loc) · 3.39 KB
/
cyclegan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# -*- coding: utf-8 -*-
"""CycleGAN.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1UdtJ4GqZj478h8eZSCHFuTiGTAlr8urK
"""
from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.models import Model
from keras.models import Input
from keras.layers import Conv2D, Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Concatenate
from keras.layers import BatchNormalization
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.utils.vis_utils import plot_model
!git clone https://www.github.com/keras-team/keras-contrib.git
!python setup.py install
!pip install keras_contrib
def define_discriminator(image_shape):
init =RandomNormal(stddev=0.02)
in_image=Input(shape=image_shape)
d=Conv2D(64,(4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)
d=LeakyReLU(alpha=0.2)(d)
d=Conv2D(128,(4,4), strides=(2,2), padding='same' , kernel_initializer=init)(d)
d=InstanceNormalization(axis=-1)(d)
d=LeakyReLU(alpha=0.2)(d)
d=Conv2D(256,(4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d=InstanceNormalization(axis=-1)(d)
d=LeakyReLU(alpha=0.2)(d)
d=Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d=InstanceNormalization(axis=-1)(d)
d=LeakyReLU(alpha=0.2)(d)
patch_out=Conv2D(1,(4,4), strides=(2,2), kernel_initializer=init)(d)
model=Model(in_image, patch_out)
model.compile(loss='mse', optimizer=Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5])
return model
image_shape=(256,256,3)
model=define_discriminator(image_shape)
model.summary()
plot_model(model, to_file='discriminator_model_plot.png', show_shapes=True,show_layer_names=True)
def resnet_block(n_filters, input_layer):
init=RandomNormal(stddev=0.02)
g=Conv2D(n_filters,(3,3), padding='same', kernel_initializer=init)(input_layer)
g=InstanceNormalization(axis=-1)(g)
g=Activation('relu')(g)
g=Conv2D(n_filters,(3,3), padding='same', kernel_initializer=init)(g)
g=InstanceNormalization(axis=-1)(g)
g=Concatenate()([g, input_layer])
return g
def define_generator(image_shape=(256,256,3), n_resnet=9):
init=RandomNormal(stddev=0.02)
in_image=Input(shape=image_shape)
g=Conv2D(64,(7,7), padding='same', kernel_initializer=init)(in_image)
g=InstanceNormalization(axis=-1)(g)
g=Activation('relu')(g)
g=Conv2D(128,(3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
g=InstanceNormalization(axis=-1)(g)
g=Activation('relu')(g)
g=Conv2D(256,(3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
g=InstanceNormalization(axis=-1)(g)
g=Activation('relu')(g)
for _ in range(9):
g=resnet_block(256,g)
g=Conv2DTranspose(128,(3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
g=InstanceNormalization(axis=-1)(g)
g=Activation('relu')(g)
g=Conv2DTranspose(64,(3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
g=InstanceNormalization(axis=-1)(g)
g=Activation('relu')(g)
g=Conv2D(3,(7,7), padding='same', kernel_initializer=init)(g)
g=InstanceNormalization(axis=-1)(g)
out_image=Activation('tanh')(g)
model=Model(in_image, out_image)
return model
model=define_generator()
model.summary()
plot_model(model,show_shapes=True,to_file='generator_model.png', show_layer_names=True)