You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In my use-case I have a custom dataset in which some node-types don't occur in every HeteroData object.
However, torch_geometric.data.separate assumes that the 'x' tensor for each node-type in the slice_dict is equally long and subsequently throws an index out of bounds error.
I guess a quick workaround is to add d2['optional_type'].x = torch.empty(0) whenever the key type does not occur in the sample. So I am not sure if you want to support this use-case, but it would be great to add a note to the docs ;)
Minimal reproducible example:
importtorchfromtorch_geometric.dataimportHeteroDatafromtypingimportUnion, List, Tuplefromtorch_geometric.dataimportInMemoryDatasetd1=HeteroData()
d1['node'].x=torch.zeros(4)
d1['optional_type'].x=torch.ones(4) # this node type might not always occurd2=HeteroData()
d2['node'].x=torch.ones(4)
importtempfiletempdir=tempfile.mkdtemp()
classMinDataset(InMemoryDataset):
def__init__(
self,
) ->None:
super().__init__(root=tempdir)
self.load(self.processed_paths[0])
@propertydefprocessed_file_names(self) ->Union[str, List[str], Tuple[str, ...]]:
return ["data.pt"]
defprocess(self) ->None:
data_list= [d1,d2]
self.save(data_list, self.processed_paths[0])
dataset=MinDataset()
# dataset.len() == 2foriinrange(dataset.len()):
print(dataset[i]) # will throw index out of bounds
...python3.10/site-packages/torch_geometric/data/in_memory_dataset.py:111, inInMemoryDataset.get(self, idx)
108elifself._data_list[idx] isnotNone:
109returncopy.copy(self._data_list[idx])
-->111data=separate(
112cls=self._data.__class__,
113batch=self._data,
114idx=idx,
115slice_dict=self.slices,
116decrement=False,
117 )
119self._data_list[idx] =copy.copy(data)
121returndata
.../python3.10/site-packages/torch_geometric/data/separate.py:48, inseparate(cls, batch, idx, slice_dict, inc_dict, decrement)
45slices=slice_dict[attr]
46incs=inc_dict[attr] ifdecrementelseNone--->48data_store[attr] =_separate(attr, batch_store[attr], idx, slices,
49incs, batch, batch_store, decrement)
51# The `num_nodes` attribute needs special treatment, as we cannot infer52# the real number of nodes from the total number of nodes alone:53ifhasattr(batch_store, '_num_nodes'):
.../python3.10/site-packages/torch_geometric/data/separate.py:75, in_separate(key, values, idx, slices, incs, batch, store, decrement)
73key=str(key)
74cat_dim=batch.__cat_dim__(key, values, store)
--->75start, end=int(slices[idx]), int(slices[idx+1])
76value=narrow(values, cat_dimor0, start, end-start)
77value=value.squeeze(0) ifcat_dimisNoneelsevalueIndexError: index2isoutofboundsfordimension0withsize2
Versions
PyTorch version: 2.2.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
Yeah, since mini-batching would get more expensive if we check for the existence of attributes across data objects, we assume that every data object contains the same set of attributes right now. Adding torch.empty((0, ...)) is the recommended approach.
馃悰 Describe the bug
In my use-case I have a custom dataset in which some node-types don't occur in every
HeteroData
object.However,
torch_geometric.data.separate
assumes that the'x'
tensor for each node-type in theslice_dict
is equally long and subsequently throws an index out of bounds error.I guess a quick workaround is to add
d2['optional_type'].x = torch.empty(0)
whenever the key type does not occur in the sample. So I am not sure if you want to support this use-case, but it would be great to add a note to the docs ;)Minimal reproducible example:
Versions
PyTorch version: 2.2.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-39-generic-x86_64-with-glibc2.35
...
The text was updated successfully, but these errors were encountered: