Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't return Quantities from high_level_objects_to_values #16287

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 24 additions & 1 deletion astropy/wcs/wcsapi/high_level_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import numbers
from collections import OrderedDict, defaultdict

import numpy as np
Expand Down Expand Up @@ -124,7 +125,8 @@

This function uses the information in ``wcs.world_axis_object_classes`` and
``wcs.world_axis_object_components`` to convert the high level objects
(such as `~.SkyCoord`) to low level "values" `~.Quantity` objects.
(such as `~.SkyCoord`) to low level "values" which should be scalars or
Numpy arrays.

This is used in `.HighLevelWCSMixin.world_to_pixel`, but provided as a
separate function for use in other places where needed.
Expand Down Expand Up @@ -240,6 +242,17 @@
else:
world.append(rec_getattr(objects[key], attr))

# Check the type of the return values - should be scalars or plain Numpy
# arrays, not e.g. Quantity. Note that we deliberately use type(w) because
# we don't want to match Numpy subclasses.
for w in world:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this have performance impact? 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I very much doubt it, it's only looping over a couple (or maybe up to 5) dimensions - not over each individual value of a large array

if not isinstance(w, numbers.Number) and not type(w) == np.ndarray:
raise TypeError(

Check warning on line 250 in astropy/wcs/wcsapi/high_level_api.py

View check run for this annotation

Codecov / codecov/patch

astropy/wcs/wcsapi/high_level_api.py#L248-L250

Added lines #L248 - L250 were not covered by tests
f"WCS world_axis_object_components results in "
f"values which are not scalars or plain Numpy "
f"arrays (got {type(w)})"
)

return world


Expand All @@ -262,6 +275,16 @@
low_level_wcs: `.BaseLowLevelWCS`
The WCS object to use to interpret the coordinates.
"""
# Check the type of the input values - should be scalars or plain Numpy
# arrays, not e.g. Quantity. Note that we deliberately use type(w) because
# we don't want to match Numpy subclasses.
for w in world_values:
if not isinstance(w, numbers.Number) and not type(w) == np.ndarray:
raise TypeError(

Check warning on line 283 in astropy/wcs/wcsapi/high_level_api.py

View check run for this annotation

Codecov / codecov/patch

astropy/wcs/wcsapi/high_level_api.py#L281-L283

Added lines #L281 - L283 were not covered by tests
f"Expected world coordinates as scalars or plain Numpy "
f"arrays (got {type(w)})"
)

# Cache the classes and components since this may be expensive
components = low_level_wcs.world_axis_object_components
classes = low_level_wcs.world_axis_object_classes
Expand Down
53 changes: 53 additions & 0 deletions astropy/wcs/wcsapi/tests/test_high_level_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import re

import numpy as np
import pytest
from numpy.testing import assert_allclose

from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.units import Quantity
from astropy.wcs import WCS
Expand Down Expand Up @@ -205,6 +209,55 @@ def test_values_to_objects():
assert c2.b == c2_out.b


class InvalidWCSQuantity(SkyCoordDuplicateWCS):
"""
WCS which defines ``world_axis_object_components`` which returns Quantity
instead of bare Numpy arrays, which can cause issues. This is for a
regression test to make sure that we don't return Quantities from
``world_axis_object_components``.
"""

@property
def world_axis_object_components(self):
return [
("test1", "ra", "spherical.lon"),
("test1", "dec", "spherical.lat"),
("test2", 0, "spherical.lon"),
("test2", 1, "spherical.lat"),
]


def test_objects_to_values_invalid_type():
wcs = InvalidWCSQuantity()
c1, c2 = wcs.pixel_to_world(1, 2, 3, 4)
with pytest.raises(
TypeError,
match=(
re.escape(
"WCS world_axis_object_components results in values which are not "
"scalars or plain Numpy arrays (got <class "
"'astropy.coordinates.angles.core.Longitude'>)"
)
),
):
high_level_objects_to_values(c1, c2, low_level_wcs=wcs)


def test_values_to_objects_invalid_type():
wcs = SkyCoordDuplicateWCS()
c1, c2 = wcs.pixel_to_world(1, 2, 3, 4)
with pytest.raises(
TypeError,
match=(
re.escape(
"Expected world coordinates as scalars or plain Numpy arrays (got "
"<class 'astropy.units.quantity.Quantity'>)"
)
),
):
values_to_high_level_objects(2 * u.m, 4, 6, 8, low_level_wcs=wcs)


class MinimalHighLevelWCS(HighLevelWCSMixin):
def __init__(self, low_level_wcs):
self._low_level_wcs = low_level_wcs
Expand Down
3 changes: 3 additions & 0 deletions docs/changes/wcs/16287.api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Errors may now occur if a ``BaseLowLevelWCS`` class defines
``world_axis_object_components`` which returns values that are not scalars or
plain Numpy arrays as per APE 14.