-
Notifications
You must be signed in to change notification settings - Fork 1
/
models.py
238 lines (198 loc) · 8.98 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
"""Module containing, RevNet and DIQT Network
Created by Stefano B. Blumberg to illustrate methodology utilised in:
Deeper Image Quality Transfer: Training Low-Memory Neural Networks for 3D Images (MICCAI 2018)
"""
import torch as t ; import torch.nn as nn
class ESPCN_RN(nn.Module):
"""ESPCN_RN-N (N:=no_RevNet_layers), from
Deeper Image Quality Transfer: Training Low-Memory Neural Networks for 3D Images
We have omitted the final shuffle for simplicity.
backpropagation is performed manually. Please see Main file on how to integrate this
"""
def __init__(self,
no_RevNet_layers,
no_chans_in=6,
no_chans_out=48,
memory_efficient=True
):
"""
Args:
no_RevNet_layers (int): Number of RevNet layers per stack
no_chans_in (int): Number of input channels
memory_efficient (bool): Use memory-efficient technique
"""
super().__init__()
noChansin0,noChansout0 = no_chans_in,50
self.rn0 = RevNet(noChansin0//2,
noChansin0//2,
no_RevNet_layers=no_RevNet_layers,
memory_efficient=memory_efficient)
self.conv0 = nn.Sequential(nn.Conv3d(noChansin0, noChansout0, kernel_size=3, padding=0),
nn.ReLU())
noChansin1,noChansout1 = 50,100
self.rn1 = RevNet(noChansin1//2,
noChansin1//2,
no_RevNet_layers=no_RevNet_layers,
memory_efficient=memory_efficient)
self.conv1 = nn.Sequential(nn.Conv3d(noChansin1, noChansout1, kernel_size=1, padding=0),
nn.ReLU())
noChansin2,noChansout2 = 100,no_chans_out
self.rn2 = RevNet(noChansin2//2,
noChansin2//2,
no_RevNet_layers=no_RevNet_layers,
memory_efficient=memory_efficient)
self.conv2 = nn.Sequential(nn.Conv3d(noChansin2, noChansout2, kernel_size=3, padding=0))
def forward(self, X):
"""Memory-efficient forward pass. Cache intermediate activations as attributes"""
self.inpConv0 = self.rn0(X)
X = self.conv0(self.inpConv0)
self.inpConv1 = self.rn1(X)
X = self.conv1(self.inpConv1)
self.inpConv2 = self.rn2(X)
X = self.conv2(self.inpConv2)
return X
def backward(self,Y,YGrad):
"""Memory-efficient backward pass, computes gradients"""
conv2_root = self.inpConv2.requires_grad_()
Y = self.conv2(conv2_root)
t.autograd.backward([Y], [YGrad])
_,YGrad = self.rn2.backward(self.inpConv2, conv2_root.grad)
conv1_root = self.inpConv1.requires_grad_()
Y = self.conv1(conv1_root)
t.autograd.backward([Y],[YGrad])
_,YGrad = self.rn1.backward(self.inpConv1,conv1_root.grad)
conv0_root = self.inpConv0.requires_grad_()
Y = self.conv0(conv0_root)
t.autograd.backward([Y],[YGrad])
_,_ = self.rn0.backward(self.inpConv0,conv0_root.grad)
class RevNet(nn.Module):
"""RevNet Class"""
def __init__(self,
noChans1=3,
noChans2=3,
no_RevNet_layers=10,
whole=True,
memory_efficient=True):
super().__init__()
self.noChans1 = noChans1
self.no_RevNet_layers = no_RevNet_layers
self.whole = whole
self.memory_efficient=memory_efficient
self.CreateFG(noChans1, noChans2, no_RevNet_layers)
def CreateFG(self, noChans1, noChans2, no_RevNet_layers):
"""Construct residual functions, split input channels into noChans1 + noChans2
e.g. F5 is residual function F for layer 5
"""
for layer in range(no_RevNet_layers):
"""
# An example of another (basic) residual block
Basic = nn.Sequential(nn.BatchNorm3d(noChans2),
nn.ReLU(),
nn.Conv3d(noChans2, noChans2, kernel_size=3, padding=1),
nn.BatchNorm3d(noChans2),
nn.ReLU(),
nn.Conv3d(noChans2, noChans1, kernel_size=3, padding=1))
"""
residual_1 = nn.Sequential(nn.BatchNorm3d(noChans1),
nn.ReLU(),
nn.Conv3d(noChans1, noChans1, kernel_size=1, padding=0),
nn.BatchNorm3d(noChans1),
nn.ReLU(),
nn.Conv3d(noChans1, noChans1, kernel_size=3, padding=1),
nn.BatchNorm3d(noChans1),
nn.ReLU(),
nn.Conv3d(noChans1, noChans2, kernel_size=1, padding=0))
residual_2 = nn.Sequential(nn.BatchNorm3d(noChans1),
nn.ReLU(),
nn.Conv3d(noChans1, noChans1, kernel_size=1, padding=0),
nn.BatchNorm3d(noChans1),
nn.ReLU(),
nn.Conv3d(noChans1, noChans1, kernel_size=3, padding=1),
nn.BatchNorm3d(noChans1),
nn.ReLU(),
nn.Conv3d(noChans1, noChans2, kernel_size=1, padding=0))
setattr(self,'F'+str(layer), residual_1)
setattr(self,'G'+str(layer), residual_2)
def ForwardPassLayer(self,
layer_no,
X):
""" Forward Pass of a RevNet layer
Args:
(x1,x2) tuple of tensors to layer layer_no
Returns:
Forward pass of layer layer_no as tuple tensors
"""
x1,x2 = X
z1 = x1 + getattr(self, 'F'+str(layer_no))(x2)
y2 = x2 + getattr(self, 'G'+str(layer_no))(z1)
y1 = z1
return (y1,y2)
def forward(self, X):
"""Forward pass of multiple reversilble layers in a block
Args:
(self.whole bool): T means we have to split the input, F it is passed as a tuple
"""
if self.whole:
X = (X[:,0:self.noChans1,...].contiguous(), X[:,self.noChans1:,...].contiguous())
for layer_no in range(self.no_RevNet_layers):
# Least memory saving option
if self.memory_efficient:
with t.no_grad():
Y = self.ForwardPassLayer(layer_no,X)
X = Y
# Forward pass agnostic if self.memory_efficient=False
if not self.memory_efficient:
Y = self.ForwardPassLayer(layer_no,X)
X = Y
del Y ; t.cuda.empty_cache()
X = t.cat((X[0],X[1]),1)
return X
def BackwardsPassLayer(self,
layer_no,
Y,
YHat):
"""Memory-efficient backwards pass for a RevNet layer
Assume self.memory_efficient=True
Args:
layer_no (int): RevNet layer no. in stack
Y = (y1,y2) (2-tensor-tuple): Output of Revnet layer_no
YHat = (y1Hat, y2Hat) (2-tensor-tuple): Gradients of Revnet layer_no output
"""
y1,y2 = Y
y1Hat,y2Hat = YHat
with t.no_grad():
z1 = y1 # DO need to make a copy here? id(z1)=id(y1)
x2 = y2 - getattr(self, 'G'+str(layer_no))( z1 )
x1 = z1 - getattr(self, 'F'+str(layer_no))( x2 )
z1.requires_grad_()
y2Part = getattr(self, 'G'+str(layer_no))(z1)
t.autograd.backward([y2Part],[y2Hat])
z1.grad += y1Hat
x2.requires_grad_()
z1Part = getattr(self, 'F'+str(layer_no))(x2)
t.autograd.backward([z1Part],[z1.grad])
x2.grad += y2Hat
del y1,y2,Y,y2Part
t.cuda.empty_cache()
return ((x1,x2),(z1.grad,x2.grad))
def backward(self,
Y,
YHat):
"""Memory-efficient backwards pass for a RevNet layers in a stack
Assume self.memory_efficient=True
Args:
Y (2-tensor-tuple or tensor): Revnet stack output, format depends on self.whole
YHat (2-tensor-tuple or tensor): Y gradients, format depends on self.whole
"""
if not self.memory_efficient: Exception('Using Manual backpropagation when shouldnt')
if self.whole:
Y = (Y[:,0:self.noChans1,...].contiguous(), Y[:,self.noChans1:,...].contiguous())
YHat = (YHat[:,0:self.noChans1,...].contiguous(),
YHat[:,self.noChans1:,...].contiguous())
Y = Y[0].data, Y[1].data
for layer_no in reversed(range(self.no_RevNet_layers)):
(Y,YHat) = self.BackwardsPassLayer(layer_no,Y,YHat)
print(layer_no)
Y = t.cat( (Y[0],Y[1]),1)
YHat = t.cat((YHat[0],YHat[1]),1)
return (Y, YHat)