Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input shape for the GraphSAGE layer and sparse tensor construction #423

Open
zzzy0828 opened this issue Mar 13, 2023 · 0 comments
Open

Input shape for the GraphSAGE layer and sparse tensor construction #423

zzzy0828 opened this issue Mar 13, 2023 · 0 comments

Comments

@zzzy0828
Copy link

Hi!
Recently I want to use the GraphSAGE layer to extract spatial information between different sites, each site has relevant features. Problems I encountered are as follows.
The main code is here.
''''''''''''''''''
def x_reshape(u):
return backend.reshape(u, (-1, n_site, fea)) # 自定义Reshape函数1---收time_step(x)
def a_reshape(v):
return backend.reshape(v, (-1, n_site, n_site)) # 自定义Reshape函数2---收time_step(a)

model construction

x_in = Input(shape=(time_step, n_site, fea))
a_in = Input(shape=(time_step, n_site, n_site))
x_in_v = Lambda(x_reshape, output_shape=(n_site, fea))(x_in)
a_in_v = Lambda(a_reshape, output_shape=(n_site, n_site))(a_in)

spatial information passing

output = GraphSageConv(64, activation='relu')([x_in_v, a_in_v])
output = GraphSageConv(64, activation='relu')([output, a_in_v]) # two layers for spatial feature extraction
''''''''''''''''''
1.How can I construct the 'a_in_v' into a sparse tensor? I tried many ways, but failed.
2.The input of the GraphSAGE layer is defined as:
Input
- Node features of shape (n_nodes, n_node_features);
- Binary adjacency matrix of shape (n_nodes, n_nodes).
But when I try a Input(shape = (n_site, n_site), sparse=True), it remainds me "AssertionError: A must have rank 2". Is this a contradict? How can I handle it.

The reshape function is to convert the 3D input into 2D input, which is designed for graph convolution operation.
Questions above don't exist when I use a GCN or GAT layer for feature extraction.
Really need your help.

Expect for your reply.
Thank you very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant