Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the GeoGNN model #8651

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open

Added the GeoGNN model #8651

wants to merge 5 commits into from

Conversation

jasona445
Copy link

Fixes #8626. Added the modified GIN in geognn_conv.py and two-graph GeoGNN in geognn.py from Geometry-enhanced molecular representation learning for property prediction.

Not sure if this is the right place for everything - please let me know if I can move things around / add tests or examples.

@jasona445 jasona445 changed the title Geognn GeoGNN Dec 21, 2023
Copy link

codecov bot commented Dec 21, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (11a29b0) 89.12% compared to head (f8382a6) 89.47%.
Report is 12 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #8651      +/-   ##
==========================================
+ Coverage   89.12%   89.47%   +0.35%     
==========================================
  Files         481      483       +2     
  Lines       30830    30915      +85     
==========================================
+ Hits        27477    27662     +185     
+ Misses       3353     3253     -100     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.



class GeoGNN(torch.nn.Module):
"""Modified version of GeoGNN from the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you shed some light here about its modifications and what that means?

torch_geometric/contrib/nn/models/geognn.py Outdated Show resolved Hide resolved
@rusty1s rusty1s changed the title GeoGNN Added the GeoGNN model Dec 21, 2023
jasona445 and others added 2 commits December 23, 2023 14:36
@jasona445
Copy link
Author

jasona445 commented Dec 23, 2023

Modifications are with respect to the PaddlePaddle implementation, but I think follow the main implementation described by the paper. In the PaddlePaddle model, forward expects two arguments, atom_bond_graph and bond_angle_graph.

atom_bond_graph looks like

{
  "class": "Graph",
  "num_nodes": 49534,
  "edges_shape": [144176,2],
  "node_feat": [
    {"name": "atomic_num", "shape": [49534], "dtype": "paddle.int64"}
        …
    {"name": "hybridization", "shape": [49534], "dtype": "paddle.int64"}
  ],
  "edge_feat": [
    {"name": "bond_dir", "shape": [144176], "dtype": "paddle.int64"}, 
        …
    {"name": "bond_length", "shape": [144176], "dtype": "paddle.float32"}
  ]
}

And bond_angle_graph look like this, where only one edge feature is present. The number of nodes is the same as the number of edges in atom_bond_graph, as here the bonds are the vertices and the bond-angles are the edges.

{
  "class": "Graph",
  "num_nodes": 144176,
  "edges_shape": [412564,2],
  "node_feat": [],
  "edge_feat": [
    {"name": "bond_angle",  "shape": [412564], "dtype": "paddle.float32”}
  ]
}

Inside the forward pass, features for both graphs are embedded using nn.Embedding layers for the discrete features, and RBF + nn.Linear for the continuous features. A single embedding for each node is generated by sum pooling the embeddings of each of its features. The same is done for the edges.

In this PR, I’ve pulled the embedding process out of the GeoGNN class to allow for embeddings to be done separately, and the model can be run via

model = GeoGNN(embedding_size, layer_num, dropout)

ab_graph = torch_geometric.data.Data(
    x=torch.randn(10, embedding_size),
    edge_index=torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                                [1, 2, 3, 4, 5, 6, 7, 8, 9,
                                0]]), batch=torch.tensor([0] * 10))

ba_graph = torch_geometric.data.Data(
    x=torch.randn(10, embedding_size),
    edge_index=torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                                [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]]),
    edge_attr=torch.randn(10,
                            embedding_size), batch=torch.tensor([0] * 10))

ab_graph_repr, ba_graph_repr, node_repr, edge_repr = model(ab_graph, ba_graph)

Here, the torch.randn(10, embedding_size) used for the node/edge features are assumed to be preprocessed embeddings. Atom embeddings are specified in ab_graph.x and bond and angle embeddings are specified in ba_graph.x and ba_graph.edge_attr respectively. My hope was that pulling the embedding out from the model class would make it more modular and similar to other torch_geometric.nn.model’s.

The larger change comes as some what of a consequence of pulling out the embedding stage, as the way the PaddlePaddle implementation is formed, the bond embeddings fed into each bond-angle GeoGNNBlock layer are calculated de novo from the original features instead of from the previous’ layers outputs. That is, the encodings for each layer are calculated like the figure below.

Screenshot 2023-12-23 at 2 45 47 PM

The atom embedding weights correspond to the nn.Embedding params learned for each of the atom features. This is only done once, as the outputs of the atom-bond GIN (AB-GIN) are used as the atom embeddings for subsequent layers. In contrast, there are multiple sets of bond embedding weights, as the input to each bond-angle layer is calculated by applying its corresponding nn.Embedding and nn.Linear params to the original inputs, instead of just using the outputs of the previous BA-GIN layer. The bond-angles are omitted for clarity, but they are embedded just like the bonds are.

This implementation deviated slightly from my understanding of the original message passing, which looks like this. The authors note that the AGGREGATE and COMBINE are those from the original GIN paper.
Screenshot 2023-12-22 at 4 08 30 PM
Screenshot 2023-12-22 at 4 08 39 PM

To my understanding, this corresponds to a scheme that looks more like this, where the previous BA-GIN layer’s outputs are used as inputs to the next.

Screenshot 2023-12-23 at 2 58 31 PM

This is the version that I ended up implementing, as it seemed to match with what the message passing that the paper suggested in equations (3) and (4), and allowed the preprocessing to be left to the user. I’m not sure which is better, and could attempt to implement the other version as well. I might also attempt to benchmark both. Let me know what you think.

@Takaogahara
Copy link

Takaogahara commented Feb 25, 2024

Hey!

I had a chance to take a look at the proposed model, and I must say it's pretty interesting!
Do you know if there have been any updates on the PR?

Looking forward to hearing more about it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Request to Implement GeoGNN
3 participants