You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In GATv2Conv, return_attention_weights is expected to be passed as a parameter to the forward function.
When using Gatv2Conv indirectly by using GAT with v2=True, I cannot pass return_attention_weights=True when calling the GAT model (= using the forward function). This will throw the error above.
One can pass return_attention_weights=True in the GAT initalization. But then, this will never be used, as the forward function in GATv2Conv does not access self.return_attention_weights (it does not exist)). The forward function requires the parameter return_attention_weights instead of accessing a class attribute.
馃悰 Describe the bug
In GATv2Conv,
return_attention_weights
is expected to be passed as a parameter to the forward function.When using Gatv2Conv indirectly by using GAT with
v2=True
, I cannot passreturn_attention_weights=True
when calling the GAT model (= using the forward function). This will throw the error above.One can pass
return_attention_weights=True
in the GAT initalization. But then, this will never be used, as the forward function in GATv2Conv does not accessself.return_attention_weights
(it does not exist)). The forward function requires the parameterreturn_attention_weights
instead of accessing a class attribute.My solution for now is:
add in the
__init__
function of Gatv2Conv:in the
forward
function of Gatv2Conv:pytorch_geometric/torch_geometric/nn/conv/gatv2_conv.py
Line 312 in 38bb5f2
I am using pytorch geometric version 2.5.2.
I assume, the same problem occurs when using
v2=False
(with GatConv).Versions
Python version: 3.11.7
The text was updated successfully, but these errors were encountered: