Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intel/rect hgam homo and hetero - draft for HGAM based on rectangular adj matrix in homogeneous and heterogeneous case (CSR only) #8897

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from

Conversation

andreazanetti
Copy link
Contributor

@andreazanetti andreazanetti commented Feb 12, 2024

Sharing this code just to discuss the idea.
Not to be merged.

As per the title, this is a draft of the rectangular adj matrix HGAM implementation -for homogeneous and heterogenous data, CSR flow- which computes only the representations needed to proceed with the next step in the computation of the GNN.

Overall idea for rect adj HGAM
The overall idea is about identifying 2 sets of nodes at each layer.
S1: The set of nodes which are necessary for the computation of the sought nodes representation at that layer (our "inputs")
S2: The set of nodes which we need to compute a new representation for (our "outputs")

S2 is a subset of S1 in general (at least assuming that the current representation of each node is used to compute its next layer representation)

Therefore S1 populates the columns of adj_t, whereas S2 populates the rows of adj_t (edge_index)
In the HGAM current implementation in PyG, at each layer we identify S1 and the representations for all the nodes in S1 are computed, since we use a square adj matrix, which gets trimmed as we proceed through the layers, moving from S1 at a layer to the S1 of the next layer. That is why in the current implementation HGAM does not act if layer == 0, because all the nodes in the batch graph are in S1, at the first layer.
Although that approach is not optimal, it is much better than computing representations for all nodes in the graph batch at all layers.

Optimality is reached when at each layer we compute only the representations for nodes from the S2 set, for that layer (which will become S1 for the next layer).
That is what we are doing here in the case of homo (and hetero as well), sage_conv, CSR.
Hope it makes sense.

Thanks in advance for your comments.
Contributors: @rBenke (robert.benke@intel.com)

@rusty1s
Copy link
Member

rusty1s commented Feb 20, 2024

The PR makes sense to me. I would prefer to make this behavior optional though (True by default), since not all layers in PyG support bipartite message passing.

@andreazanetti
Copy link
Contributor Author

andreazanetti commented Feb 20, 2024

Thanks for your feedback!
However, I am not sure I fully understand what you mean, apologies.

  1. Bipartite message passing
    I assumed that bipartite message passing already happens frequently in the case of hetero-data (except maybe the case of edges between the same type of nodes).

I initially thought that the issues with pyg operators would be connected to the need of "narrowing" the feature vector x in all those cases of operators in which the original node representation in input is reused directly to compute the final (new) output node representation (I guess when root_weight is True and we sum a projection of x to the result of the propagation, like we do torch_geometric/nn/conv/sage_conv.py line 142)

I am not clear how making the adj_mat rectangular would affect pyg operators in other ways.

  1. Making this an optional behaviour.
    HGAM is still an optional behaviour.
    Would you like to allow users to choose from 3 options then?
    a) default pyg behaviour
    b) HGAM with square adj_matrix (==>no modification to the operators)
    c) HGAM with rect adjacency matrix (==>modification to some operators to deal with the reduced dimensionality of x)
    or just a) and c)?

Hope it makes sense.
Thanks

@rusty1s
Copy link
Member

rusty1s commented Feb 20, 2024

Sorry, I was solely talking about homogeneous case (so please ignore my comment). For heterogeneous, we can just do (c).

@andreazanetti
Copy link
Contributor Author

Thanks Matthias.
Just to understand what you have in mind, and avoid misunderstandings:

a) for homogeneous data you would like the user to be free to choose from the 3 options above
b) for heterogenous data the user is not expected to make any choices, only c) is supported.

Did I get your thinking right?

@rusty1s
Copy link
Member

rusty1s commented Feb 21, 2024

Yes, that is what I had in mind, but please feel free to drop (b) completely if you feel it is unnecessary.

@andreazanetti
Copy link
Contributor Author

andreazanetti commented Feb 22, 2024

Hi, thanks for the feedback.

Currently, using HGAM requires the user to pass the HGAM metadata explicitly to the model.
These metadata are returned by the loader at each iteration.
To make HGAM the default mode, it would mean to make sure that those metadata are passed with no user intervention.
Looks to me that this would require a bit of code restructuring, possibly in the way the loader packs the data returned.
Therefore i would like to propose to divide the work in 2 parts.

First Part:
Clean and complete this PR which evolves HGAM to use the rectangular adjacency matrix for CSR data only (according to initial @rBenke 's initial idea).
It does that for both homo and hetero case, leaving HGAM as an option the user can set to "True".
Default will remain "False" for both homo and hetero.
Please note that this PR needs to be completed with the modification of all the conv operators that reuse the nodes features in computation parallel to the message passing part, like we do here for sage_conv.
In other words, the first part would enable option a) and c) for both hetero and homo.
Option b) would be dropped.
HGAM still not the default, neither for homo nor for hetero.

Second Part:
We could consider to make HGAM the default option, just for hetero for example, as per the discussion above (option c) for hetero)
However, for the HGAM to work, the HGAM metadata the neighbor loader returns at each iteration is to be passed to the model forward method.
In the current code structure, it looks to me that the user level (so the training script) has to pass these metadata explicitly to the model forward method.
Conversely, in order to pass the metadata implicitly, one option would be to pack the metadata with batch.adj_t or batch.edge_index, making them tuples.
Not yet clear if this is a good way.
That might cause a number of changes somewhere else in the code. Not clear for now.

Please share your thoughts, thanks!

@andreazanetti
Copy link
Contributor Author

I am going on with the first part as per above.
However, there are 64 conv operators in PyG.
I will need some time to review all of them.
I will leave this PR open, if that's ok.
Thanks

@rusty1s
Copy link
Member

rusty1s commented Feb 26, 2024

@andreazanetti I see. Do you think we need to update all the layers in order to support HGAM++? Would bipartite input not be an option as well?

@andreazanetti
Copy link
Contributor Author

Hi Matthias, thanks for your reply.
In the way _trim_to_layer.py operates in this PR, for both home and hetero input graph(s) represented in CSR, the adjacency matrix (-ces) is non square, and as side effect we have that the dimensionality of the node feature vector given in input to a conv operator, does not correspond to the dimension of the output of the propagate which now returns the new representations only for the nodes necessary at the next layer.
So each conv operator that reuses the original node representation to compute the new one should be updated.
That would translate in "narrowing" the x to the dimension of out obtained from propagate.

Pretty much like we do here in sage_conv.py:
if self.root_weight:
x_r = torch.narrow(x_r, 0, 0, out.shape[0])
out = out + self.lin_r(x_r)

And looking at the code in torch_geometric/nn/conv it looks to me that most of them are.
There is also the need to change the tests, for _trim_to_layer and for, I guess, each operator.

Now, I am not sure I understand how making the input bipartite would help. Apologies, I do not see it :)
Could you share more about that idea? thanks

@rusty1s
Copy link
Member

rusty1s commented Feb 27, 2024

Got it. Initially, I assumed we can just do conv((x_src, x_dst), edge_index), which would be equivalent but already has built in narrowing. I am fine with adding the narrow logic to PyG operators, but this sounds like a lot of work :)

@andreazanetti
Copy link
Contributor Author

andreazanetti commented Feb 27, 2024

Oh, thanks @rusty1s, I did not realize that we could leverage the presence of x: Union[Tensor, OptPairTensor] in the forward method for this.

  1. So I assume that you mean the conv((x_src, x_dst), edge_index) is meant to replace the conv calls (e.g. x = conv(x, edge_index)) in the forward method of the BasicGNN class.

  2. If I understand the point here, and assuming the src/dst naming is based on the inbound/outbound direction of edges when building neighborhood at each layer:

  • x_dst is a matrix num_target_nodes_for_that_layer X num_node_features, which contains the current node representations for the target nodes only.
  • x_src is a matrix that might include x_dst (in the homo case for example) whose dimensions are num_nodes_in_input_at_that_layer X num_node_features, and contains the representations of all the nodes needed as input for the current layer.
  1. With each conv working with a pair of tensors as input, we will require the trimming part to return -besides the rectangular adj_matrix correctly sized at each layer- the trimmed and non-trimmed node features matrices, in contrast to returning only the non-trimmed as it does now.

  2. Looking at the sage_conv class implementation, the node representation tuple as input should contain:
    x[0] = non-trimmed node features matrix (used as input for message passing)
    x[1] = trimmed node features matrix (used as input for linear projection whose result will be added to the message passing output)

So with my definitions above, I guess the conv call should look like: conv((x_src, x_dst), edge_index), which matches your suggestion.
I hope all this makes sense.
I will go on working along this lines (so no modification of all conv operators! :) )
Unless you tell me I missed your point.
(thanks @rBenke (robert.benke@intel.com) for the help)

@andreazanetti
Copy link
Contributor Author

andreazanetti commented Feb 28, 2024

I went on and modified _trim_to_layer.py and other things to support the conv((x_src, x_dst), edge_index)
It just returns a view of the node feature matrix, so there should not be performance change.

It works fine for homo, I have tested with hierarchical_sampling.py example.
Results are showing the expected behavior, in terms of performance and functional perspective.

One epoch training without Hierarchical Graph Sampling: (NO HGAM)
100%|████████████████████████████████████| 150/150 [00:29<00:00, 5.17it/s]
avg_loss is: 0.47911528728326375

One epoch training with Hierarchical Graph Sampling: (WITH HGAM rect_adj_mat)
100%|████████████████████████████████████| 150/150 [00:17<00:00, 8.81it/s]
avg_loss is: 0.48629561240452207

However, for the hetero case it is a bit more delicate.
I guess a change to the 'HeteroConv' class will be needed, which does not seem to support the tuple as input.
I am not clear what the impact of a change to that would be in the overall.
Working on that. Any suggestions/comments will be of help, thanks!

@andreazanetti
Copy link
Contributor Author

The hetero case seems to work now.
Running the hierarchical_sampling_hetero.py it shows improvement, when the batch size is big and the neighborhood is not too small. With the values in the code the training on OGB-mag appears to be ~30% faster

HGAM False ==> 20.09it/s
HGAM True ==> 26.96it/s

In this approach, for my convenience I commented out the part connected to **kwargs in hetero_conv.py
This cannot be a final version, but I guess it is a matter of deciding a nice way to pass what I called xra_dict, to renable the support for **kwargs.
Happy to receive suggestions/comments.
Thanks.

@andreazanetti
Copy link
Contributor Author

andreazanetti commented Mar 11, 2024

Updating the tests for trim_to_layer I realized that the support for HGAM with rectangular adj matrix for CSR (hetero and homo) broke the HGAM with square matrix for COO, hetero and homo (which is the only one we are planning to support for COO)
:(
I have fixed it in my local branch and I will update this on github as soon as I have updated also the test_trim_to_layer.py file.

UPDATE: no, I have found other issues with the cohesistence of HGAM with square adj matrix for COO and HGAM with rectangular adjacency matrix for CSR. I will get back to this as soon as possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants