Skip to content

Commit

Permalink
Merge pull request #16346 from neutrinoceros/depr/deprecate_check_bro…
Browse files Browse the repository at this point in the history
…adcast

DEPR: deprecate `astropy.utils.check_broadcast`
  • Loading branch information
mhvk committed Apr 29, 2024
2 parents ef98d46 + 796c9c2 commit 0c06ac3
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 72 deletions.
16 changes: 9 additions & 7 deletions astropy/coordinates/baseframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import numpy as np

from astropy import units as u
from astropy.utils import ShapedLikeNDArray, check_broadcast
from astropy.utils import ShapedLikeNDArray
from astropy.utils.decorators import deprecated, format_doc, lazyproperty
from astropy.utils.exceptions import AstropyWarning
from astropy.utils.exceptions import AstropyWarning, _add_note_to_exception

from . import representation as r
from .angles import Angle, position_angle
Expand Down Expand Up @@ -344,11 +344,13 @@ def __init__(

# Determine the overall shape of the frame.
try:
self._shape = check_broadcast(*shapes)
except ValueError as err:
raise ValueError(
f"non-scalar data and/or attributes with inconsistent shapes: {shapes}"
) from err
self._shape = np.broadcast_shapes(*shapes)
except ValueError as exc:
_add_note_to_exception(
exc,
f"non-scalar data and/or attributes with inconsistent shapes: {shapes}",
)
raise exc

# Broadcast the data if necessary and set it
if data is not None and data.shape != self._shape:
Expand Down
16 changes: 14 additions & 2 deletions astropy/coordinates/tests/test_frames.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst

import re
import sys
from copy import deepcopy

import numpy as np
Expand Down Expand Up @@ -38,6 +39,7 @@
REPRESENTATION_CLASSES,
CartesianDifferential,
)
from astropy.tests.helper import PYTEST_LT_8_0
from astropy.tests.helper import assert_quantity_allclose as assert_allclose
from astropy.time import Time
from astropy.units import allclose
Expand Down Expand Up @@ -254,12 +256,22 @@ def test_no_data_nonscalar_frames():
assert a1.obstime.shape == (3, 10)
assert a1.temperature.shape == (3, 10)
assert a1.shape == (3, 10)
with pytest.raises(ValueError) as exc:

if sys.version_info >= (3, 11) and PYTEST_LT_8_0:
# Exception.__notes__ are available (and used here) but ignored in matching,
# so we'll match manually and post-mortem instead
match = None
else:
match = r".*inconsistent shapes.*"

with pytest.raises(ValueError, match=match) as exc:
AltAz(
obstime=Time("2012-01-01") + np.arange(10.0) * u.day,
temperature=np.ones((3,)) * u.deg_C,
)
assert "inconsistent shapes" in str(exc.value)

if match is None:
assert "inconsistent shapes" in "\n".join(exc.value.__notes__)


def test_frame_repr():
Expand Down
101 changes: 73 additions & 28 deletions astropy/modeling/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import functools
import inspect
import operator
import re
import warnings
from collections import defaultdict, deque
from inspect import signature
from textwrap import indent
Expand All @@ -30,15 +32,14 @@
from astropy.units import Quantity, UnitsError, dimensionless_unscaled
from astropy.units.utils import quantity_asanyarray
from astropy.utils import (
IncompatibleShapeError,
check_broadcast,
find_current_module,
isiterable,
metadata,
sharedmethod,
)
from astropy.utils.codegen import make_function_with_signature
from astropy.utils.compat import COPY_IF_NEEDED
from astropy.utils.exceptions import _add_note_to_exception

from .bounding_box import CompoundBoundingBox, ModelBoundingBox
from .parameters import InputParameterError, Parameter, _tofloat, param_repr_oneline
Expand Down Expand Up @@ -1063,11 +1064,12 @@ def _validate_input_shapes(self, inputs, argnames, model_set_axis):
)

try:
input_shape = check_broadcast(*all_shapes)
except IncompatibleShapeError as e:
raise ValueError(
"All inputs must have identical shapes or must be scalars."
) from e
input_shape = np.broadcast_shapes(*all_shapes)
except ValueError as exc:
_add_note_to_exception(
exc, "All inputs must have identical shapes or must be scalars."
)
raise exc

return input_shape

Expand Down Expand Up @@ -1949,15 +1951,17 @@ def _prepare_inputs_single_model(self, params, inputs, **kwargs):
for param in params:
try:
if self.standard_broadcasting:
broadcast = check_broadcast(input_shape, param.shape)
broadcast = np.broadcast_shapes(input_shape, param.shape)
else:
broadcast = input_shape
except IncompatibleShapeError:
raise ValueError(
except ValueError as exc:
_add_note_to_exception(
exc,
f"self input argument {self.inputs[idx]!r} of shape"
f" {input_shape!r} cannot be broadcast with parameter"
f" {param.name!r} of shape {param.shape!r}."
f" {param.name!r} of shape {param.shape!r}.",
)
raise exc

if len(broadcast) > len(max_broadcast):
max_broadcast = broadcast
Expand Down Expand Up @@ -2012,17 +2016,19 @@ def _prepare_inputs_model_set(self, params, inputs, model_set_axis_input, **kwar

for param in params:
try:
check_broadcast(
np.broadcast_shapes(
input_shape,
self._remove_axes_from_shape(param.shape, model_set_axis_param),
)
except IncompatibleShapeError:
raise ValueError(
except ValueError as exc:
_add_note_to_exception(
exc,
f"Model input argument {self.inputs[idx]!r} of shape"
f" {input_shape!r} "
f"cannot be broadcast with parameter {param.name!r} of shape "
f"{self._remove_axes_from_shape(param.shape, model_set_axis_param)!r}."
f"{self._remove_axes_from_shape(param.shape, model_set_axis_param)!r}.",
)
raise exc

if len(param.shape) - 1 > len(max_param_shape):
max_param_shape = self._remove_axes_from_shape(
Expand Down Expand Up @@ -2227,11 +2233,23 @@ def _prepare_output_single_model(output, broadcast_shape):

def _prepare_outputs_single_model(self, outputs, broadcasted_shapes):
outputs = list(outputs)
shapes = broadcasted_shapes[0]
for idx, output in enumerate(outputs):
try:
broadcast_shape = check_broadcast(*broadcasted_shapes[0])
except (IndexError, TypeError):
broadcast_shape = broadcasted_shapes[0][idx]
if None in shapes:
# Previously, we used our own function (check_broadcast) instead
# of np.broadcast_shapes in the following try block
# - check_broadcast raised an exception when passed a None.
# - as of numpy 1.26, np.broadcast raises a deprecation warning
# when passed a `None` value, but returns an empty tuple.
#
# Since () and None have different effects downstream of this function,
# and to preserve backward-compatibility, we handle this special here
broadcast_shape = shapes[idx]
else:
try:
broadcast_shape = np.broadcast_shapes(*shapes)
except Exception:
broadcast_shape = shapes[idx]

outputs[idx] = self._prepare_output_single_model(output, broadcast_shape)

Expand Down Expand Up @@ -2742,18 +2760,45 @@ def _check_param_broadcast(self, max_ndim):

# Now check mutual broadcastability of all shapes
try:
check_broadcast(*all_shapes)
except IncompatibleShapeError as exc:
shape_a, shape_a_idx, shape_b, shape_b_idx = exc.args
param_a = self.param_names[shape_a_idx]
param_b = self.param_names[shape_b_idx]

raise InputParameterError(
f"Parameter {param_a!r} of shape {shape_a!r} cannot be broadcast with "
f"parameter {param_b!r} of shape {shape_b!r}. All parameter arrays "
np.broadcast_shapes(*all_shapes)
except ValueError as exc:
# In a previous version, we used to have our own version of
# np.broadcast_shapes (check_broadcast). In order to preserve
# backward compatibility, we now have to go the extra mile and
# parse an error message controlled by numpy.
base_message = (
"All parameter arrays "
"must have shapes that are mutually compatible according "
"to the broadcasting rules."
)
broadcast_shapes_error_re = re.compile(
r"shape mismatch: objects cannot be broadcast to a single shape\. "
r"Mismatch is between "
r"arg (?P<argno_a>\d+) with shape (?P<shape_a>\((\d+(, ?)?)+\)) and "
r"arg (?P<argno_b>\d+) with shape (?P<shape_b>\((\d+(, ?)?)+\))\."
)
if (match := broadcast_shapes_error_re.fullmatch(str(exc))) is not None:
shape_a = match.group("shape_a")
shape_b = match.group("shape_b")
shape_a_idx = int(match.group("argno_a"))
shape_b_idx = int(match.group("argno_b"))
param_a = self.param_names[shape_a_idx]
param_b = self.param_names[shape_b_idx]
message = (
f"Parameter {param_a!r} of shape {shape_a} cannot be broadcast with "
f"parameter {param_b!r} of shape {shape_b}."
)
else:
warnings.warn(
"Failed to parse error message from np.broadcast_shapes. "
"Please report this at "
"https://github.com/astropy/astropy/issues/new/choose",
category=RuntimeWarning,
stacklevel=1,
)
message = "Some parameters failed to broadcast with each other."

raise InputParameterError(f"{message} {base_message}") from None

def _param_sets(self, raw=False, units=False):
"""
Expand Down
3 changes: 1 addition & 2 deletions astropy/modeling/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import numpy as np

from astropy.utils import check_broadcast
from astropy.utils.compat import COPY_IF_NEEDED

from .core import FittableModel, Model
Expand Down Expand Up @@ -1189,7 +1188,7 @@ def evaluate(self, x, y, *coeffs):
# still as expected by the broadcasting rules, even though the x and y
# inputs are not used in the evaluation
if self.degree == 0:
output_shape = check_broadcast(np.shape(coeffs[0]), x.shape)
output_shape = np.broadcast_shapes(np.shape(coeffs[0]), x.shape)
if output_shape:
new_result = np.empty(output_shape)
new_result[:] = result
Expand Down
36 changes: 6 additions & 30 deletions astropy/modeling/tests/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,11 +436,7 @@ def test_1d_array_parameters_1d_array_input(self):
assert np.shape(y2) == (2, 2)
assert np.all(y2 == [[111, 122], [211, 222]])

MESSAGE = (
r"self input argument 'x' of shape .* cannot be broadcast with parameter"
r" 'p1' of shape .*"
)
with pytest.raises(ValueError, match=MESSAGE):
with pytest.raises(ValueError, match="broadcast"):
# Doesn't broadcast
t([100, 200, 300])

Expand All @@ -466,11 +462,7 @@ def test_2d_array_parameters_2d_array_input(self):
]
)

MESSAGE = (
r"self input argument .* of shape .* cannot be broadcast with parameter .*"
r" of shape .*"
)
with pytest.raises(ValueError, match=MESSAGE):
with pytest.raises(ValueError, match="broadcast"):
# Doesn't broadcast
t([[100, 200, 300], [400, 500, 600]])

Expand Down Expand Up @@ -654,11 +646,7 @@ def test_1d_array_parameters_1d_array_input(self):
assert np.shape(y1) == (2, 3)
assert np.all(y1 == [[111, 122, 133], [244, 255, 266]])

MESSAGE = (
r"Model input argument .* of shape .* cannot be broadcast with parameter .*"
r" of shape .*"
)
with pytest.raises(ValueError, match=MESSAGE):
with pytest.raises(ValueError, match="broadcast"):
# Doesn't broadcast with the shape of the parameters, (3,)
y2 = t([100, 200], model_set_axis=False)

Expand All @@ -682,11 +670,7 @@ def test_2d_array_parameters_2d_array_input(self):
]
)

MESSAGE = (
r"Model input argument .* of shape .* cannot be broadcast with parameter .*"
r" of shape .*"
)
with pytest.raises(ValueError, match=MESSAGE):
with pytest.raises(ValueError, match="broadcast"):
y2 = t([[100, 200, 300], [400, 500, 600]])

y2 = t([[[100, 200], [300, 400]], [[500, 600], [700, 800]]])
Expand Down Expand Up @@ -845,11 +829,7 @@ def test_1d_array_parameters_1d_array_input(self):
assert np.all(y2 == [[111, 122], [211, 222]])
assert np.all(z2 == [[1111, 2122], [1211, 2222]])

MESSAGE = (
r"self input argument .* of shape .* cannot be broadcast with parameter .*"
r" of shape .*"
)
with pytest.raises(ValueError, match=MESSAGE):
with pytest.raises(ValueError, match="broadcast"):
# Doesn't broadcast
y3, z3 = t([100, 200, 300])

Expand Down Expand Up @@ -885,11 +865,7 @@ def test_2d_array_parameters_2d_array_input(self):
]
)

MESSAGE = (
r"self input argument .* of shape .* cannot be broadcast with parameter .*"
r" of shape .*"
)
with pytest.raises(ValueError, match=MESSAGE):
with pytest.raises(ValueError, match="broadcast"):
# Doesn't broadcast
y3, z3 = t([[100, 200, 300], [400, 500, 600]])

Expand Down
4 changes: 1 addition & 3 deletions astropy/modeling/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,7 @@ def test_inconsistent_input_shapes():
g = Gaussian2D()
x = np.arange(-1.0, 1, 0.2)
y = np.arange(-1.0, 1, 0.1)
with pytest.raises(
ValueError, match="All inputs must have identical shapes or must be scalars"
):
with pytest.raises(ValueError, match="broadcast"):
g(x, y)


Expand Down
11 changes: 11 additions & 0 deletions astropy/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@ def __repr__(self):
NoValue = _NoValue()


def _add_note_to_exception(exc: Exception, note: str) -> None:
import sys

if sys.version_info >= (3, 11):
exc.add_note(note)
else:
# mimic Python 3.11 behavior:
# preserve error message and traceback
exc.args += ("\n", note)


def __getattr__(name: str):
if name in ("ErfaError", "ErfaWarning"):
import warnings
Expand Down
2 changes: 2 additions & 0 deletions astropy/utils/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np

from astropy.utils.compat import NUMPY_LT_2_0
from astropy.utils.decorators import deprecated

if NUMPY_LT_2_0:
import numpy.core as np_core
Expand Down Expand Up @@ -356,6 +357,7 @@ def __init__(
super().__init__(shape_a, shape_a_idx, shape_b, shape_b_idx)


@deprecated("7.0", alternative="np.broadcast_shapes")
def check_broadcast(*shapes: tuple[int, ...]) -> tuple[int, ...]:
"""
Determines whether two or more Numpy arrays can be broadcast with each
Expand Down

0 comments on commit 0c06ac3

Please sign in to comment.