Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Implement model Unimp #83

Open
wants to merge 25 commits into
base: main
Choose a base branch
from

Conversation

yangcd-bupt
Copy link

support ogbn dataset

Description

Checklist

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [$CATEGORY] (such as [NN], [Model], [Doc], [Feature]])
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented
  • To the best of my knowledge, examples are either not affected by this change,
    or have been fixed to be compatible with this change
  • Related issue is referred in this PR

Changes

@yangcd-bupt
Copy link
Author

get ogb node dataset via OgbNodeDataset class

Copy link
Contributor

@Zhanghyi Zhanghyi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加一些测试代码,对于graph,link,node 各自选择一个最小的数据集写一下单元测试的代码 放在 https://github.com/BUPT-GAMMA/GammaGL/tree/main/tests/datasets 下面

import os.path as osp
import numpy as np
from gammagl.data import InMemoryDataset
from gammgl.utils.ogb_url import decide_download, download_url, extract_zip
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gammgl -> gammagl

@@ -0,0 +1,91 @@
import urllib.request as ur
Copy link
Contributor

@Zhanghyi Zhanghyi Sep 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# check if previously-downloaded folder exists.
# If so, use that one.
if osp.exists(osp.join(root, self.dir_name + '_pyg')):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyg相关的代码改成gammagl

print(data)
print(data[0])

test()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace the function name "test" with "test_ogbgraphdataset", which makes it easier to distinguish the functions from other files.

data=OgbLinkDataset('ogbl-ppa')
print(data[0])

test()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

data=OgbNodeDataset('ogbn-arxiv')
print(data[0])

test()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

from gammagl.data import InMemoryDataset
from gammagl.data.download import download_url
from gammagl.data.extract import extract_zip
from gammagl.io.read_ogb import read_node_label_hetero, read_graph, read_heterograph, read_nodesplitidx_split_hetero
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中文逗号

print(data)
print(data[0])

test()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Replace the function name "test" with "test_ogbgraphdataset", which makes it easier to distinguish the functions from other files.
  2. Using "assert" to check the correctness of some variable. e.g. the "feature.shape[0]" and "num_nodes" should be equal.

data=OgbLinkDataset('ogbl-ppa')
print(data[0])

test()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

data=OgbNodeDataset('ogbn-arxiv')
print(data[0])

test()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

return loss


class MultiHead(MessagePassing):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this section into gammagl.layers.conv.unimp_conv.py. If you put this into this file, users will not be able to use this function.

return x


class Unimp(tlx.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this section into gammagl.models.unimp.py. If you put this into this file, users will not be able to use this function.

super(Unimp, self).__init__()

out_layer1=int(dataset.num_node_features/2)
self.layer1=MultiHead(dataset.num_node_features+1, out_layer1, 4,dataset[0].num_nodes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think users can choose their n_heads, so change this to let users choose their own n_heads instead of a fixed one.

self.norm1=nn.LayerNorm(out_layer1)
self.relu1=nn.ReLU()

self.layer2=MultiHead(out_layer1, dataset.num_classes, 4,dataset[0].num_nodes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

import tlx.nn as nn
from gammagl.layers import MultiHead

class Unimp(tlx.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add rst doc here refer to this link: https://docs.qq.com/pdf/DUXRTTU9tUnB1WnFB.

@dddg617 dddg617 changed the title Create ogb_node.py [Model] Implement model Unimp Oct 28, 2022
return loss


def forward(self, x, edge_index):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why there are two forward functions?


alpha = self.dropout(segment_softmax(weight, node_dst, self.num_nodes))
x = tlx.gather(x, node_src) * tlx.expand_dims(alpha, -1)
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function means Unimp? So what is the difference between this model and model GAT?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants