Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong return values for group_argsort #9209

Open
clathe opened this issue Apr 17, 2024 · 0 comments
Open

Wrong return values for group_argsort #9209

clathe opened this issue Apr 17, 2024 · 0 comments
Labels

Comments

@clathe
Copy link

clathe commented Apr 17, 2024

馃悰 Describe the bug

I believe the function group_argsort returns the wrong values. If there is only one index, the behaviour should be equivalent to torch.argsort but this is not the case.

import torch
from torch_geometric.utils import group_argsort

src = torch.tensor([0, 1, 10, 2, 7, 11, 3, 8, 15, 12, 4, 14, 9, 17, 16, 13, 20, 5, 19, 6, 18])
index = torch.zeros_like(src)

out = group_argsort(src, index)
expected = src.argsort()
print(src[out])
# tensor([ 0,  1,  4, 10,  8, 14,  2, 15, 13,  9,  7, 16, 12,  5, 20, 17, 18, 11, 6,  3, 19])
print(src[expected])
# tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])
assert torch.allclose(out, expected)

I believe the error lies in these line:

out[perm] = torch.arange(index.numel(), device=index.device)

which should be removed if the return values are the indices that sort the tensor src.

However, if I do this and replace the line with

out = perm

the mult-index case breaks.

src = torch.tensor([0, 1, 10, 2, 7, 11, 3, 8, 15, 12, 4, 14, 9, 17, 16, 13, 20, 5, 19, 6, 18])
index = torch.tensor([0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0])
out = group_argsort(src, index)
expected = torch.empty_like(out)
for i in range(index.max() + 1):
    mask = index == i
    expected[mask] = src[mask].argsort()

print(src[index == 0][expected[index == 0]])
# tensor([ 0,  4,  6,  7, 11, 13, 14, 15, 18, 19])
print(src[index == 0][out[index == 0]])
# IndexError: index 15 is out of bounds for dimension 0 with size 10
assert torch.allclose(out, expected)

Overall, I am wondering why this line

src = src - 2 * index if descending else src + 2 * index
is sufficient to compute a grouped argsort.

Versions

PyTorch version: 2.2.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.2.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:54:21) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-14.2.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M2 Pro

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.2.0
[pip3] torch_cluster==1.6.3
[pip3] torch_geometric==2.5.2
[pip3] torch_scatter==2.1.2
[pip3] torch_sparse==0.6.18
[pip3] torch_spline_conv==1.2.2
[conda] numpy 1.26.4 pypi_0 pypi
[conda] torch 2.2.0 pypi_0 pypi
[conda] torch-cluster 1.6.3 pypi_0 pypi
[conda] torch-geometric 2.5.2 pypi_0 pypi
[conda] torch-scatter 2.1.2 pypi_0 pypi
[conda] torch-sparse 0.6.18 pypi_0 pypi
[conda] torch-spline-conv 1.2.2 pypi_0 pypi

@clathe clathe added the bug label Apr 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant