Skip to content

Latest commit

 

History

History
259 lines (207 loc) · 13.5 KB

SUBSTRATES.md

File metadata and controls

259 lines (207 loc) · 13.5 KB

Substrates (JAX, NumPy)

Current as of 9/15/2020

TFP supports alternative numerical backends to TensorFlow, including both JAX and NumPy. The intent of this document is to explain some of the details of supporting disparate backends, and how we have handled them, with the objective of helping contributors read / write code, and understand how these alternative-substrate packages are assembled.

Alternative backends

In tensorflow_probability/python/internal/backend we find the implementations of both the NumPy and JAX backends. These imitate the portion of the TensorFlow API used by TFP, but are implemented in terms of the corresponding substrate's primitives.

Since JAX provides jax.numpy, we are in most cases able to write a single NumPy implementation under tensorflow_probability/python/internal/backend, then use a rewrite script (found at tensorflow_probability/python/internal/backend/jax/rewrite.py) to generate a JAX variant at bazel build time. See the genrules in tensorflow_probability/python/internal/backend/jax/BUILD for more details of how this occurs. In cases where JAX provides a different API (e.g. random) or a more performant API (e.g. batched matrix decompositions, vmap, etc.), we will special-case using an if JAX_MODE: block.

Wrinkles

  • Shapes

    In TensorFlow, the shape attribute of a Tensor is a tf.TensorShape object, whereas in JAX and NumPy, it is a simple tuple of ints. To handle cases for both systems nicely, we use from tensorflow_probability.python.internal import tensorshape_util and move special-casing into this library.

  • DTypes

    TensorFlow, and to some extent JAX, use the dtype of inputs to infer the dtype of outputs, and generally try to preserve the Tensor dtype in a binary op between Tensor and non-Tensor. NumPy, on the other hand, aggressively pushes dtypes toward 64-bit when left unspecified. In some cases debugging the JAX substrate, we have seen issues with dtypes changing from float32 to float64 or vice versa across the iterations of a while loop. Finding where this change happens can be tricky. Where possible, we aim to fix these in the internal/backend package, as opposed to implementation files.

  • Shapes, again (JAX "omnistaging" / prefer_static)

    Every JAX primitive observed within a JIT or JAX control flow context becomes an abstract Tracer. This is similar to @tf.function. The main challenge this introduces for TFP is that TF allows dynamic shapes as Tensors whereas JAX (being built atop XLA) needs shapes to be statically available (i.e. a tuple or numpy.ndarray, not a JAX ndarray).

    If you observe issues with shapes derived from Tracers in JAX, often a simple fix is from tensorflow_probability.python.internal import prefer_static as ps followed by replacing tf.shape with ps.shape, and similar for other ops such as tf.rank, tf.size, tf.concat (when dealing with shapes), (the args to) tf.range, etc. It's also useful to be aware of ps.convert_to_shape_tensor, which behaves like tf.convert_to_tensor for TF, but leaves things as np.ndarray for JAX. Similarly, in constructors, use the as_shape_tensor=True arg to tensor_util.convert_nonref_to_tensor for shape-related values.

  • tf.GradientTape

    TF uses a tape to record ops for later gradient evaluation, whereas JAX rewrites a function while tracing its execution. Since the function transform is more general, we aim to replace usage of GradientTape (in tests, TFP impl, etc), with tfp.math.value_and_gradient or similar. Then, we can special-case JAX_MODE inside the body of value_and_gradient.

  • tf.Variable, tf_keras.optimizers.Optimizer

    TF provides a Variable abstraction so that graph functions may modify state, including using the Keras Optimizer subclasses like Adam. JAX, in contrast, operates only on pure functions. In general, TFP is fairly functional (e.g. tfp.optimizer.lbfgs_minimize), but in some cases (e.g. tfp.vi.fit_surrogate_posterior, tfp.optimizer.StochasticGradientLangevinDynamics) we have felt the mismatch too strong to try to port code to JAX. Some approaches to hoisting state out of a stateful function can be seen in the TFP spinoff project oryx.

  • Custom derivatives

    JAX supports both forward and reverse mode autodifferentiation, and where possible TFP aims to support both in JAX. To do so, in places where we define a custom derivative, we use an internal wrapper which provides a function decorator that supports both TF and JAX's interfaces for custom derivatives, namely:

    from tensorflow_probability.python.internal import custom_gradient as tfp_custom_gradient
    
    def _f(..): pass
    def _f_fwd(..): return _f(..), bwd_auxiliary_data
    def _f_bwd(bwd_auxiliary_data, dy): pass
    def _f_jvp(primals, tangents): return _f(*primals), df(primals, tangents)
    @tfp_custom_gradient.custom_gradient(vjp_fwd=_f_fwd, vjp_bwd=_f_bwd, jvp_fn=_f_jvp)
    def f(..): return _f(..)

    For more information, the JAX custom derivatives doc can be useful.

  • Randomness

    In TF, we support both "stateful" (i.e. some latent memory tracks the state of the sampler) and "stateless" sampling. JAX natively supports only stateless, for functional purity reasons. For internal use, we have from tensorflow_probability.python.internal import samplers, a library that provides methods to:

    • convert stateful seeds to stateless, add salts (sanitize_seed)
    • split stateless seeds to multiple descendant seeds (split_seed)
    • proxy through to a number of stateless samplers (normal, uniform, ...)

    When the rewrite script is dealing with a ..._test.py file, it will rewrite calls to tf.random.{uniform,normal,...} to tf.random.stateless_{uniform,normal,...} to ensure compatibility with the JAX backend, which only implements the stateless samplers.

Rewriting TF code

In a couple cases, we commit into the repository script-munged source from TensorFlow. These files can be found under tensorflow_probability/python/internal/backend/numpy/gen. They currently include:

  • an implementation of tf.TensorShape
  • several parts of tf.linalg, especially the tf.linalg.LinearOperator classes

The actual rewriting is accomplished by scripts found under tensorflow_probability/python/internal/backend/meta, namely gen_linear_operators.py and gen_tensor_shape.py.

The test tensorflow_probability/python/internal/backend/numpy/rewrite_equivalence_test.py verifies that the files in TensorFlow, when rewritten, match the files in the gen/ directory. The test uses BUILD dependencies on genrules that apply the rewrite scripts, and compares those genrule inputs to the source of the files under the gen/ directory.

Similar to the sources in internal/backend/numpy, the sources in internal/backend/numpy/gen are rewritten by jax/rewrite.py. Note that the files under gen/ do not have the numpy import rewritten. This is because we only want to rewrite TensorFlow usage of TensorFlow-ported code; typically when TF code is using numpy, it is munging shapes, and JAX does not like shapes to be munged using jax.numpy (must use plain numpy).

Rewriting TFP code

With internal/backend/{numpy,jax} now ready to provide a tf2jax or tf2numpy backend, we can proceed to the core packages of TFP.

The script tensorflow_probability/substrates/meta/rewrite.py runs on TFP sources to auto-generate JAX and NumPy python source corresponding to the given TF source.

The most important job of the rewrite script is to rewrite import tensorflow.compat.v2 as tf to from tensorflow_probability.python.internal.backend.jax import v2 as tf. Second to that, the script will rewrite dependencies on TFP subpackages to dependencies on the corresponding substrate-specific TFP subpackages. For example, the line from tensorflow_probability.python import math as tfp_math becomes from tensorflow_probability.substrates.jax import math as tfp_math. Beyond that, there are a number of peripheral replacements to work around other wrinkles we've accumulated over time.

In rare cases we will put an explicit if JAX_MODE: or if NUMPY_MODE: block into the implementation code of a TFP submodule. This should be very uncommon. Whenever possible, the intent is for such special-casing to live under python/internal. For example, today we see in bijectors/softplus.py:

# TODO(b/155501444): Remove this when tf.nn.softplus is fixed.
if JAX_MODE:
  _stable_grad_softplus = tf.nn.softplus
else:
  @tf.custom_gradient
  def _stable_grad_softplus(x):  # ...

Note that this rewrite currently adds exactly a 10-line header, so line numbers from stack traces will be +10 from the raw code.

BUILD rules

tensorflow_probability/python/build_defs.bzl defines a pair of bazel build rules: multi_substrate_py_library and multi_substrate_py_test.

These rules automatically invoke tensorflow_probability/substrates/meta/rewrite.py to emit JAX/NumPy source variants. The file bijectors/softplus.py gets rewritten into bijectors/_generated_jax_softplus.py (you can view the output under the corresponding bazel-genfiles directory).

These build rules are also responsible for rewriting TFP-internal deps to the some_dep.jax or some_dep.numpy substrate-specific replacement.

The multi_substrate_py_library will emit three targets: a TF py_library with the name given by the name argument, a JAX py_library with name name + '.jax', and a NumPy py_library with name name + '.numpy'.

The multi_substrate_py_test will emit three targets, each of name + '.tf', name + '.jax', and name + '.numpy'. Rules specified by the disabled_substrates arg will not have BUILD rules emitted at all; jax_tags and numpy_tags can be used to specify specific tags that drop CI coverage while keeping the target buildable and testable. The distinction is useful so that we can track cases where we think a test should be fixable, but we haven't yet, as opposed to cases like HMC where we know the test will never pass for NumPy so we prefer to not even have the test target. All emitted test targets are aggregated into a test_suite with name corresponding to the original name arg.

In cases where we know we will never be able to support a given feature, the substrates_omit_deps, jax_omit_deps, and numpy_omit_deps args to multi_substrate_py_library can be used to exclude things. Examples include non-pure code or code using tf.Variable (JAX wants pure functions), or HMC (no gradients in NumPy!). When rewriting an __init__.py file, the rewrite script is set up to comment out imports and __all__ lines corresponding to the omitted deps.

In order to test against the same directory hierarchy as we use for wheel packaging, the multi_substrate_py_library does some internal gymnastics with a custom bazel rule which is able to add symlinks into tensorflow_probability/substrates pointing to that point to implementation files generated under bazel-genfiles/tensorflow_probability/python (details in _substrate_runfiles_symlinks_impl of build_defs.bzl).

Wheel packaging

When it comes to building the wheel, we must first use cp -L to resolve the symlinks added as part of the bazel build. Otherwise the wheel does not follow them and fails to include tfp.substrates. This cp -L command sits in pip_pkg.sh (currently adjacent to this doc).

Integration testing

A couple of integration tests sit in tensorflow_probability/substrates/meta/jax_integration_test.py and tensorflow_probability/substrates/meta/numpy_integration_test.py.

We run these under CI after building and installing a wheel to verify that the tfp.substrates packages load correctly and do not require a tensorflow install.