Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Index out of range in SchNet on a modification of QM9 dataset. #9299

Open
CalmScout opened this issue May 6, 2024 · 1 comment
Open

Index out of range in SchNet on a modification of QM9 dataset. #9299

CalmScout opened this issue May 6, 2024 · 1 comment
Labels

Comments

@CalmScout
Copy link

馃悰 Describe the bug

Hi!

The idea of the code below is to run a custom version of SchNet on SMILES representations of molecules. Code:

print("Importing packages...")
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import QM9
from torch_geometric.nn import SchNet
from tqdm import tqdm
import pickle
import os

print("Defining functions...")
# Define a function to convert SMILES to PyG data objects
def smiles_to_pyg_graph(smiles):
    from rdkit import Chem
    from rdkit.Chem import AllChem
    from torch_geometric.data import Data

    try:
        mol = Chem.MolFromSmiles(smiles)
    except:
        return None
    
    if mol is None:
        return None

    # Add Hydrogens to the molecule
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol)

    # Convert the molecule to a graph
    node_features = []
    for atom in mol.GetAtoms():
        node_features.append(atom_feature(atom))
    # node_features = torch.tensor(node_features, dtype=torch.float)
    node_features = torch.tensor(node_features, dtype=torch.long)

    edge_indices = []
    edge_features = []

    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        edge_indices.append((start, end))
        edge_indices.append((end, start))
        edge_features.append(bond_feature(bond))
        edge_features.append(bond_feature(bond))

    edge_indices = torch.tensor(edge_indices).t().to(torch.long)
    # edge_features = torch.tensor(edge_features, dtype=torch.float)
    edge_features = torch.tensor(edge_features, dtype=torch.long)

    return Data(x=node_features, edge_index=edge_indices, edge_attr=edge_features)

# Helper functions for node and edge features
def atom_feature(atom):
    return [atom.GetAtomicNum(), atom.GetFormalCharge()]

def bond_feature(bond):
    return [int(bond.GetBondTypeAsDouble())]

# Load dataset and convert SMILES to PyG data objects
print("Creating dataset...")
# if we have cached data, load it
if os.path.exists('data/qm9_pyg_data.pkl'):
    print("Loading data from cache...")
    with open('data/qm9_pyg_data.pkl', 'rb') as f:
        data_list = pickle.load(f)
else:
    print("Creating dataset from scratch...")
    dataset = QM9(root='data')
    data_list = []
    # for i in tqdm(range(len(dataset))):
    for i in tqdm(range(1000)):
        smiles = dataset[i]['smiles']
        data = smiles_to_pyg_graph(smiles)
        if data is not None:
            data_list.append(data)
    # Save data_list to a pickle file
    with open('data/qm9_pyg_data.pkl', 'wb') as f:
        pickle.dump(data_list, f)

print(f"Example data entry in the data_list: {data_list[0]}")

# Define a SchNet model
class MySchNet(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_targets):
        super(MySchNet, self).__init__()
        self.schnet = SchNet(hidden_channels, num_features)
        self.lin = torch.nn.Linear(hidden_channels, num_targets)

    def forward(self, data):
        print(f'pirnt from forward: data.x.shape: {data.x.shape}')
        print(f'pirnt from forward: data.edge_index.shape: {data.edge_index.shape}')
        print(f'pirnt from forward: data.edge_attr.shape: {data.edge_attr.shape}')
        out = self.schnet(data.x, data.edge_index, data.edge_attr)
        out = self.lin(out)
        return out

# Instantiate the model and define other training parameters
print("Defining model...")
model = MySchNet(num_features=2, hidden_channels=64, num_targets=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()

The correspondign output before the Exception:

Training...
Batch size: 32
type(batch.x): <class 'torch.Tensor'>
batch.x.dtype: torch.int64
Batch edge_index shape: torch.Size([2, 834])
Batch edge_index dtype: torch.int64
Batch edge_attr shape: torch.Size([834, 1])
Batch edge_attr dtype: torch.int64
pirnt from forward: data.x.shape: torch.Size([419, 2])
pirnt from forward: data.edge_index.shape: torch.Size([2, 834])
pirnt from forward: data.edge_attr.shape: torch.Size([834, 1])

And an Exception message:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[5], [line 17](vscode-notebook-cell:?execution_count=5&line=17)
     [15](vscode-notebook-cell:?execution_count=5&line=15) print(f'Batch edge_attr dtype: {batch.edge_attr.dtype}')
     [16](vscode-notebook-cell:?execution_count=5&line=16) optimizer.zero_grad()
---> [17](vscode-notebook-cell:?execution_count=5&line=17) output = model(batch)
     [18](vscode-notebook-cell:?execution_count=5&line=18) loss = criterion(output, batch.y.view(-1, 1))  # Assuming targets are stored in batch.y
     [19](vscode-notebook-cell:?execution_count=5&line=19) loss.backward()

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1530](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1530)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1531](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1531) else:
-> [1532](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1532)     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   [1536](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
   [1537](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1537) # this function, and just call forward.
   [1538](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1539](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1539)         or _global_backward_pre_hooks or _global_backward_hooks
   [1540](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1540)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1541)     return forward_call(*args, **kwargs)
   [1543](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1543) try:
   [1544](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1544)     result = None

Cell In[4], [line 14](vscode-notebook-cell:?execution_count=4&line=14)
     [12](vscode-notebook-cell:?execution_count=4&line=12) print(f'pirnt from forward: data.edge_index.shape: {data.edge_index.shape}')
     [13](vscode-notebook-cell:?execution_count=4&line=13) print(f'pirnt from forward: data.edge_attr.shape: {data.edge_attr.shape}')
---> [14](vscode-notebook-cell:?execution_count=4&line=14) out = self.schnet(data.x, data.edge_index, data.edge_attr)
     [15](vscode-notebook-cell:?execution_count=4&line=15) out = self.lin(out)
     [16](vscode-notebook-cell:?execution_count=4&line=16) return out

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1530](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1530)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1531](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1531) else:
-> [1532](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1532)     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   [1536](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
   [1537](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1537) # this function, and just call forward.
   [1538](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1539](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1539)         or _global_backward_pre_hooks or _global_backward_hooks
   [1540](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1540)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1541)     return forward_call(*args, **kwargs)
   [1543](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1543) try:
   [1544](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1544)     result = None

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:284, in SchNet.forward(self, z, pos, batch)
    [271](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:271) r"""Forward pass.
    [272](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:272) 
    [273](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:273) Args:
   (...)
    [280](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:280)         (default: :obj:`None`)
    [281](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:281) """
    [282](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:282) batch = torch.zeros_like(z) if batch is None else batch
--> [284](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:284) h = self.embedding(z)
    [285](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:285) edge_index, edge_weight = self.interaction_graph(pos, batch)
    [286](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch_geometric/nn/models/schnet.py:286) edge_attr = self.distance_expansion(edge_weight)

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1530](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1530)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1531](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1531) else:
-> [1532](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1532)     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   [1536](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
   [1537](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1537) # this function, and just call forward.
   [1538](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1539](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1539)         or _global_backward_pre_hooks or _global_backward_hooks
   [1540](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1540)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1541)     return forward_call(*args, **kwargs)
   [1543](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1543) try:
   [1544](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/module.py:1544)     result = None

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/sparse.py:163, in Embedding.forward(self, input)
    [162](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/sparse.py:162) def forward(self, input: Tensor) -> Tensor:
--> [163](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/sparse.py:163)     return F.embedding(
    [164](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/sparse.py:164)         input, self.weight, self.padding_idx, self.max_norm,
    [165](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/modules/sparse.py:165)         self.norm_type, self.scale_grad_by_freq, self.sparse)

File ~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2264, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   [2258](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2258)     # Note [embedding_renorm set_grad_enabled]
   [2259](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2259)     # XXX: equivalent to
   [2260](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2260)     # with torch.no_grad():
   [2261](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2261)     #   torch.embedding_renorm_
   [2262](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2262)     # remove once script supports set_grad_enabled
   [2263](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2263)     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> [2264](/home/popova/Projects/citre-quantum-chemistry/nbs/~/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/nn/functional.py:2264) return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

IndexError: index out of range in self

Thanks for reading! I appreciate any feedback regarding the issue.

Best regards,
Anton.

Versions

Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:38:13) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1058-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 10.1.243
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla V100-SXM2-16GB
Nvidia driver version: 535.171.04
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 8
On-line CPU(s) list: 0-7
Thread(s) per core: 2
Core(s) per socket: 4
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 79
Model name: Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
Stepping: 1
CPU MHz: 3000.000
CPU max MHz: 3000.0000
CPU min MHz: 1200.0000
BogoMIPS: 4600.02
Hypervisor vendor: Xen
Virtualization type: full
L1d cache: 128 KiB
L1i cache: 128 KiB
L2 cache: 1 MiB
L3 cache: 45 MiB
NUMA node0 CPU(s): 0-7
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single pti fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx xsaveopt

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-lightning==2.2.3
[pip3] torch==2.3.0
[pip3] torch_cluster==1.6.3+pt22cu121
[pip3] torch-ema==0.3
[pip3] torch_geometric==2.5.3
[pip3] torch_scatter==2.1.2+pt22cu121
[pip3] torch_sparse==0.6.18+pt22cu121
[pip3] torch_spline_conv==1.2.2+pt22cu121
[pip3] torchaudio==2.3.0
[pip3] torchmetrics==1.0.1
[pip3] torchvision==0.18.0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] pytorch-lightning 2.2.3 pypi_0 pypi
[conda] torch 2.3.0 pypi_0 pypi
[conda] torch-cluster 1.6.3+pt22cu121 pypi_0 pypi
[conda] torch-ema 0.3 pypi_0 pypi
[conda] torch-geometric 2.5.3 pypi_0 pypi
[conda] torch-scatter 2.1.2+pt22cu121 pypi_0 pypi
[conda] torch-sparse 0.6.18+pt22cu121 pypi_0 pypi
[conda] torch-spline-conv 1.2.2+pt22cu121 pypi_0 pypi
[conda] torchaudio 2.3.0 pypi_0 pypi
[conda] torchmetrics 1.0.1 pypi_0 pypi
[conda] torchvision 0.18.0 pypi_0 pypi

@CalmScout CalmScout added the bug label May 6, 2024
@rusty1s
Copy link
Member

rusty1s commented May 13, 2024

Currently, PyG's SchNet expects an input feature vector of shape [num_atoms], while it looks that your input is two-dimensional.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants