Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade visualize_graph for explain module #8743

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

Sutongtong233
Copy link
Contributor

For current visualize_graph, the node colors are default 'white', which is not as intuitive as in GNNExplainer: target node is red, different node class with different node color.
We add three optional params:

  • node_label: node class label
  • color_dict: node color for each node class
  • target_node: target node to explain for node-level explanation
  • draw_node_idx: if or not draw node index. For some cases, it is not necessary, and when nodes number is large, the visualization will be messy

We add lines at two examples examples/explain/gnn_explainer.py and examples/explain/gnn_explainer_ba_shapes.py

I am not sure about how to add optional params. Currently we add them at **kwargs.

Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for improving the visualization.
Could you please post some images of how things looked before and after your change.
I've left some comments.

@@ -233,7 +233,7 @@ def visualize_feature_importance(
return _visualize_score(score, feat_labels, path, top_k)

def visualize_graph(self, path: Optional[str] = None,
backend: Optional[str] = None):
backend: Optional[str] = None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
backend: Optional[str] = None, **kwargs):
backend: Optional[str] = None,
nodel_label: Optional[Tensor] = None,
colors_dict: Optional[Dict[int, str] = None,
target_idx: Optional[int]=None):

Lets add the new arguments as optional arguemnts and add documentation for them. That way the end user is aware of the options available to them.

if target_node != None:
for i, node_id in enumerate(list(g.nodes)):
if node_id == target_node:
print("kylin")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("kylin")

@@ -127,10 +134,26 @@ def _visualize_graph_via_networkx(
),
)

node_color = ['white'] * len(g.nodes)
if node_label != None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if node_label is None won't all nodes be white?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants