From 0f6e91165072b393bccd92a5fb0f598271bf9ac1 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 16 Dec 2023 13:51:20 +0100 Subject: [PATCH] Add rewrite for matmul when only one of the inputs has batched dimensions This rewrites replaces a batched matmul by a core matmul by raveling the batched dimensions of the batched input, doing a 2D matmul and then unravelling the batched dimensions of the output. This sidesteps the Blockwise Dot operation and allows specialization into BLAS routines. The idea was taken from these two discussions: https://github.com/numpy/numpy/issues/7569 https://github.com/numpy/numpy/issues/8957 --- pytensor/tensor/rewriting/math.py | 53 +++++++++++++++++++++++++++++ tests/tensor/rewriting/test_math.py | 45 ++++++++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 5309abe882..196830e1ae 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -31,11 +31,13 @@ constant, extract_constant, get_underlying_scalar_constant_value, + moveaxis, ones_like, register_infer_shape, switch, zeros_like, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays @@ -217,6 +219,57 @@ def local_lift_transpose_through_dot(fgraph, node): return ret +@register_stabilize +@register_specialize +@node_rewriter(tracks=[Blockwise]) +def local_batched_matmul_to_core_matmul(fgraph, node): + """Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul. + + Example, if x has batch dimensions, but y not: + x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1] + y.shape[-1]) + + It also works when y has batch dimensions, but x not. + """ + + # Check whether we have a matmul operation in this node + if not ( + isinstance(node.op.core_op, Dot) + and len(node.op.inputs_sig[0]) == 2 + and len(node.op.inputs_sig[1]) == 2 + ): + return None + + x, y = node.inputs + batch_ndim = node.op.batch_ndim(node) + + # Check if x has batch dimensions, but y not (or only broadcastable dimensions) + if any(not b_dim for b_dim in x.type.broadcastable[:-2]) and all( + y.type.broadcastable[:-2] + ): + x_stacked = x.reshape((-1, x.shape[-1])) + out_stacked = x_stacked @ y.squeeze(tuple(range(batch_ndim))) + out = out_stacked.reshape((*x.shape[:-1], y.shape[-1])) + return [out] + + # Otherwise, check if y has batch dimension, but x not + elif any(not b_dim for b_dim in y.type.broadcastable[:-2]) and all( + x.type.broadcastable[:-2] + ): + # For the y batch case we need to first move the batch axes and then reshape + # y.shape == (*b, k, n) + y_tr = moveaxis(y, -2, 0) # (k, *b, n) + y_stacked = y_tr.reshape((y.shape[-2], -1)) # (k, *b * n) + out_stacked = x.squeeze(tuple(range(batch_ndim))) @ y_stacked # (m, *b * n) + out_stacked_tr = out_stacked.reshape( + (x.shape[-2], *y.shape[:-2], y.shape[-1]) + ) # (m, *b, n) + out = moveaxis(out_stacked_tr, 0, -2) # (*b, m, n) + return [out] + + # Both x and y have batch dimensions, nothing to do here + return None + + def is_inverse_pair(node_op, prev_op, inv_pair): """ Given two consecutive operations, check if they are the diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index dc7927db05..df32711bec 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -34,6 +34,7 @@ from pytensor.tensor.basic import Alloc, constant, join, second, switch from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas_c import CGemv +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj from pytensor.tensor.math import abs as pt_abs @@ -4427,3 +4428,47 @@ def test_polygamma_specialization(): assert isinstance(fn_outs[0].owner.op.scalar_op, Psi) assert isinstance(fn_outs[1].owner.op.scalar_op, TriGamma) assert isinstance(fn_outs[2].owner.op.scalar_op, PolyGamma) + + +def test_local_batched_matmul_to_core_matmul(): + rng = np.random.default_rng(seed=4433) + + # x is batched but not y + x = pt.tensor("x", shape=(None, 3, 2), dtype="float64") + y = pt.tensor("y", shape=(2, 2), dtype="float64") + out = x @ y + assert isinstance(out.owner.op, Blockwise) + + fn = pytensor.function([x, y], out) + assert not any( + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + ) + + x_test = rng.normal(size=(5, 3, 2)) + y_test = rng.normal(size=(2, 2)) + np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) + + # y is batched but not x + x = pt.tensor("x", shape=(1, 3, 2), dtype="float64") + y = pt.tensor("y", shape=(5, 2, 2), dtype="float64") + out = x @ y + assert isinstance(out.owner.op, Blockwise) + + fn = pytensor.function([x, y], out) + assert not any( + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + ) + + x_test = rng.normal(size=(1, 3, 2)) + y_test = rng.normal(size=(5, 2, 2)) + np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) + + # Both x and y are batched, rewrite does not apply + x = pt.tensor("x", shape=(None, 3, 2), dtype="float64") + y = pt.tensor("y", shape=(5, 2, 2), dtype="float64") + out = x @ y + + fn = pytensor.function([x, y], out) + x_test = rng.normal(size=(5, 3, 2)) + y_test = rng.normal(size=(5, 2, 2)) + np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)