Skip to content

Commit

Permalink
Implement vectorize_node dispatch for AdvancedSubtensor
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 7, 2024
1 parent 3b54507 commit 33897ac
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 28 deletions.
27 changes: 27 additions & 0 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
zscalar,
)
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType, make_slice
from pytensor.tensor.variable import TensorVariable


_logger = logging.getLogger("pytensor.tensor.subtensor")
Expand Down Expand Up @@ -2686,6 +2687,32 @@ def grad(self, inputs, grads):
advanced_subtensor = AdvancedSubtensor()


@_vectorize_node.register(AdvancedSubtensor)
def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
x, *idxs = node.inputs
batch_x, *batch_idxs = batch_inputs

# If indexes are batched fallback to Blockwise
if any(
batch_idx.type.ndim > idx.type.ndim
for batch_idx, idx in zip(batch_idxs, idxs)
if isinstance(batch_idx, TensorVariable)
):
# Blockwise doesn't accept None or Slices types so we raise informative error here
# TODO: Implement these internally, so Blockwise is always a safe fallback
if any(not isinstance(idx.type, TensorVariable) for idx in idxs):
raise NotImplementedError(
"Vectorized AdvancedSubtensor combining batched indexes and slices or newaxis not supported."
)
else:
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)

Check warning on line 2708 in pytensor/tensor/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/subtensor.py#L2708

Added line #L2708 was not covered by tests

# Otherwise we just need to add None slices for every new batch dim
x_batch_ndim = batch_x.type.ndim - x.type.ndim
empty_slices = (slice(None),) * x_batch_ndim
return op.make_node(batch_x, *empty_slices, *batch_idxs)


class AdvancedIncSubtensor(Op):
"""Increments a subtensor using advanced indexing."""

Expand Down
58 changes: 30 additions & 28 deletions tests/tensor/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2713,41 +2713,43 @@ def test_static_shapes(x_shape, indices, expected):
assert y.type.shape == expected


def test_vectorize_subtensor_without_batch_indices():
signature = "(t1,t2,t3),()->(t1,t3)"
@pytest.mark.parametrize(
"adv_idx, batch_idx",
[
(False, False),
(False, True),
(True, False),
# Advanced indexing with batched indexes fails when there are Slices / None
pytest.param(True, True, marks=pytest.mark.xfail(raises=NotImplementedError)),
],
)
@config.change_flags(cxx="") # C code not needed
def test_vectorize_subtensor(adv_idx, batch_idx):
if adv_idx:
signature = "(t1,t2,t3),(idx)->(t1,tx,t3)"
else:
signature = "(t1,t2,t3),()->(t1,t3)"

def core_fn(x, start):
return x[:, start, :]
def core_fn(x, idx):
return x[:, idx, :]

x = tensor(shape=(11, 7, 5, 3))
start = tensor(shape=(), dtype="int")
vectorize_pt = function(
[x, start], vectorize(core_fn, signature=signature)(x, start)
)
assert not any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
start_test = np.random.randint(0, x.type.shape[-2])
vectorize_np = np.vectorize(core_fn, signature=signature)
np.testing.assert_allclose(
vectorize_pt(x_test, start_test),
vectorize_np(x_test, start_test),
)
if batch_idx:
idx = tensor(shape=(11, 2) if adv_idx else (11,), dtype="int")
else:
idx = tensor(shape=(2,) if adv_idx else (), dtype="int")
vectorize_pt = function([x, idx], vectorize(core_fn, signature=signature)(x, idx))

# If we vectorize start, we should get a Blockwise that still works
x = tensor(shape=(11, 7, 5, 3))
start = tensor(shape=(11,), dtype="int")
vectorize_pt = function(
[x, start], vectorize(core_fn, signature=signature)(x, start)
)
assert any(
needs_blockwise = batch_idx
has_blockwise = any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)
assert has_blockwise == needs_blockwise

x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
start_test = np.random.randint(0, x.type.shape[-2], size=start.type.shape[0])
idx_test = np.random.randint(0, x.type.shape[-2], size=idx.type.shape)
vectorize_np = np.vectorize(core_fn, signature=signature)
np.testing.assert_allclose(
vectorize_pt(x_test, start_test),
vectorize_np(x_test, start_test),
vectorize_pt(x_test, idx_test),
vectorize_np(x_test, idx_test),
)

0 comments on commit 33897ac

Please sign in to comment.