You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The LLVM compiler pass uses excessive amounts of memory for deep networks which are constructed like this
stax.serial([my_layer]*depth)
In fact, the compilation may eventually OOM.
The reason is that the serial combinator internally relies on a python for loop (with carry) to support mixed input sequences.
It would be nice to have a specialization for the case in which the same layer is repeated n times, which could then use jax.lax.scan() to save compilation time by avoiding loop unrolling.
Suggestion:
import jax.example_libraries.stax as ostax
from neural_tangents._src.utils.typing import Layer, InternalLayer, NTTree
from neural_tangents._src.stax.requirements import get_req, requires, layer
from neural_tangents._src.utils.kernel import Kernel
from jax.lax import scan
import jax.numpy as np
@layer
def repeat(layer: Layer, n: int) -> InternalLayer:
"""Combinator for repeating the same layers `n` times.
Based on :obj:`jax.example_libraries.stax.serial`.
Args:
layer:
a single layer, each an `(init_fn, apply_fn, kernel_fn)` triple.
n:
the number of iterations
Returns:
A new layer, meaning an `(init_fn, apply_fn, kernel_fn)` triple,
representing the composition of `n` layers.
"""
init_fn, apply_fn, kernel_fn = layer
init_fn, apply_fn = ostax.serial(*zip([init_fn] * n, [apply_fn] * n))
@requires(**get_req(kernel_fn))
def kernel_fn_scan(k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]:
# TODO(xlc): if we drop `x1_is_x2` and use `rng` instead, need split key
# inside kernel functions here and parallel below.
k, _ = scan(lambda carry, x: (kernel_fn(carry, **kwargs), None), k, np.arange(n))
return k
return init_fn, apply_fn, kernel_fn_scan
Use like this
repeat(my_layer, depth)
The text was updated successfully, but these errors were encountered:
…168). This allows faster / less memory hungry compilation of very deep networks.
Note that compiled loops require layer to not change shapes and other static metadata. This necessitates some warnings (see docstring), and makes it less flexible than `stax.serial`.
Co-authored-by: Jens Glaser <jens.glaser@gmail.com>
PiperOrigin-RevId: 493193125
One caveat that makes this less elegant than we'd like is that kernel_fn sometimes makes non-jittable changes to the metadata of the Kernel object, and when this happens, lax.scan fails (see especially second warning), so unfortunately for now it's less flexible than stax.serial.
The LLVM compiler pass uses excessive amounts of memory for deep networks which are constructed like this
In fact, the compilation may eventually OOM.
The reason is that the
serial
combinator internally relies on a python for loop (with carry) to support mixed input sequences.It would be nice to have a specialization for the case in which the same layer is repeated
n
times, which could then usejax.lax.scan()
to save compilation time by avoiding loop unrolling.Suggestion:
Use like this
The text was updated successfully, but these errors were encountered: