Skip to content

Commit

Permalink
Add rewrite for matmul when only one of the inputs has batched dimens…
Browse files Browse the repository at this point in the history
…ions

This rewrites replaces a batched matmul by a core matmul by raveling the batched dimensions of the batched input, doing a 2D matmul and then unravelling the batched dimensions of the output.

This sidesteps the Blockwise Dot operation and allows specialization into BLAS routines.

The idea was taken from these two discussions:
numpy/numpy#7569
numpy/numpy#8957
  • Loading branch information
ricardoV94 committed Dec 16, 2023
1 parent 60a9566 commit d28b35f
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
53 changes: 53 additions & 0 deletions pytensor/tensor/rewriting/math.py
Expand Up @@ -31,11 +31,13 @@
constant,
extract_constant,
get_underlying_scalar_constant_value,
moveaxis,
ones_like,
register_infer_shape,
switch,
zeros_like,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
Expand Down Expand Up @@ -217,6 +219,57 @@ def local_lift_transpose_through_dot(fgraph, node):
return ret


@register_stabilize
@register_specialize
@node_rewriter(tracks=[Blockwise])
def local_batched_matmul_to_core_matmul(fgraph, node):
"""Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul.
Example, if x has batch dimensions, but y not:
x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1] + y.shape[-1])
It also works when y has batch dimensions, but x not.
"""

# Check whether we have a matmul operation in this node
if not (
isinstance(node.op.core_op, Dot)
and len(node.op.inputs_sig[0]) == 2
and len(node.op.inputs_sig[1]) == 2
):
return None

x, y = node.inputs
batch_ndim = node.op.batch_ndim(node)

# Check if x has batch dimensions, but y not (or only broadcastable dimensions)
if any(not b_dim for b_dim in x.type.broadcastable[:-2]) and all(
y.type.broadcastable[:-2]
):
x_stacked = x.reshape((-1, x.shape[-1]))
out_stacked = x_stacked @ y.squeeze(tuple(range(batch_ndim)))
out = out_stacked.reshape((*x.shape[:-1], y.shape[-1]))
return [out]

# Otherwise, check if y has batch dimension, but x not
elif any(not b_dim for b_dim in y.type.broadcastable[:-2]) and all(
x.type.broadcastable[:-2]
):
# For the y batch case we need to first move the batch axes and then reshape
# y.shape == (*b, k, n)
y_tr = moveaxis(y, -2, 0) # (k, *b, n)
y_stacked = y_tr.reshape((y.shape[-2], -1)) # (k, *b * n)
out_stacked = x.squeeze(tuple(range(batch_ndim))) @ y_stacked # (m, *b * n)
out_stacked_tr = out_stacked.reshape(
(x.shape[-2], *y.shape[:-2], y.shape[-1])
) # (m, *b, n)
out = moveaxis(out_stacked_tr, 0, -2) # (*b, m, n)
return [out]

# Both x and y have batch dimensions, nothing to do here
return None


def is_inverse_pair(node_op, prev_op, inv_pair):
"""
Given two consecutive operations, check if they are the
Expand Down
49 changes: 49 additions & 0 deletions tests/tensor/rewriting/test_math.py
Expand Up @@ -34,6 +34,7 @@
from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj
from pytensor.tensor.math import abs as pt_abs
Expand Down Expand Up @@ -4427,3 +4428,51 @@ def test_polygamma_specialization():
assert isinstance(fn_outs[0].owner.op.scalar_op, Psi)
assert isinstance(fn_outs[1].owner.op.scalar_op, TriGamma)
assert isinstance(fn_outs[2].owner.op.scalar_op, PolyGamma)


@pytest.mark.skipif(
config.mode == "FAST_COMPILE",
reason="Rewrite is only relevant in FAST_RUN",
)
def test_local_batched_matmul_to_core_matmul():
rng = np.random.default_rng(seed=4433)

# x is batched but not y
x = pt.tensor("x", shape=(None, 3, 2), dtype="float64")
y = pt.tensor("y", shape=(2, 2), dtype="float64")
out = x @ y
assert isinstance(out.owner.op, Blockwise)

fn = pytensor.function([x, y], out)
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)

x_test = rng.normal(size=(5, 3, 2))
y_test = rng.normal(size=(2, 2))
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)

# y is batched but not x
x = pt.tensor("x", shape=(1, 3, 2), dtype="float64")
y = pt.tensor("y", shape=(5, 2, 2), dtype="float64")
out = x @ y
assert isinstance(out.owner.op, Blockwise)

fn = pytensor.function([x, y], out)
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)

x_test = rng.normal(size=(1, 3, 2))
y_test = rng.normal(size=(5, 2, 2))
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)

# Both x and y are batched, rewrite does not apply
x = pt.tensor("x", shape=(None, 3, 2), dtype="float64")
y = pt.tensor("y", shape=(5, 2, 2), dtype="float64")
out = x @ y

fn = pytensor.function([x, y], out)
x_test = rng.normal(size=(5, 3, 2))
y_test = rng.normal(size=(5, 2, 2))
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)

0 comments on commit d28b35f

Please sign in to comment.