Skip to content

Commit

Permalink
Generic reduce window jvp
Browse files Browse the repository at this point in the history
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.

However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).

For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.

In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition

PiperOrigin-RevId: 627085500
  • Loading branch information
jax authors committed May 8, 2024
1 parent 335f27b commit c24ccd1
Show file tree
Hide file tree
Showing 3 changed files with 502 additions and 88 deletions.
248 changes: 197 additions & 51 deletions jax/_src/lax/windowed_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,50 +19,61 @@
from typing import Callable
import warnings

import numpy as np

from jax import tree_util

from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import util
from jax._src.core import ShapedArray, ConcreteArray
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import lax
from jax._src.lax import convolution
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
from jax._src.typing import Array
import numpy as np
from jax._src.core import ClosedJaxpr
from jax._src.core import jaxpr_as_fun
from jax._src.interpreters.ad import jvp_jaxpr
from jax._src import ad_util

map = util.safe_map
zip = util.safe_zip


def reduce_window(operand, init_value, computation: Callable,
window_dimensions: core.Shape, window_strides: Sequence[int],
padding: str | Sequence[tuple[int, int]],
base_dilation: Sequence[int] | None = None,
window_dilation: Sequence[int] | None = None) -> Array:
def _reduce_window(
operand,
init_value,
computation,
window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: str | Sequence[tuple[int, int]],
base_dilation: Sequence[int] | None = None,
window_dilation: Sequence[int] | None = None,
):
"""Wraps XLA's `ReduceWindowWithGeneralPadding
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
operator.
"""
flat_operands, operand_tree = tree_util.tree_flatten(operand)
flat_init_values, init_value_tree = tree_util.tree_flatten(init_value)
if operand_tree != init_value_tree:
raise ValueError('Operands must have the same tree structure as '
f'init_values: {operand_tree} vs. {init_value_tree}')
if len(flat_operands) == 0:
raise ValueError('reduce_window must have at least one operand.')
raise ValueError(
"Operands must have the same tree structure as "
f"init_values: {operand_tree} vs. {init_value_tree}"
)
if len(flat_operands) != len(flat_init_values):
raise ValueError('Must have same total number of operands as init_values: '
f' {len(flat_operands)} vs. {len(flat_init_values)}')
raise ValueError(
"Must have same total number of operands as init_values: "
f" {len(flat_operands)} vs. {len(flat_init_values)}"
)

if len(flat_operands) == 0:
raise ValueError("reduce_window must have at least one operand.")
if isinstance(padding, str):
dilated_window_dims = (
window_dimensions if window_dilation is None else
Expand All @@ -82,21 +93,52 @@ def reduce_window(operand, init_value, computation: Callable,
else:
flat_init_avals = map(lax._abstractify, flat_init_values)
jaxpr, out_tree = lax._variadic_reduction_jaxpr(
computation, tuple(flat_init_avals), init_value_tree)
computation, tuple(flat_init_avals), init_value_tree
)
if operand_tree != out_tree:
raise ValueError(
'reduce_window output must have the same tree structure as the operands'
f' {operand_tree} vs. {out_tree}')
out_flat = reduce_window_p.bind(
*flat_operands, *flat_init_values, jaxpr=jaxpr.jaxpr,
consts=tuple(jaxpr.consts), window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=padding,
*flat_operands,
*flat_init_values,
jaxpr=jaxpr.jaxpr,
consts=tuple(jaxpr.consts),
window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides),
padding=padding,
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
window_dilation=tuple(window_dilation),
)
return tree_util.tree_unflatten(out_tree, out_flat)

def _get_monoid_window_reducer(monoid_op: Callable,
xs: Sequence[Array]) -> Callable | None:


def reduce_window(
operand,
init_value,
computation: Callable,
window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: str | Sequence[tuple[int, int]],
base_dilation: Sequence[int] | None = None,
window_dilation: Sequence[int] | None = None,
) -> Array:
return _reduce_window(
operand,
init_value,
computation,
window_dimensions,
window_strides,
padding,
base_dilation,
window_dilation,
)


def _get_monoid_window_reducer(
monoid_op, xs: Sequence[Array]
) -> Callable | None:
if len(xs) != 1:
return None
x, = xs
Expand All @@ -112,6 +154,7 @@ def _get_monoid_window_reducer(monoid_op: Callable,
and _reduce_window_min)
return None


def _reduce_window_sum(operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[tuple[int, int]],
Expand Down Expand Up @@ -260,10 +303,19 @@ def _select_and_gather_add(tangents: Array, operand: Array,


def _reduce_window_abstract_eval_rule(
*avals, jaxpr, consts, window_dimensions, window_strides, padding,
base_dilation, window_dilation):
*avals,
jaxpr,
consts,
window_dimensions,
window_strides,
padding,
base_dilation,
window_dilation,
):
operand_avals, init_val_avals = util.split_list(avals, [len(avals) // 2])
if any(o.dtype != iv.dtype for o, iv in zip(operand_avals, init_val_avals)):
if any(
o.dtype != iv.dtype for o, iv in zip(operand_avals, init_val_avals)
):
msg = ("reduce_window got inconsistent dtypes for operands and init_values:"
" got operand dtypes {} and init_value dtypes {}.")
raise TypeError(msg.format([o.dtype for o in operand_avals],
Expand All @@ -273,13 +325,28 @@ def _reduce_window_abstract_eval_rule(
"have shapes {}.")
raise TypeError(msg.format([v.shape for v in init_val_avals]))
out_shape = _common_reduce_window_shape_rule(
operand_avals[0], window_dimensions, window_strides, padding,
base_dilation, window_dilation)
operand_avals[0],
window_dimensions,
window_strides,
padding,
base_dilation,
window_dilation,
)
return tuple(ShapedArray(out_shape, op.dtype) for op in operand_avals)


def _generic_reduce_window_batch_rule(
batched_args, batch_dims, *, jaxpr, consts, window_dimensions,
window_strides, padding, base_dilation, window_dilation):
batched_args,
batch_dims,
*,
jaxpr,
consts,
window_dimensions,
window_strides,
padding,
base_dilation,
window_dilation,
):
num_operands = len(batched_args) // 2
operands, init_values = util.split_list(batched_args, [num_operands])
operand_bdims, init_value_bdims = util.split_list(batch_dims, [num_operands])
Expand All @@ -306,14 +373,68 @@ def _generic_reduce_window_batch_rule(


reduce_window_p = core.Primitive('reduce_window')


def reduce_window_jvp(
primals,
tangents,
window_dimensions,
window_strides,
padding,
base_dilation,
window_dilation,
jaxpr,
consts,
):

reduction_jaxpr = jaxpr

n = len(primals) // 2 # number of primal operands
operand, init_value = util.split_list(primals, [n])
operand_tangent, init_value_tangent = util.split_list(tangents, [n])
if not all(isinstance(t, ad.Zero) for t in init_value_tangent):
raise TypeError("reduce_window jvp does not support non-zero init_value_tangent.")

init_value_tangent = map(ad_util.instantiate, init_value_tangent)
c_reduction_jaxpr = ClosedJaxpr(reduction_jaxpr, consts)
jvp_reduction = jvp_jaxpr(c_reduction_jaxpr, (True,) * len(tangents), [False] * len(init_value_tangent))[0]

def wrapper(left, right):
pl, tl = util.split_list(left, [n])
pr, tr = util.split_list(right, [n])
return jaxpr_as_fun(jvp_reduction)(*pl, *pr, *tl, *tr)

jvp_primals_tangents = _reduce_window(
operand=[*operand, *operand_tangent],
init_value=[*init_value, *init_value_tangent],
computation=wrapper,
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding,
base_dilation=base_dilation,
window_dilation=window_dilation,
)
primals, tangents = util.split_list(jvp_primals_tangents, [len(jvp_primals_tangents) // 2])
return [*primals], [*tangents]

ad.primitive_jvps[reduce_window_p] = reduce_window_jvp
reduce_window_p.multiple_results = True
reduce_window_p.def_impl(partial(dispatch.apply_primitive, reduce_window_p))
reduce_window_p.def_abstract_eval(_reduce_window_abstract_eval_rule)
batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule

def _generic_reduce_window_lower(ctx, *args, jaxpr, consts,
window_dimensions, window_strides, padding,
base_dilation, window_dilation):

def _generic_reduce_window_lower(
ctx,
*args,
jaxpr,
consts,
window_dimensions,
window_strides,
padding,
base_dilation,
window_dilation,
):
operands, init_values = util.split_list(args, [len(args) // 2])
_, init_value_avals = util.split_list(ctx.avals_in, [len(operands)])

Expand All @@ -330,11 +451,15 @@ def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]:
reducer_name="generic_reduce_window_reducer",
reducer_body=reducer_body,
operands=operands,
init_values=init_values, init_values_avals=init_value_avals,
init_values=init_values,
init_values_avals=init_value_avals,
out_avals=ctx.avals_out,
window_dimensions=window_dimensions, window_strides=window_strides,
base_dilation=base_dilation, window_dilation=window_dilation,
padding=padding)
window_dimensions=window_dimensions,
window_strides=window_strides,
base_dilation=base_dilation,
window_dilation=window_dilation,
padding=padding,
)


mlir.register_lowering(reduce_window_p, _generic_reduce_window_lower)
Expand Down Expand Up @@ -402,18 +527,25 @@ def _reduce_window_chooser_jvp_rule(prim, g, operand, *, window_dimensions,
window_dilation)


def _common_reduce_window_shape_rule(operand, window_dimensions,
window_strides, padding, base_dilation,
window_dilation):
def _common_reduce_window_shape_rule(
operand,
window_dimensions,
window_strides,
padding,
base_dilation,
window_dilation,
):
lax._check_shapelike("reduce_window", "window_dimensions", window_dimensions,
non_zero_shape=True)
lax._check_shapelike("reduce_window", "window_strides", window_strides,
non_zero_shape=True)
lax._check_shapelike("reduce_window", "base_dilation", base_dilation)
lax._check_shapelike("reduce_window", "window_dilation", window_dilation)
if operand.ndim != len(window_dimensions):
msg = ("reduce_window got the wrong number of window_dimensions for "
"operand: got operand shape {} with window_dimensions {}.")
msg = (
"reduce_window got the wrong number of window_dimensions for "
"operand: got operand shape {} with window_dimensions {}."
)
raise TypeError(msg.format(operand.shape, window_dimensions))
if len(window_strides) != len(window_dimensions):
msg = ("reduce_window got inconsistent window_strides and "
Expand Down Expand Up @@ -443,6 +575,7 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
operand_padded = tuple(d + pl + ph for d, (pl, ph) in zip(operand_shape, padding))
return tuple(map(core.stride_dim, operand_padded, window_dimensions, window_strides))


reduce_window_max_p = lax.standard_primitive(
_common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_max')
ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule,
Expand All @@ -463,24 +596,36 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,

def _reduce_window_lower(
reduce_op,
init_value, ctx, operand, *,
window_dimensions, window_strides, padding, base_dilation,
window_dilation):
init_value,
ctx,
operand,
*,
window_dimensions,
window_strides,
padding,
base_dilation,
window_dilation,
):

operand_aval, = ctx.avals_in
scalar_aval = operand_aval.update(shape=())

return mlir.reduce_window(ctx,
return mlir.reduce_window(
ctx,
reducer_name=f"reduce_window_{scalar_aval.dtype}_reducer",
reducer_body=lambda reducer: [reduce_op(*reducer.arguments)],
operands=[operand],
init_values=[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype),
scalar_aval)],
init_values=[
mlir.full_like_aval(ctx, init_value(scalar_aval.dtype), scalar_aval)
],
init_values_avals=[scalar_aval],
out_avals=ctx.avals_out,
window_dimensions=window_dimensions,
window_strides=window_strides, base_dilation=base_dilation,
window_dilation=window_dilation, padding=padding)
window_strides=window_strides,
base_dilation=base_dilation,
window_dilation=window_dilation,
padding=padding,
)


mlir.register_lowering(reduce_window_sum_p, partial(
Expand Down Expand Up @@ -871,6 +1016,7 @@ def _select_and_gather_add_batching_rule(
_select_and_gather_add_using_variadic_reducewindow,
multiple_results=False))


# TODO(b/183233858): use variadic reducewindow on GPU, when implemented.
mlir.register_lowering(
select_and_gather_add_p,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from jax._src.util import unzip2
from jax._src.public_test_util import ( # noqa: F401
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, tolerance)
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, tolerance, rand_like)
from jax._src import xla_bridge


Expand Down

0 comments on commit c24ccd1

Please sign in to comment.