-
Notifications
You must be signed in to change notification settings - Fork 5
/
pi_net.py
88 lines (71 loc) · 2.35 KB
/
pi_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
from tensorflow.python.keras import models, layers, losses, optimizers, utils
from tensorflow.python.keras import backend as K
def PINet_CIFAR10():
## model
input_shape = [32,32,3]
initial_conv_width=3
initial_stride=1
initial_filters=64
initial_pool_width=3
initial_pool_stride=2
use_global_pooling = True
dropout_rate = 0.2
model_input = layers.Input(shape=input_shape)
x = layers.Conv2D(
128,
initial_conv_width,
strides=initial_stride,
padding="same")(model_input)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.MaxPooling2D(
pool_size=initial_pool_width,
strides=initial_pool_stride,
padding="same")(x)
x = layers.Conv2D(
256,
initial_conv_width,
strides=initial_stride,
padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.MaxPooling2D(
pool_size=initial_pool_width,
strides=initial_pool_stride,
padding="same")(x)
x = layers.Conv2D(
512,
initial_conv_width,
strides=initial_stride,
padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.MaxPooling2D(
pool_size=initial_pool_width,
strides=initial_pool_stride,
padding="same")(x)
x = layers.Conv2D(
1024,
initial_conv_width,
strides=initial_stride,
padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
if use_global_pooling:
x = layers.GlobalAveragePooling2D()(x)
x_logits1 = layers.Dense(2500, activation="relu")(x)
x_logits1_reshape = layers.Reshape((1,50,50))(x_logits1)
x_logits1_reshape = layers.Permute((2,3,1))(x_logits1_reshape)
x_logits2 = layers.Conv2DTranspose(
3,
50,
strides=initial_stride,
padding="same")(x_logits1_reshape)
x_logits2 = layers.BatchNormalization()(x_logits2)
x_logits2 = layers.Activation("relu")(x_logits2)
model_output = layers.Flatten()(x_logits2)
model = models.Model(model_input, model_output)
return model