Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Excessive memory consumption for deep networks #168

Open
jglaser opened this issue Nov 6, 2022 · 2 comments
Open

Excessive memory consumption for deep networks #168

jglaser opened this issue Nov 6, 2022 · 2 comments
Labels
enhancement New feature or request

Comments

@jglaser
Copy link
Contributor

jglaser commented Nov 6, 2022

The LLVM compiler pass uses excessive amounts of memory for deep networks which are constructed like this

stax.serial([my_layer]*depth)

In fact, the compilation may eventually OOM.

The reason is that the serial combinator internally relies on a python for loop (with carry) to support mixed input sequences.

It would be nice to have a specialization for the case in which the same layer is repeated n times, which could then use jax.lax.scan() to save compilation time by avoiding loop unrolling.

Suggestion:

import jax.example_libraries.stax as ostax
from neural_tangents._src.utils.typing import Layer, InternalLayer, NTTree
from neural_tangents._src.stax.requirements import get_req, requires, layer
from neural_tangents._src.utils.kernel import Kernel
from jax.lax import scan
import jax.numpy as np

@layer
def repeat(layer: Layer, n: int) -> InternalLayer:
  """Combinator for repeating the same layers `n` times.

  Based on :obj:`jax.example_libraries.stax.serial`.

  Args:
    layer:
      a single layer, each an `(init_fn, apply_fn, kernel_fn)` triple.

    n:
      the number of iterations

  Returns:
    A new layer, meaning an `(init_fn, apply_fn, kernel_fn)` triple,
    representing the composition of `n` layers.
  """
  init_fn, apply_fn, kernel_fn = layer

  init_fn, apply_fn = ostax.serial(*zip([init_fn] * n, [apply_fn] * n))
  @requires(**get_req(kernel_fn))
  def kernel_fn_scan(k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]:
    # TODO(xlc): if we drop `x1_is_x2` and use `rng` instead, need split key
    # inside kernel functions here and parallel below.
    k, _ = scan(lambda carry, x: (kernel_fn(carry, **kwargs), None), k, np.arange(n))
    return k

  return init_fn, apply_fn, kernel_fn_scan

Use like this

repeat(my_layer, depth)
romanngg added a commit that referenced this issue Dec 6, 2022
…168). This allows faster / less memory hungry compilation of very deep networks.

Note that compiled loops require layer to not change shapes and other static metadata. This necessitates some warnings (see docstring), and makes it less flexible than `stax.serial`.

Co-authored-by: Jens Glaser <jens.glaser@gmail.com>
PiperOrigin-RevId: 493193125
@romanngg romanngg added the enhancement New feature or request label Dec 6, 2022
@romanngg
Copy link
Contributor

romanngg commented Dec 6, 2022

Thanks for the suggestion, please check out the added https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.repeat.html#neural_tangents.stax.repeat

One caveat that makes this less elegant than we'd like is that kernel_fn sometimes makes non-jittable changes to the metadata of the Kernel object, and when this happens, lax.scan fails (see especially second warning), so unfortunately for now it's less flexible than stax.serial.

@jglaser
Copy link
Contributor Author

jglaser commented Dec 6, 2022

awesome, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants