Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Index out of bounds with custom heterogeneous datasets #9233

Open
jkrude opened this issue Apr 24, 2024 · 2 comments
Open

Index out of bounds with custom heterogeneous datasets #9233

jkrude opened this issue Apr 24, 2024 · 2 comments
Labels

Comments

@jkrude
Copy link

jkrude commented Apr 24, 2024

馃悰 Describe the bug

In my use-case I have a custom dataset in which some node-types don't occur in every HeteroData object.
However, torch_geometric.data.separate assumes that the 'x' tensor for each node-type in the slice_dict is equally long and subsequently throws an index out of bounds error.

I guess a quick workaround is to add d2['optional_type'].x = torch.empty(0) whenever the key type does not occur in the sample. So I am not sure if you want to support this use-case, but it would be great to add a note to the docs ;)

Minimal reproducible example:

import torch
from torch_geometric.data import HeteroData
from typing import Union, List, Tuple
from torch_geometric.data import InMemoryDataset

d1 = HeteroData()
d1['node'].x = torch.zeros(4)
d1['optional_type'].x = torch.ones(4)  # this node type might not always occur
d2 = HeteroData()
d2['node'].x = torch.ones(4)

import tempfile
tempdir = tempfile.mkdtemp()


class MinDataset(InMemoryDataset):

    def __init__(
            self,
    ) -> None:
        super().__init__(root=tempdir)
        self.load(self.processed_paths[0])

    @property
    def processed_file_names(self) -> Union[str, List[str], Tuple[str, ...]]:
        return ["data.pt"]

    def process(self) -> None:
        data_list = [d1,d2]
        self.save(data_list, self.processed_paths[0])

dataset = MinDataset()

# dataset.len() == 2

for i in range(dataset.len()):
    print(dataset[i]) # will throw index out of bounds
...python3.10/site-packages/torch_geometric/data/in_memory_dataset.py:111, in InMemoryDataset.get(self, idx)
    108 elif self._data_list[idx] is not None:
    109     return copy.copy(self._data_list[idx])
--> 111 data = separate(
    112     cls=self._data.__class__,
    113     batch=self._data,
    114     idx=idx,
    115     slice_dict=self.slices,
    116     decrement=False,
    117 )
    119 self._data_list[idx] = copy.copy(data)
    121 return data

.../python3.10/site-packages/torch_geometric/data/separate.py:48, in separate(cls, batch, idx, slice_dict, inc_dict, decrement)
     45         slices = slice_dict[attr]
     46         incs = inc_dict[attr] if decrement else None
---> 48     data_store[attr] = _separate(attr, batch_store[attr], idx, slices,
     49                                  incs, batch, batch_store, decrement)
     51 # The `num_nodes` attribute needs special treatment, as we cannot infer
     52 # the real number of nodes from the total number of nodes alone:
     53 if hasattr(batch_store, '_num_nodes'):

.../python3.10/site-packages/torch_geometric/data/separate.py:75, in _separate(key, values, idx, slices, incs, batch, store, decrement)
     73 key = str(key)
     74 cat_dim = batch.__cat_dim__(key, values, store)
---> 75 start, end = int(slices[idx]), int(slices[idx + 1])
     76 value = narrow(values, cat_dim or 0, start, end - start)
     77 value = value.squeeze(0) if cat_dim is None else value

IndexError: index 2 is out of bounds for dimension 0 with size 2

Versions

PyTorch version: 2.2.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-39-generic-x86_64-with-glibc2.35
...

@jkrude jkrude added the bug label Apr 24, 2024
@jkrude
Copy link
Author

jkrude commented Apr 25, 2024

I guess this is related to #3984 in which the same problem was raised for varying edge-types.

@rusty1s
Copy link
Member

rusty1s commented Apr 26, 2024

Yeah, since mini-batching would get more expensive if we check for the existence of attributes across data objects, we assume that every data object contains the same set of attributes right now. Adding torch.empty((0, ...)) is the recommended approach.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants