Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Improve performance of np.broadcast_arrays and np.broadcast_shapes #26160

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 6 additions & 7 deletions numpy/lib/_stride_tricks_impl.py
Expand Up @@ -478,7 +478,7 @@ def broadcast_shapes(*args):
>>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7))
(5, 6, 7)
"""
arrays = [np.empty(x, dtype=[]) for x in args]
arrays = [np.empty(x, dtype=bool) for x in args]
return _broadcast_shape(*arrays)


Expand Down Expand Up @@ -546,13 +546,12 @@ def broadcast_arrays(*args, subok=False):
# return np.nditer(args, flags=['multi_index', 'zerosize_ok'],
# order='C').itviews

args = tuple(np.array(_m, copy=None, subok=subok) for _m in args)
args = [np.array(_m, copy=None, subok=subok) for _m in args]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this really help? I thought that these days list comprehension and creating a tuple via an iterator made very little speed difference, and *args should be very slightly faster for a tuple.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does help, although you are right the cost of list comprehensions has gone down. On Python 3.12:

In [16]: import dis
    ...: args=[1, 2, 3]
    ...: %timeit tuple(2*j for j in args)
    ...: %timeit tuple([2*j for j  in args])
298 ns ± 22.1 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
108 ns ± 0.781 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

So about 200ns for a small size, which does matter in the benchmarks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I confirm this. It is funny because I really thought python had solved the speed difference, but clearly I was wrong.


shape = _broadcast_shape(*args)

if all(array.shape == shape for array in args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice catch.

# Common case where nothing needs to be broadcasted.
return args
result = [array if array.shape == shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the above, I wonder if it is actually slower to directly return tuple(array if array.shape ...)?

else _broadcast_to(array, shape, subok=subok, readonly=False)
for array in args]
return tuple(result)

return tuple(_broadcast_to(array, shape, subok=subok, readonly=False)
for array in args)
20 changes: 11 additions & 9 deletions numpy/lib/tests/test_stride_tricks.py
Expand Up @@ -341,7 +341,7 @@ def test_broadcast_shapes_raises():
[(2, 3), (2,)],
[(3,), (3,), (4,)],
[(1, 3, 4), (2, 3, 3)],
[(1, 2), (3,1), (3,2), (10, 5)],
[(1, 2), (3, 1), (3, 2), (10, 5)],
[2, (2, 3)],
]
for input_shapes in data:
Expand Down Expand Up @@ -578,11 +578,12 @@ def test_writeable():

# but the result of broadcast_arrays needs to be writeable, to
# preserve backwards compatibility
for is_broadcast, results in [(False, broadcast_arrays(original,)),
(True, broadcast_arrays(0, original))]:
for result in results:
test_cases = [((False,), broadcast_arrays(original,)),
((True, False), broadcast_arrays(0, original))]
for is_broadcast, results in test_cases:
for array_is_broadcast, result in zip(is_broadcast, results):
# This will change to False in a future version
if is_broadcast:
if array_is_broadcast:
with assert_warns(FutureWarning):
assert_equal(result.flags.writeable, True)
with assert_warns(DeprecationWarning):
Expand Down Expand Up @@ -623,11 +624,12 @@ def test_writeable_memoryview():
# See gh-13929.
original = np.array([1, 2, 3])

for is_broadcast, results in [(False, broadcast_arrays(original,)),
(True, broadcast_arrays(0, original))]:
for result in results:
test_cases = [((False, ), broadcast_arrays(original,)),
((True, False), broadcast_arrays(0, original))]
for is_broadcast, results in test_cases:
for array_is_broadcast, result in zip(is_broadcast, results):
# This will change to False in a future version
if is_broadcast:
if array_is_broadcast:
# memoryview(result, writable=True) will give warning but cannot
# be tested using the python API.
assert memoryview(result).readonly
Expand Down