-
Notifications
You must be signed in to change notification settings - Fork 10
/
attentionControl.py
103 lines (85 loc) · 3.85 KB
/
attentionControl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from typing import Union, Tuple
import torch
import abc
class AttentionControl(abc.ABC):
def __init__(self):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
def between_steps(self):
return
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
@abc.abstractmethod
def forward(self, attn, is_cross: bool, place_in_unet: str):
raise NotImplementedError
def __call__(self, attn, is_cross: bool, place_in_unet: str):
if self.cur_att_layer >= 0:
h = attn.shape[0]
self.forward(attn[h // 2:], is_cross, place_in_unet)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
self.between_steps()
return attn
class AttentionStore(AttentionControl):
def __init__(self, res):
super(AttentionStore, self).__init__()
self.step_store = self.get_empty_store()
self.attention_store = {}
self.res = res
@staticmethod
def get_empty_store():
return {"down_cross": [], "mid_cross": [], "up_cross": [],
"down_self": [], "mid_self": [], "up_self": []}
def forward(self, attn, is_cross: bool, place_in_unet: str):
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
if attn.shape[1] <= (self.res // 16) ** 2: # avoid memory overhead
self.step_store[key].append(attn)
return attn
def between_steps(self):
if len(self.attention_store) == 0:
self.attention_store = self.step_store
else:
for key in self.attention_store:
for i in range(len(self.attention_store[key])):
self.attention_store[key][i] = self.step_store[key][i] + self.attention_store[key][i]
self.step_store = self.get_empty_store()
def get_average_attention(self):
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
self.attention_store}
return average_attention
def reset(self):
super(AttentionStore, self).reset()
self.step_store = self.get_empty_store()
self.attention_store = {}
class AttentionControlEdit(AttentionStore, abc.ABC):
def __init__(self, num_steps: int,
self_replace_steps: Union[float, Tuple[float, float]], res):
super(AttentionControlEdit, self).__init__(res)
self.batch_size = 2
if type(self_replace_steps) is float:
self_replace_steps = 0, self_replace_steps
self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
self.loss = 0
self.criterion = torch.nn.MSELoss()
def forward(self, attn, is_cross: bool, place_in_unet: str):
super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
h = attn.shape[0] // (self.batch_size)
attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
attn_base, attn_repalce = attn[0], attn[1:]
if not is_cross:
"""
==========================================
========= Self Attention Control =========
=== Details please refer to Section 3.4 ==
==========================================
"""
self.loss += self.criterion(attn[1:], self.replace_self_attention(attn_base, attn_repalce))
attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
return attn
def replace_self_attention(self, attn_base, att_replace):
return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)