Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OGB graph prop pred examples torch-mlir compatibility #436

Open
dvlp-r opened this issue Apr 18, 2023 · 0 comments
Open

OGB graph prop pred examples torch-mlir compatibility #436

dvlp-r opened this issue Apr 18, 2023 · 0 comments

Comments

@dvlp-r
Copy link

dvlp-r commented Apr 18, 2023

Hi, I am opening this issue to ask a question. I am trying to use one of your examples about graph classification (https://github.com/snap-stanford/ogb/tree/master/examples/graphproppred/mol).

What I have recently done is modify the example to make it possible to successfully create a TorchScript of the model. This because I am trying to lowering down the model to torch-mlir. When trying to do it I encounter an error which, as stated by torch-mlir devs, means that my model has a tuple like x=(0,0). They suggested me to try to change this tuple with a list, like x=[0,0].
Unfortunately, I am new into this and I have not been able to spot the problem. I leave the torch-mlir error here for completeness.

Exception: 
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
### Importer C++ Exception:
see diagnostics
### Importer Diagnostics:
error: unhandled prim::Constant node: %37 : (int, int) = prim::Constant[value=(0, 0)]()

Can you please help me to spot this tuple in order to make your models compatible with torch-mlir ?
I leave here the files of your model I am using. Some changes has been done only to make it compatible with TorchScript (and, for simplicity, only the code for the GIN model has been preserved).

Thank you in advance for your help.

main.py

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

import sys

from PIL import Image
import requests

import torch
from torchvision import transforms
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from gnn import GNN
import torch.optim as optim

from ogb.graphproppred import PygGraphPropPredDataset, Evaluator

import torch_mlir
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend


def train(model, device, loader, optimizer, task_type):
    model.train()

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            pass
        else:
            pred = model(batch)
            optimizer.zero_grad()
            ## ignore nan targets (unlabeled) when computing training loss.
            is_labeled = batch.y == batch.y
            if "classification" in task_type:
                loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
            else:
                loss = reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
            loss.backward()
            optimizer.step()


def eval(model, device, loader, evaluator):
    model.eval()
    y_true = []
    y_pred = []

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        if batch.x.shape[0] == 1:
            pass
        else:
            x, edge_index, edge_attr, batch_f = batch.x, batch.edge_index, batch.edge_attr, batch.batch
            with torch.no_grad():
                pred = model(x, edge_index, edge_attr, batch_f)

            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()

    input_dict = {"y_true": y_true, "y_pred": y_pred}

    return evaluator.eval(input_dict)


def predictions(torch_model, jit_model):
    pytorch_prediction = eval(torch_model, device, test_loader, evaluator)
    print("PyTorch prediction")
    print(pytorch_prediction)
    mlir_prediction = eval(jit_model, device, test_loader, evaluator)
    print("torch-mlir prediction")
    print(mlir_prediction)


cls_criterion = torch.nn.BCEWithLogitsLoss()
reg_criterion = torch.nn.MSELoss()

### automatic data loading and splitting
dataset = PygGraphPropPredDataset(name='ogbg-molhiv')

split_idx = dataset.get_idx_split()

### automatic evaluator. takes dataset name as input
evaluator = Evaluator("ogbg-molhiv")

train_loader = DataLoader(dataset[split_idx["train"]], batch_size=1, shuffle=True,
                          num_workers=0)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=1, shuffle=False,
                          num_workers=0)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=1, shuffle=False,
                         num_workers=0)

gin = GNN(gnn_type='gin', num_tasks=dataset.num_tasks, num_layer=5, emb_dim=300,
          drop_ratio=0.5).to("cpu")

optimizer = optim.Adam(gin.parameters(), lr=0.001)

device = torch.device("cpu")

train(gin, device, train_loader, optimizer, dataset.task_type)
eval(gin, device, valid_loader, evaluator)

gin.eval()

for step, batch in enumerate(tqdm(test_loader, desc="Iteration")):
    batch = batch.to(device)
    x, edge_index, edge_attr, batch_f = batch.x, batch.edge_index, batch.edge_attr, batch.batch
    module = torch_mlir.compile(gin, (x, edge_index, edge_attr, batch_f), output_type="linalg-on-tensors")
    break

backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
jit_module = backend.load(compiled)

predictions(gin, jit_module)

gnn.py

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from torch_geometric.nn.inits import uniform

from conv import GNN_node

from torch_scatter import scatter_mean
import time


class GNN(torch.nn.Module):

    def __init__(self, num_tasks=10, num_layer=5, emb_dim=300,
                 gnn_type='gin', residual=False, drop_ratio=0.5, JK="last", graph_pooling="mean"):
        """
            num_tasks (int): number of labels to be predicted
            virtual_node (bool): whether to add virtual node or not
        """

        super(GNN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.graph_pooling = graph_pooling

        ### GNN to generate node embeddings
        self.gnn_node = GNN_node(num_layer, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual,
                                 gnn_type=gnn_type)

        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(
                gate_nn=torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
                                            torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, 1)))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps=2)
        else:
            raise ValueError("Invalid grcd ..aph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = torch.nn.Linear(2 * self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, batch) -> torch.Tensor:

        h_node = self.gnn_node(x, edge_index, edge_attr)

        h_graph = self.pool(h_node, batch)

        return self.graph_pred_linear(h_graph)


if __name__ == '__main__':
    GNN(num_tasks=10)

conv.py

import torch
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, global_add_pool
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from torch_geometric.utils import degree

import math
import time


### GIN convolution along the graph structure
class GINConv(MessagePassing):
    propagate_type = {'x': torch.Tensor, 'edge_attr': torch.Tensor}

    def __init__(self, emb_dim):
        '''
            emb_dim (int): node embedding dimensionality
        '''

        super(GINConv, self).__init__(aggr="add")

        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
                                       torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim))
        self.eps = torch.nn.Parameter(torch.Tensor([0]))

        self.bond_encoder = BondEncoder(emb_dim=emb_dim)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        edge_embedding = self.bond_encoder(edge_attr)
        assert isinstance(edge_embedding, torch.Tensor)
        return self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding, size=None))

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out


### GNN to generate node embedding
class GNN_node(torch.nn.Module):
    """
    Output:
        node representations
    """

    def __init__(self, num_layer, emb_dim, drop_ratio=0.5, JK="last", residual=False, gnn_type='gin'):
        '''
            emb_dim (int): node embedding dimensionality
            num_layer (int): number of GNN message passing layers
        '''

        super(GNN_node, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        ### add residual connection or not
        self.residual = residual

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        ###List of GNNs
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for _ in range(num_layer):
            self.convs.append(GINConv(emb_dim).jittable())

            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        ### computing input node embedding

        h_list = [self.atom_encoder(x)]

        for layer, (conv, norm) in enumerate(zip(self.convs, self.batch_norms)):

            assert isinstance(h_list[layer], torch.Tensor)
            h = conv(h_list[layer], edge_index, edge_attr)
            h = norm(h)

            if layer == self.num_layer - 1:
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        ### Different implementations of Jk-concat
        #if self.JK == "last":
        node_representation = h_list[-1]
        #elif self.JK == "sum":
        #    node_representation = 0
        #    for layer in range(self.num_layer + 1):
        #        node_representation += h_list[layer]

        return node_representation



if __name__ == "__main__":
    pass

mol_encoder.py

import torch
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims 

full_atom_feature_dims = get_atom_feature_dims()
full_bond_feature_dims = get_bond_feature_dims()

class AtomEncoder(torch.nn.Module):

    def __init__(self, emb_dim):
        super(AtomEncoder, self).__init__()
        
        self.atom_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_atom_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.atom_embedding_list.append(emb)

    def forward(self, x):
        x_embedding = 0
        x_embedding += self.atom_embedding_list[0](x[:,0])
        x_embedding += self.atom_embedding_list[1](x[:,1])
        x_embedding += self.atom_embedding_list[2](x[:,2])
        x_embedding += self.atom_embedding_list[3](x[:,3])
        x_embedding += self.atom_embedding_list[4](x[:,4])
        x_embedding += self.atom_embedding_list[5](x[:,5])
        x_embedding += self.atom_embedding_list[6](x[:,6])
        x_embedding += self.atom_embedding_list[7](x[:,7])
        x_embedding += self.atom_embedding_list[8](x[:,8])

        return x_embedding


class BondEncoder(torch.nn.Module):
    
    def __init__(self, emb_dim):
        super(BondEncoder, self).__init__()
        
        self.bond_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_bond_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.bond_embedding_list.append(emb)

    def forward(self, edge_attr):
        bond_embedding = 0
        bond_embedding += self.bond_embedding_list[0](edge_attr[:,0])
        bond_embedding += self.bond_embedding_list[1](edge_attr[:,1])
        bond_embedding += self.bond_embedding_list[2](edge_attr[:,2])

        return bond_embedding   


if __name__ == '__main__':
    from loader import GraphClassificationPygDataset
    dataset = GraphClassificationPygDataset(name = 'tox21')
    atom_enc = AtomEncoder(100)
    bond_enc = BondEncoder(100)

    print(atom_enc(dataset[0].x))
    print(bond_enc(dataset[0].edge_attr))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant