Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FLOPs calculation for LSHSelfAttention in LSH mode and Full attention mode #149

Open
zaidilyas89 opened this issue May 27, 2022 · 0 comments

Comments

@zaidilyas89
Copy link

zaidilyas89 commented May 27, 2022

As per my understanding, FLOPs calculation is usually done on complete model, but, I am trying to test computational cost comparison of only LSH attention module of Reformer by providing it random input vectors. This LSH attention module switches between LSH hashing and full dot product based attention using setting flag use_full_attn=False and use_full_attn=True.

But the problem is that whatever size of input vectors I set for qk and v, the number of FLOPs appear to be same for both calculations.

By setting use_full_attn=False and use_full_attn=True the attention model is switched between LSH based attention and Full attention. I have verified this in debug mode of Spyder IDE.

Am I missing something?

How can I verify this? I would be grateful if someone can help me.

Code: (From Reformer Github website)


import torch

from reformer_pytorch import LSHAttention

model = LSHSelfAttention(
   

>  dim = 128,

    heads = 8,
    bucket_size = 64,
    n_hashes = 16,
    causal = True,
    use_full_attn=**False**,
    return_attn = False
).to(device)

qk = torch.randn(10, 1024, 128)

v = torch.randn(10, 1024, 128)

x = torch.randn(1, 1024, 128).to(device)

y = model(x) # (10, 1024, 128)

Code for FLOPs calculation: (https://github.com/cszn/KAIR/blob/master/utils/utils_modelsummary.py)

with torch.no_grad():
    input_dim = (1, 16384, 128)  # set the input dimension

    flops = get_model_flops(model, input_dim, False)
    
    print('{:>16s} : {:<.4f} [G]'.format('FLOPs', flops/10**9))

```Result in both cases:

FLOPs : 0.8053 [G]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant