Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CMPNN model #9223

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open

Add CMPNN model #9223

wants to merge 10 commits into from

Conversation

devanshamin
Copy link

@devanshamin devanshamin commented Apr 21, 2024

Paper Summary

  • Introduces a new model, CMPNN, for Communicative Representation Learning on Attributed Molecular Graphs
  • CMPNN improves molecular embedding by:
    • Following the edge-based message passing in DMPNN
    • Introducing node-edge message communication modules such as Inner Product Kernel, Gated Graph Kernel, and Multilayer Perception
    • Updating both bond and atom embeddings during training
    • Including a message booster to enrich the message generation process

Motivation

  • Adding an alternative to DMPNN (Chemprop uses this GNN architecture) that improves DMPNN by adding:
    • Node-edge message communication modules
    • Message booster
    • Updating bond embeddings
  • Adding a GNN model that updates edge embeddings during training

Benchmark Results

I have benchmarked CMPNN on a subset of datasets from the TDC ADMET Benchmark Group. I have used LitGNN to perform this benchmark and results can be found in a W&B report.

Task Classification (AUROC ↑) Regression (MAE ↓)
Dataset BBB_Martins AMES Solubility_AqSolDB Lipophilicity_AstraZeneca LD50_Zhu
AttentiveFP 0.855 ± 0.011 0.814 ± 0.008 0.776 ± 0.008 0.572 ± 0.007 0.678 ± 0.012
Chemprop 0.821 ± 0.112 0.842 ± 0.014 0.829 ± 0.022 0.470 ± 0.009 0.606 ± 0.024
Chemprop-RDKit* 0.869 ± 0.027 0.850 ± 0.004 0.761 ± 0.025 0.467 ± 0.006 0.625 ± 0.022
CMPNN 0.89 ± 0.016
CMPNN-GRU
0.843 ± 0.009
CMPNN-MLP
0.796 ± 0.038
CMPNN-GRU
0.515 ± 0.008
CMPNN-MLP
0.631 ± 0.021
CMPNN-Additive

Table: Prediction results of CMPNN on five chemical graph datasets. The datasets were used from the TDC ADMET Benchmark group that provides train_val/test scaffold splits. The model was trained and tested for each task for five times, and reported the mean and standard deviation of AUROC or MAE values. *Chemprop-RDKIT utilizes a hybrid approach where it combines the learned molecule embeddings with 200 global molecule features (descriptors).

Implementation Details

Note

Here is a fork of the original code with some cleanups, addition of poetry for dependency management etc.

Below are the places where improvements have been made:

  • Node-edge message communication modules
    • In the paper, the authors mention different communicators such as Inner Product Kernel, Gated Graph Kernel and Multilayer Perception. However, within their code, they don't use such different communicators. Instead, they only use an additive communicator (not mentioned in the paper; I came up with the name 'additive').
    • I have implemented 4 communicators proposed in the paper and their code:
      • Additive
      • Inner product
      • GRU
      • MLP
    • The user can choose different communicator according to their dataset.
    • These communication modules are applied during the convolution layers.
  • Final communication module
    • In the paper, the authors mention applying the same communication module to the message from incoming bonds [m(v)], current atom's representation [hK(v)] and atom's initial representation [x(v)].
    • In the section 3.3 of the paper where they describe the different node-edge message communication modules, all the communication modules operate on m(v) and hK-1(v).
    • In their code, they use MLP communication module.
    • I have kept this part the same and hardcoded the MLP communication module.

Checklist

Note

For CMPNN PyG implementation, I have used AttentiveFP as a template.

  • Add torch_geometric/nn/models/cmpnn.py
  • Add CMPNN to torch_geometric/nn/models/__init__.py
  • Add test/nn/models/test_cmpnn.py
  • Add message_booster mode to the torch_geometric/nn/aggr/multi.py:MultiAggregation class
  • Add a test for message_booster mode in test/nn/aggr/test_multi.py
  • Add examples/cmpnn.py
  • Add CMPNN to the 'Implemented GNN Models' section of the README
  • Update CHANGELOG
  • Support torch.compile
    • 2 graph breaks caused by forward method of torch_geometric/nn/models/cmpnn.py:BatchGRU class,
       > Graph break 1
       Line 298: unique_values, counts = torch.unique(batch, return_counts=True)
       GraphCompileReason(reason='dynamic shape operator: aten._unique2.default')
      
       > Graph break 2
       Line 301: dim_2 = counts.max().item()
       GraphCompileReason(reason='Tensor.item')
      
    • 1 graph break caused by scatter method of torch_geometric/utils/_scatter.py,
       Line 53: dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
       GraphCompileReason(reason='Tensor.item')
      

For CMPNN model, do you want torch.compile support within this PR or should I open a separate PR? - @rusty1s

Thank you! Please let me know if any changes are required.

@AzureLeon1
Copy link

Thank you for your efforts in implementing CMPNN. Based on my tests, this pull request correctly implements the key modules described in the original paper. One question I have is whether it would be more appropriate to separate BatchGRU from CMPNN, considering that BatchGRU serves the role of a readout function.

@devanshamin
Copy link
Author

Thank you @AzureLeon1 for taking the time to review the PR. Regarding the BatchGRU, I have few thoughts -

  1. If you look at the forward method of the AttentiveFP, it utilizes a PyTorch GRUCell while CMPNN implements BatchGRU using the PyTorch GRU.
  2. Another option would be to move it under nn.pool.

I feel BatchGRU is specific to CMPNN, but if you (or the maintainers of PyG) see a utility of having it in a separate file, I'm happy to make the necessary changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants