Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hetero Data Dimension Mismatch #9087

Open
xuzhang0112 opened this issue Mar 21, 2024 · 3 comments
Open

Hetero Data Dimension Mismatch #9087

xuzhang0112 opened this issue Mar 21, 2024 · 3 comments
Labels

Comments

@xuzhang0112
Copy link

馃悰 Describe the bug

Hello, I am using the to_hetero method to implement HGT for a subgraph classification problem and have encountered some dimension mismatch issues. However, the to_hetero method encapsulates many operations, preventing me from printing tensor dimensions within the GNN. When I print dimensions externally, my node dimensions and edge indices seem normal. How can I obtain more detailed error information? Or could you offer some insights on how to solve this problem?

here is a simple demo to reproduce my error:

import torch
import torch.nn as nn
from torch_geometric.nn import GATConv, Linear, to_hetero
from torch_geometric.data import HeteroData

class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GATConv((-1, -1), hidden_channels, add_self_loops=False)
        self.lin1 = Linear(-1, hidden_channels)
        self.conv2 = GATConv((-1, -1), out_channels, add_self_loops=False)
        self.lin2 = Linear(-1, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index) + self.lin1(x)
        x = x.relu()
        x = self.conv2(x, edge_index) + self.lin2(x)
        return x

class Network(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.device = config['device']
        self.hidden_size, self.num_labels = 256, 684
        self.embed = nn.Embedding(26008, self.hidden_size)
        # self.embed.load_state_dict(torch.load('bert/embedding-pth'))
        self.dropout = nn.Dropout(p=0.1)
        self.homo_gat = GAT(256, 684)
        self.pooling = nn.AdaptiveMaxPool1d(1)
        self.node_types = config['node_types']
        self.edge_types = config['edge_types']
        self.graph_metadata = (self.node_types, self.edge_types)
        self.hetero_gat = to_hetero(self.homo_gat, self. graph_metadata)
        self.loss = nn.CrossEntropyLoss()
        
    def forward(self, input_list, graph, label_list):
        graph.to(self.device)
        x_dict, edge_index_dict = graph.x_dict, graph.edge_index_dict
        for k, v in x_dict.items():
            if v.nelement() != 0:
                _ = self.embed(v).permute(0,2,1)
                _ = self.pooling(_).squeeze(-1)
                x_dict[k] = self.dropout (_)
            else:
                x_dict[k] = torch.empty((0,256), dtype=torch.float).to(self.device)
        for k,v in x_dict.items():
            print(k, v.shape, v.dtype)
        for k,v in edge_index_dict.items():
            if v.nelement() != 0:
                print(k,v,v.dtype)
        x_dict = self.hetero_gat(x_dict, edge_index_dict)
        logits = x_dict['patient']
        loss = self.loss(logits, label_list[0].to(self.device))
        # return logits
        return [logits], [loss]
    
if __name__ == '__main__':
    node_types = ['patient', 'symptom', 'part', 'thing',
                    'time', 'reason', 'history',
                    'check1', 'checkup1', 'check2', 'checkup2',
                    'positive', 'positive2', 'attributes',
                    'change', "change2", 'reason2', 'procedure']
    edge_types = [(x,'_',y) for x in node_types
                             for y in node_types]
    config = {'device': 'cpu',
              'node_types': node_types,
              'edge_types': edge_types}
    model = Network(config).to(config['device'])

    data = HeteroData()
    for node_type in node_types:
        data[node_type].x = torch.empty((0,20), dtype=torch.long)
    for edge_type in edge_types:
        data[edge_type].edge_index = torch.empty((2,0), dtype=torch.long)
    
    data['patient'].x =  torch.randint(0,10000,(1,20), dtype=torch.long)
    data['symptom'].x =  torch.randint(0,10000,(3,20), dtype=torch.long)
    data['part'].x =  torch.randint(0,10000,(1,20), dtype=torch.long)
    data['time'].x =  torch.randint(0,10000,(2,20), dtype=torch.long)
    data['reason'].x =  torch.randint(0,10000,(1,20), dtype=torch.long)
    data['positive'].x =  torch.randint(0,10000,(1,20), dtype=torch.long)
    data['attributes'].x =  torch.randint(0,10000,(1,20), dtype=torch.long)
    data['change'].x =  torch.randint(0,10000,(2,20), dtype=torch.long)
    data['reason2'].x =  torch.randint(0,10000,(2,20), dtype=torch.long)
    data['patient','_','symptom'].edge_index = torch.tensor([[0,0,0], [0,1,2]], dtype=torch.long)
    data['patient','_','positive'].edge_index = torch.tensor([[0],[0]],dtype=torch.long)
    data['patient','_', 'change'].edge_index = torch.tensor([[0],[1]],dtype=torch.long)
    data['symptom','_', 'part'].edge_index = torch.tensor([[1,2],[0,0]],dtype=torch.long)
    data['symptom','_','time'].edge_index = torch.tensor([[0,1],[0,1]],dtype=torch.long)
    data['symptom','_','reason'].edge_index = torch.tensor([[1],[0]],dtype=torch.long)
    data['symptom','_','attributes'].edge_index = torch.tensor([[1],[0]],dtype=torch.long)
    data['symptom','_','change'].edge_index = torch.tensor([[2],[0]],dtype=torch.long)
    data['reason','_','time'].edge_index = torch.tensor([[0],[1]],dtype=torch.long)
    data['change','_','reason2'].edge_index = torch.tensor([[0,1],[0,1]],dtype=torch.long)
    
    out = model([], data, [torch.tensor([0],dtype=torch.long)])
    print(out)

and then you can see the error:
image

Versions

[pip3] numpy==1.23.5
[pip3] pytorch-ignite==0.4.10
[pip3] pytorch-pretrained-vit==0.0.7
[pip3] torch==1.13.1
[pip3] torch-cluster==1.6.1+pt113cu117
[pip3] torch-geometric==2.3.1
[pip3] torch-scatter==2.1.1+pt113cu117
[pip3] torch-sparse==0.6.17+pt113cu117
[pip3] torch-spline-conv==1.2.2+pt113cu117
[pip3] torch-tb-profiler==0.4.1
[pip3] torchvision==0.14.1
[pip3] triton==2.0.0

@xuzhang0112
Copy link
Author

The bug disappears when I remove the numbers in node types and edge types, e.g. 'reason2' -> 'reasonb'. Maybe numbers are not allowed. Is '_' allowed? I didn't see them in the doc.

@rusty1s
Copy link
Member

rusty1s commented Mar 25, 2024

Mh, I don't think this is related to 2 or _. I hit an error even if changing that. However, I could make the model work by defining the model as

x = self.conv1(x, edge_index) + self.lin1(x)
        x = x.relu()
        x = self.conv2(x, edge_index) + self.lin2(x)
        return x

Will take a look.

@xuzhang0112
Copy link
Author

xuzhang0112 commented Mar 25, 2024

Thank you for replying. Have you removed all the numbers in the code? e.g. '1'->'a','2'->'b'. I can ensure it works after I change them.
Here is my revised version, and you can just run it.

import torch
import torch.nn as nn
from torch_geometric.nn import GATConv, Linear, to_hetero
from torch_geometric.data import HeteroData

class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GATConv((-1, -1), hidden_channels, add_self_loops=False)
        self.lin1 = Linear(-1, hidden_channels)
        self.conv2 = GATConv((-1, -1), out_channels, add_self_loops=False)
        self.lin2 = Linear(-1, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index) + self.lin1(x)
        x = x.relu()
        x = self.conv2(x, edge_index) + self.lin2(x)
        return x

class Network(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.device = config['device']
        self.hidden_size, self.num_labels = 256, 684
        self.embed = nn.Embedding(26008, self.hidden_size)
        # self.embed.load_state_dict(torch.load('bert/embedding-pth'))
        self.dropout = nn.Dropout(p=0.1)
        self.homo_gat = GAT(256, 684)
        self.pooling = nn.AdaptiveMaxPool1d(1)
        self.node_types = config['node_types']
        self.edge_types = config['edge_types']
        self.graph_metadata = (self.node_types, self.edge_types)
        self.hetero_gat = to_hetero(self.homo_gat, self. graph_metadata)
        self.loss = nn.CrossEntropyLoss()
        
    def forward(self, input_list, graph, label_list):
        graph.to(self.device)
        x_dict, edge_index_dict = graph.x_dict, graph.edge_index_dict
        for k, v in x_dict.items():
            if v.nelement() != 0:
                _ = self.embed(v).permute(0,2,1)
                _ = self.pooling(_).squeeze(-1)
                x_dict[k] = self.dropout (_)
            else:
                x_dict[k] = torch.empty((0,256), dtype=torch.float).to(self.device)
        # for k,v in x_dict.items():
        #     print(k, v.shape, v.dtype)
        # for k,v in edge_index_dict.items():
        #     if v.nelement() != 0:
        #         print(k,v,v.dtype)
        x_dict = self.hetero_gat(x_dict, edge_index_dict)
        logits = x_dict['patient']
        loss = self.loss(logits, label_list[0].to(self.device))
        # return logits
        return [logits], [loss]
    
if __name__ == '__main__':
    node_types = ['patient', 'symptom', 'part', 'thing',
                    'time', 'reason', 'history',
                    'checka', 'checkupa', 'checkb', 'checkupb',
                    'positive', 'positiveb', 'attributes',
                    'change', "changeb", 'reasonb', 'procedure']
    edge_types = [(x,'_',y) for x in node_types
                             for y in node_types]
    config = {'device': 'cpu',
              'node_types': node_types,
              'edge_types': edge_types}
    model = Network(config).to(config['device'])

    data = HeteroData()
    for node_type in node_types:
        data[node_type].x = torch.empty((0,20), dtype=torch.long)
    for edge_type in edge_types:
        data[edge_type].edge_index = torch.empty((2,0), dtype=torch.long)
    
    data['patient'].x =  torch.randint(0,10000,(1,20), dtype=torch.long)
    data['symptom'].x =  torch.randint(0,10000,(3,20), dtype=torch.long)
    data['part'].x =  torch.randint(0,10000,(1,20), dtype=torch.long)
    data['time'].x =  torch.randint(0,10000,(2,20), dtype=torch.long)
    data['reason'].x =  torch.randint(0,10000,(1,20), dtype=torch.long)
    data['positive'].x =  torch.randint(0,10000,(1,20), dtype=torch.long)
    data['attributes'].x =  torch.randint(0,10000,(1,20), dtype=torch.long)
    data['change'].x =  torch.randint(0,10000,(2,20), dtype=torch.long)
    data['reasonb'].x =  torch.randint(0,10000,(2,20), dtype=torch.long)
    data['patient','_','symptom'].edge_index = torch.tensor([[0,0,0], [0,1,2]], dtype=torch.long)
    data['patient','_','positive'].edge_index = torch.tensor([[0],[0]],dtype=torch.long)
    data['patient','_', 'change'].edge_index = torch.tensor([[0],[1]],dtype=torch.long)
    data['symptom','_', 'part'].edge_index = torch.tensor([[1,2],[0,0]],dtype=torch.long)
    data['symptom','_','time'].edge_index = torch.tensor([[0,1],[0,1]],dtype=torch.long)
    data['symptom','_','reason'].edge_index = torch.tensor([[1],[0]],dtype=torch.long)
    data['symptom','_','attributes'].edge_index = torch.tensor([[1],[0]],dtype=torch.long)
    data['symptom','_','change'].edge_index = torch.tensor([[2],[0]],dtype=torch.long)
    data['reason','_','time'].edge_index = torch.tensor([[0],[1]],dtype=torch.long)
    data['change','_','reasonb'].edge_index = torch.tensor([[0,1],[0,1]],dtype=torch.long)
    
    out = model([], data, [torch.tensor([0],dtype=torch.long)])
    print(out)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants