Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistency on edge_index shape for A3TGCN2 module #261

Open
jolenscki opened this issue Dec 13, 2023 · 0 comments
Open

Inconsistency on edge_index shape for A3TGCN2 module #261

jolenscki opened this issue Dec 13, 2023 · 0 comments

Comments

@jolenscki
Copy link

jolenscki commented Dec 13, 2023

Hello! First of all I want to thank the maintainers of this repository for the great work! It has been helping me a lot on my projects!
I'm trying to use a A3TGCN2 for traffic prediction, but I've had a hard time to understand how to set it up. I thought that the way of defining this module would follow the same logic as other modules from this library, such as GConvLSTM, but apparently the A3TGCN2 module doesn't accept batched edge_index tensors.

The documentation doesn't say it, but in order to instantiate the A3TGCN2 module you should pass batch_size as a parameter, and when calling the forward call of it, X must have shape (batch_size, num_nodes, features, seq_len). But when doing this, the original edge_index tensor (which is still batched) causes the following error:

C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\ScatterGatherKernel.cu:145: block: [104,0,0], thread: [127,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.

This is because the edge_index passed is too big. A solution I've found to fix this (and thus allow smooth batch calls) is to create a method that unbatches the edge_index tensor, assuming that it is static and the same for all items in the batch.

class MyModel(nn.Module):
    def __init__(self,
                 features: int,
                 out_dim: int,
                 batch_size: int,
                 periods: int,
                 device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 
                ):
        super(Predictor, self).__init__()
        self.features= features
        self.out_dim= out_dim
        self.batch_size = batch_size
        self.device = device
        self.periods = periods
        
        self.tgnn = A3TGCN2(
            in_channels=self.features,
            out_channels=self.out_dim,
            periods=self.periods,
            batch_size=self.batch_size,
        )
        
    def forward(self, x, edge_index, batch):
        '''
        Parameters
        ------------
        x: torch.Tensor
            node features, of shape (seq_len, num_nodes*batch_size, features)
        edge_index: torch.Tensor
            edge_indices, of shape (2, batch_size*num_edges)
        batch: torch.Tensor
            batch tensor (tensor that delimits the edge indices of each batch),
            of shape (batch_size*num_edges)
        '''

        seq_len, num_nodes, features= x.shape
        x = torch.movedim(x, 0, -1)
        x = x.reshape(self.batch_size, -1, features, seq_len)
        
        # now x is shaped (batch_size, num_nodes, features, seq_len)

        # now we unbatch the edge_index tensor
        edge_index = self.unbatch_edge_index(edge_index, batch)

        H = self.tgnn(X=x, edge_index=edge_index)

        return H
    
    def unbatch_edge_index(self, edge_index, batch):
        # Calculate the number of nodes in each graph
        num_nodes_per_graph = torch.bincount(batch)

        # Calculate the cumulative sum of nodes to determine the boundaries
        cum_nodes = torch.cumsum(num_nodes_per_graph, dim=0)
        cum_nodes = torch.cat([torch.tensor([0], device=self.device), cum_nodes])

        # Split the edge_index for each graph
        mask = (edge_index[0] >= cum_nodes[0]) & (edge_index[0] < cum_nodes[1])
        edge_subset = edge_index[:, mask]
        # Adjust node indices to start from 0 for each graph
        edge_subset[0] -= cum_nodes[0]
        edge_subset[1] -= cum_nodes[0]

        return edge_subset

This seems to work with the batched edge_index tensor that a torch_geometric.Dataloader yields.

I also wrote a stackoverflow QA with the full problem I was facing and I'm posting this here so people that search for this specific issue can find an answer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant