Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing XGNN: Towards Model-Level Explanations of Graph Neural Networks #8618

Open
wants to merge 139 commits into
base: master
Choose a base branch
from

Conversation

SimonBele
Copy link

The XGNN approach, as described in the research paper "XGNN: Towards Model-Level Explanations of Graph Neural Networks," \cite{yuanXGNNModelLevelExplanations2020} offers a novel method to interpret GNNs at a model-level. In this project, we implemented the XGNN approach using PyTorch Geometric and integrated it into their explainability module.

The key methods outlined in the source paper revolve around the concept of training a graph generator to explain the behaviour of GNNs at a model-level. This generative model is designed to generate graph patterns that maximise a specific prediction of the GNN model we are trying to explain. It accomplishes this through reinforcement learning, with the generator determining how to add edges and nodes to the current graph at each step. This graph generator is trained using policy gradient methods informed by the predictions of the trained GNN we are explaining. Additionally, the paper advocates for the incorporation of specific graph rules to ensure that the generated graphs adhere to predefined criteria for validity.

Our implementation brings an algorithm of GNNs to PyG and, to the best of our knowledge, is general enough to be extended for different types of graph generative models, and not limited to the reinforcement learning approach employed in the paper.

MLG_new_diagram_border

The current explanations which are returned from PyG's existing explanation algorithms are based on node/edge/feature masks and while those are useful, they are not suitable to be used as a base class for explanations which do not fall into that category.

In our case, the explanations we provide per the XGNN algorithm are sets of graphs which maximise the prediction of the model we are trying to explain, in essence trying to understand which graph patterns it actually understands as belonging to a certain class.

We choose not to undertake the refactoring of the Explanation class into a better base class and extending it for the sake of node/edge/feature mask explanations, as this is not within the scope of our contribution. However, we hope that through this contribution we motivate future work to in fact perform this restructuring of the base classes such that they are suitable for any types of explanations.

We incorporate the use of the existing Data class, extending it for our GenerativeExplanation, similarly to what has been done in the current Explanation class.

We modify the existing Explainer class such that it supports the explanation type of generative explanations as well as is able to handle these accordingly, while maintaining existing functionality. We believe it would be best to further refactor the Explainer class as it is also too constrained in its formulation towards node/edge/feature masks but it was general enough for us to still use for the sake of this contribution.

To aid clarity, we define the notion of an "Explanation Set" as a collection of data points that are key to maximising a network's activation. This set effectively represents specific inputs that lead to the highest response from the neural network, offering insights into its behaviour.

For a particular pretrained model, we offer an easily extendable interface through the XGNNExplainer class, which extends the ExplainerAlgorithm base class (it was general enough as a base class to be suitable to use). This allows users to define how to train a generative model of their choosing (extending our ExplanationSetSampler class), capable of sampling from the dataset that maximises the pretrained model's predictions (termed as the model's explanation set). We provide this sampler to our GenerativeExplanation for the sake of retrieving the explanation set.

The GenerativeExplanation is designed to be general; it obtains the explanation set solely from the sample() method of the ExplanationSetSampler. This design permits, for instance, an extension in which the sampler contains a static dataset. In such a scenario, the sample method could simply return the entire dataset, select samples from this static dataset, or implement any other approach envisioned by the user. Alternatively, this framework accommodates any type of generative process for retrieving the explanation set, provided it includes a mechanism for sampling from it.

We provide an example of using these classes (under explainer algorithm examples), specifically to mimic the experiment done in the XGNN: Towards Model-Level Explanations of Graph Neural Networks paper through their idea of using a reinforcement learning approach to their graph generator.

We start with a specific graph, which will represent the state of this reinforcement learning environment, following this we take a number of actions or steps under which we train the network to figure out which actions are favorable for our given graph and our specified reward function. The actions here specifically are whether to connect two existing nodes in the graph or to create a new node from the possible candidate node types and connect it to one of the existing nodes. The policy in this reinforcement learning environment is the neural network in our graph generator, which based on our current graph state outputs probabilities of the possible actions that maximise our reward function. The reward at a certain step in the generation process of the graph determines whether the action under those conditions was favourable.
The reward is divided into two parts. Firstly, for the current graph, we pass it through the model we are explaining and retrieve the predicted probability that our generated graph is actually of the class we are generating for. The paper refers to this as the intermediate reward, to which we then add the final reward, which is obtained by performing rollout on the current graph, meaning for each of these rollout steps we perform an action and we similarly evaluate the updated graph's predicted probability, inherently trying out several one-step actions. Secondly the paper advocates for incorporating validity rules to be employed to check whether the graph is even valid given the domain of the problem we are trying to explain. In the case of the MUTAG dataset which we used, the degree of the node is checked against which type of atom it is and a penalty is employed if the degree exceeds the maximal chemical valency of that type of atom.

GraphGenerator extends our ExplanationSetSampler class and implements the network that is trained to generate our graphs that maximize the prediction of a certain class.

In the example we also provided code to render the graphs generated through this method.

Copy link

codecov bot commented Dec 14, 2023

Codecov Report

Attention: 15 lines in your changes are missing coverage. Please review.

Comparison is base (74f7cba) 90.12% compared to head (5e31373) 89.46%.

Files Patch % Lines
torch_geometric/explain/explanation.py 53.57% 13 Missing ⚠️
...orch_geometric/explain/algorithm/xgnn_explainer.py 92.30% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #8618      +/-   ##
==========================================
- Coverage   90.12%   89.46%   -0.66%     
==========================================
  Files         481      482       +1     
  Lines       31055    31110      +55     
==========================================
- Hits        27988    27833     -155     
- Misses       3067     3277     +210     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants