diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 5309abe882..670fec4211 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -31,11 +31,13 @@ constant, extract_constant, get_underlying_scalar_constant_value, + moveaxis, ones_like, register_infer_shape, switch, zeros_like, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays @@ -217,6 +219,57 @@ def local_lift_transpose_through_dot(fgraph, node): return ret +@register_stabilize +@register_specialize +@node_rewriter(tracks=[Blockwise]) +def local_batched_matmul_to_core_matmul(fgraph, node): + """Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul. + + Example, if x has batch dimensions, but y not: + x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1]) + + It also works when y has batch dimensions, but x not. + """ + + # Check whether we have a matmul operation in this node + if not ( + isinstance(node.op.core_op, Dot) + and len(node.op.inputs_sig[0]) == 2 + and len(node.op.inputs_sig[1]) == 2 + ): + return None + + x, y = node.inputs + batch_ndim = node.op.batch_ndim(node) + + # Check if x has batch dimensions, but y not (or only broadcastable dimensions) + if any(not b_dim for b_dim in x.type.broadcastable[:-2]) and all( + y.type.broadcastable[:-2] + ): + x_stacked = x.reshape((-1, x.shape[-1])) + out_stacked = x_stacked @ y.squeeze(tuple(range(batch_ndim))) + out = out_stacked.reshape((*x.shape[:-1], y.shape[-1])) + return [out] + + # Otherwise, check if y has batch dimension, but x not + elif any(not b_dim for b_dim in y.type.broadcastable[:-2]) and all( + x.type.broadcastable[:-2] + ): + # For the y batch case we need to first move the batch axes and then reshape + # y.shape == (*b, k, n) + y_tr = moveaxis(y, -2, 0) # (k, *b, n) + y_stacked = y_tr.reshape((y.shape[-2], -1)) # (k, *b * n) + out_stacked = x.squeeze(tuple(range(batch_ndim))) @ y_stacked # (m, *b * n) + out_stacked_tr = out_stacked.reshape( + (x.shape[-2], *y.shape[:-2], y.shape[-1]) + ) # (m, *b, n) + out = moveaxis(out_stacked_tr, 0, -2) # (*b, m, n) + return [out] + + # Both x and y have batch dimensions, nothing to do here + return None + + def is_inverse_pair(node_op, prev_op, inv_pair): """ Given two consecutive operations, check if they are the diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index dc7927db05..d23189050a 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -34,6 +34,7 @@ from pytensor.tensor.basic import Alloc, constant, join, second, switch from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas_c import CGemv +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj from pytensor.tensor.math import abs as pt_abs @@ -4427,3 +4428,51 @@ def test_polygamma_specialization(): assert isinstance(fn_outs[0].owner.op.scalar_op, Psi) assert isinstance(fn_outs[1].owner.op.scalar_op, TriGamma) assert isinstance(fn_outs[2].owner.op.scalar_op, PolyGamma) + + +@pytest.mark.skipif( + config.mode == "FAST_COMPILE", + reason="Rewrite is only relevant in FAST_RUN", +) +def test_local_batched_matmul_to_core_matmul(): + rng = np.random.default_rng(seed=4433) + + # x is batched but not y + x = pt.tensor("x", shape=(None, 3, 2), dtype="float64") + y = pt.tensor("y", shape=(2, 2), dtype="float64") + out = x @ y + assert isinstance(out.owner.op, Blockwise) + + fn = pytensor.function([x, y], out) + assert not any( + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + ) + + x_test = rng.normal(size=(5, 3, 2)) + y_test = rng.normal(size=(2, 2)) + np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) + + # y is batched but not x + x = pt.tensor("x", shape=(1, 3, 2), dtype="float64") + y = pt.tensor("y", shape=(5, 2, 2), dtype="float64") + out = x @ y + assert isinstance(out.owner.op, Blockwise) + + fn = pytensor.function([x, y], out) + assert not any( + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + ) + + x_test = rng.normal(size=(1, 3, 2)) + y_test = rng.normal(size=(5, 2, 2)) + np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) + + # Both x and y are batched, rewrite does not apply + x = pt.tensor("x", shape=(None, 3, 2), dtype="float64") + y = pt.tensor("y", shape=(5, 2, 2), dtype="float64") + out = x @ y + + fn = pytensor.function([x, y], out) + x_test = rng.normal(size=(5, 3, 2)) + y_test = rng.normal(size=(5, 2, 2)) + np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)