-
I am trying to get the attention weights of my GATConv network. However, my current procedure is failing to retrieve the weights and throwing out an error that I do not understand. I was able to create a minimalistic example to recreate the same error:
The above code produces the following error:
However, when I comment out the edge connections and edge attributes between "type2" and "type3" nodes (indicated above in the code), the error disappears. I am not sure what's happening. Your help will be highly appreciated. Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
I am sorry to inform you that self.convs = ModuleDict({edge_type: GATConv((-1, -1), output_dim, ...) for edge_type in edge_types})
def forward(self, x_dict, edge_index_dict):
for edge_type, edge_index in edge_index_dict.items():
src, _, dst = edge_type
x, alpha = self.convs[edge_type]((x_dict[src], x_dict[dst]), edge_index, return_attention_weights=True) |
Beta Was this translation helpful? Give feedback.
I am sorry to inform you that
return_attention_weights=True
andto_hetero
is currently not compatible, and I don't see a good way to support this long-term. The only workaround would be to implement manually whatto_hetero
is doing internally, i.e., specifying aGATConv
layer for every edge type, and then calling these layers in a loop: