Skip to content

zombie-einstein/jaxpr-viz

Repository files navigation

Jaxpr-Viz

JAX Computation Graph Visualisation Tool

JAX has built-in functionality to visualise the HLO graph generated by JAX, but I've found this rather low-level for some use-cases.

The intention of this package is to visualise how sub-functions are connected in JAX programs. It does this by converting the JaxPr representation into a pydot graph. See here for examples.

NOTE: This project is still at an early stage and may not support all JAX functionality (or permutations thereof). If you spot some strange behaviour please create a Github issue.

Installation

Install with pip:

pip install jpviz

Dependent on your system you may also need to install Graphviz

Usage

Jaxpr-viz can be used to visualise jit compiled (and nested) functions. It wraps jit compiled functions, which when called with concrete values returns a pydot graph.

For example this simple computation graph

import jax
import jax.numpy as jnp

import jpviz

@jax.jit
def foo(x):
    return 2 * x

@jax.jit
def bar(x):
    x = foo(x)
    return x - 1

# Wrap function and call with concrete arguments
#  here dot_graph is a pydot object
dot_graph = jpviz.draw(bar)(jnp.arange(10))
# This renders the graph to a png file
dot_graph.write_png("computation_graph.png")

produces this image

bar computation graph

Pydot has a number of options for rendering graphs, see here.

NOTE: For sub-functions to show as nodes/sub-graphs they need to be marked with @jax.jit, otherwise they will just merged into thir parent graph.

Jupyter Notebook

To show the rendered graph in a jupyter notebook you can use the helper function view_pydot

...
dot_graph = jpviz.draw(bar)(jnp.arange(10))
jpviz.view_pydot(dot)

Visualisation Options

Collapse Nodes

By default, functions that are composed of only primitive functions are collapsed into a single node (like foo in the above example). The full computation graph can be rendered using the collapse_primitives flag, setting it to False in the above example

...
dot_graph = jpviz.draw(bar, collapse_primitives=False)(jnp.arange(10))
...

produces

bar computation graph

Show Types

By default, type information is included in the node labels, this can be hidden using the show_avals flag, setting it to False

...
dot_graph = jpviz.draw(bar, show_avals=False)(jnp.arange(10))
...

produces

bar computation graph

NOTE: The labels of the nodes don't currently correspond to argument/variable names in the original Python code. Since JAX unpacks arguments/outputs to tuples they do correspond to the positioning of arguments and outputs.

Examples

See here for more examples of rendered computation graphs.

Developers

Developer notes can be found here.