Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: BasicGNN.forward() got an unexpected keyword argument 'return_attention_weights' #9160

Open
Batene opened this issue Apr 6, 2024 · 1 comment
Labels

Comments

@Batene
Copy link

Batene commented Apr 6, 2024

馃悰 Describe the bug

In GATv2Conv, return_attention_weights is expected to be passed as a parameter to the forward function.
When using Gatv2Conv indirectly by using GAT with v2=True, I cannot pass return_attention_weights=True when calling the GAT model (= using the forward function). This will throw the error above.
One can pass return_attention_weights=True in the GAT initalization. But then, this will never be used, as the forward function in GATv2Conv does not access self.return_attention_weights (it does not exist)). The forward function requires the parameter return_attention_weights instead of accessing a class attribute.

My solution for now is:

  1. add in the __init__ function of Gatv2Conv:

     `self.return_attention_weights = kwargs["return_attention_weights"]`
    
  2. in the forward function of Gatv2Conv:

    if isinstance(return_attention_weights, bool):

     `if isinstance(self.return_attention_weights, bool):`
    

I am using pytorch geometric version 2.5.2.
I assume, the same problem occurs when using v2=False (with GatConv).

Versions

Python version: 3.11.7

@Batene Batene added the bug label Apr 6, 2024
@rusty1s
Copy link
Member

rusty1s commented Apr 8, 2024

Are you referring to models.GAT? If you need to return attention weights, I suggest to use the GNN layers and build your own model on top.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants