Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

what's the purpose of bool_matrix1024, bool_matrix4096 in the return of cal_attn_mask_xl function #97

Open
parryppp opened this issue May 16, 2024 · 6 comments

Comments

@parryppp
Copy link

The image of bool_matrix1024, bool_matrix4096 are shown as belowed.
image
image

@Z-YuPeng
Copy link
Collaborator

Our method is based on sampling tokens to interact between images through attention operations. To perform the sampling operation, the mask identifies the tokens to be sampled. In order to reduce memory consumption, we switch from using the mask to using indices.

@parryppp
Copy link
Author

image The reshaped attention mask is shown above. Do you mean that, for example, if i want to generate 4 consistent images, the yellow zone in the attention map would not be masked, then what does 'randsample' in the paper mean?" image

@Z-YuPeng
Copy link
Collaborator

The generated random mask is exactly the means of implementing random sampling. Once we have randomized a mask, it means that only the tokens indicated by mask = 1 will be considered.

@Z-YuPeng
Copy link
Collaborator

the yellow zone corresponding to the concatenation operation below, we found that we cannot drop a image's own tokens, as this would lead to a significant decline in image quality.

@Z-YuPeng
Copy link
Collaborator

The mask operation combines random sampling and concatenation into a single step because we initially found that doing so was faster and equivalent to random sampling but also led to greater memory usage. Later, we reverted to the original approach due to concerns about memory consumption raised in issues.
https://github.com/HVision-NKU/StoryDiffusion/blob/main/utils/gradio_utils.py#L258

@parryppp
Copy link
Author

Thank you for your explanation, I now understand much more clearly. But I still have a question about the shape of attention mask. why does the attention mask ensure that squares on the diagonal remain set to 1 as shown in the figure belowed, is it the purpose of code in L249-L252?
image

bool_matrix1024[i:i+1,id_length*nums_1024:] = False

why not just simply generate a random attention mask just like the belowed figure?
bool_matrix1024 = torch.rand((total_length,nums_1024),device = device,dtype = dtype) < sa32

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants