-
Notifications
You must be signed in to change notification settings - Fork 3
/
models_adapter.py
271 lines (229 loc) · 8.63 KB
/
models_adapter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
# modified from: https://github.com/openai/CLIP/blob/a9b1bf5920416aaeaec965c25dd9e8f98c864f16/clip/model.py
from typing import Tuple
from collections import OrderedDict
import math
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from configs import (
CLIP_VIT_B16_PATH,
CLIP_VIT_L14_PATH,
DWCONV3D_DISABLE_CUDNN,
)
class Adapter(nn.Module):
def __init__(self, in_channels, adapter_channels, kernel_size):
super().__init__()
self.fc1 = nn.Linear(in_channels, adapter_channels)
self.conv = nn.Conv3d(
adapter_channels, adapter_channels,
kernel_size=kernel_size,
stride=(1, 1, 1),
padding=tuple(x // 2 for x in kernel_size),
groups=adapter_channels,
)
self.fc2 = nn.Linear(adapter_channels, in_channels)
nn.init.constant_(self.conv.weight, 0.)
nn.init.constant_(self.conv.bias, 0.)
nn.init.constant_(self.fc1.bias, 0.)
nn.init.constant_(self.fc2.bias, 0.)
def forward(self, x, T):
BT, L, C = x.size()
B = BT // T
Ca = self.conv.in_channels
H = W = round(math.sqrt(L - 1))
assert L - 1 == H * W
x_id = x
x = x[:, 1:, :]
x = self.fc1(x)
x = x.view(B, T, H, W, Ca).permute(0, 4, 1, 2, 3).contiguous()
cudnn_enabled = torch.backends.cudnn.enabled
torch.backends.cudnn.enabled = cudnn_enabled and DWCONV3D_DISABLE_CUDNN
x = self.conv(x)
torch.backends.cudnn.enabled = cudnn_enabled
x = x.permute(0, 2, 3, 4, 1).contiguous().view(BT, L - 1, Ca)
x = self.fc2(x)
x_id[:, 1:, :] += x
return x_id
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self,
d_model: int,
n_head: int,
adapter_width: int,
adapter_kernel_size: Tuple[int, int, int],
adapter_pre_attn: bool,
adapter_pre_mlp: bool,
) -> None:
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
adapter_class = functools.partial(
Adapter,
in_channels=d_model,
adapter_channels=adapter_width,
kernel_size=adapter_kernel_size,
)
self.adapter_pre_attn = \
adapter_class() if adapter_pre_attn else None
self.adapter_pre_mlp = \
adapter_class() if adapter_pre_mlp else None
def attention(self, x: torch.Tensor) -> torch.Tensor:
B, L, C = x.size()
H = self.attn.num_heads
qkv = F.linear(x, weight=self.attn.in_proj_weight, bias=self.attn.in_proj_bias)
qkv = qkv.view(B, L, H * 3, -1).permute(0, 2, 1, 3)
q, k, v = qkv.split([H, H, H], dim=1)
out = F.scaled_dot_product_attention(q, k, v)
out = out.permute(0, 2, 1, 3).flatten(-2)
out = self.attn.out_proj(out)
return out
def forward(self,
x: torch.Tensor,
num_frames: int
) -> torch.Tensor:
if self.adapter_pre_attn is not None:
x = self.adapter_pre_attn(x, num_frames)
x = x + self.attention(self.ln_1(x))
if self.adapter_pre_mlp is not None:
x = self.adapter_pre_mlp(x, num_frames)
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self,
width: int,
layers: int,
heads: int,
adapter_width: int,
adapter_layers: int,
adapter_kernel_size: Tuple[int, int, int],
adapter_pre_attn: bool,
adapter_pre_mlp: bool,
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList([
ResidualAttentionBlock(
d_model=width,
n_head=heads,
adapter_width=adapter_width,
adapter_kernel_size=adapter_kernel_size,
adapter_pre_attn=adapter_pre_attn and i >= layers - adapter_layers,
adapter_pre_mlp=adapter_pre_mlp and i >= layers - adapter_layers,
)
for i in range(layers)
])
def forward(self, x: torch.Tensor, num_frames: int) -> torch.Tensor:
for block in self.resblocks:
x = block(x, num_frames)
return x
class VisionTransformer(nn.Module):
def __init__(self,
input_resolution: int,
patch_size: int,
width: int,
layers: int,
heads: int,
num_classes: int,
adapter_width: int,
adapter_layers: int,
adapter_kernel_size: Tuple[int, int, int],
adapter_pre_attn: bool,
adapter_pre_mlp: bool,
):
super().__init__()
self.input_resolution = input_resolution
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width,
kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(
scale * torch.randn(
(input_resolution // patch_size) ** 2 + 1, width
)
)
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads,
adapter_width, adapter_layers, adapter_kernel_size,
adapter_pre_attn, adapter_pre_mlp)
self.ln_post = LayerNorm(width)
for n, p in self.named_parameters():
if 'adapter' not in n:
p.requires_grad_(False)
p.data = p.data.half()
self.dropout = nn.Dropout(0.5)
self.fc = nn.Linear(width, num_classes)
nn.init.normal_(self.fc.weight, std=0.02)
nn.init.constant_(self.fc.bias, 0.)
def forward(self, x: torch.Tensor):
B, T = x.size(0), x.size(2)
x = x.permute(0, 2, 1, 3, 4).flatten(0, 1)
x = self.conv1(x) # shape = [*, width, grid, grid]
spatial_size = tuple(x.size()[2:])
x = x.flatten(-2).permute(0, 2, 1)
x = torch.cat([
self.class_embedding.view(1, 1, -1).expand(x.shape[0], -1, -1), x
], dim=1) # [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.view(B, T, x.size(1), x.size(2)).flatten(0, 1) # BT, L, D
x = self.transformer(x, T)
x = x.contiguous().view(B, T, spatial_size[0] * spatial_size[1] + 1, x.size(-1))
x = x[:, :, 0, :].mean(dim=1)
x = self.ln_post(x)
x = self.dropout(x)
x = self.fc(x)
return x
def clip_vit_base_patch16_adapter24x384(**kwargs):
model = VisionTransformer(
input_resolution=224,
patch_size=16,
width=768,
layers=12,
heads=12,
adapter_width=384,
adapter_layers=12,
adapter_kernel_size=(3, 1, 1),
adapter_pre_attn=True,
adapter_pre_mlp=True,
**kwargs,
)
assert CLIP_VIT_B16_PATH is not None, \
'Please set CLIP_VIT_B16_PATH in configs.py.'
checkpoint = torch.jit.load(CLIP_VIT_B16_PATH, map_location='cpu')
print(model.load_state_dict(checkpoint.visual.state_dict(), strict=False))
return model
def clip_vit_base_patch16_adapter12x384(**kwargs):
model = VisionTransformer(
input_resolution=224,
patch_size=16,
width=768,
layers=12,
heads=12,
adapter_width=384,
adapter_layers=12,
adapter_kernel_size=(3, 1, 1),
adapter_pre_attn=False,
adapter_pre_mlp=True,
**kwargs,
)
assert CLIP_VIT_B16_PATH is not None, \
'Please set CLIP_VIT_B16_PATH in configs.py'
checkpoint = torch.jit.load(CLIP_VIT_B16_PATH, map_location='cpu')
print(model.load_state_dict(checkpoint.visual.state_dict(), strict=False))
return model