Skip to content

Commit

Permalink
Merge pull request #349 from keflavich/high_level_wcs
Browse files Browse the repository at this point in the history
  • Loading branch information
astrofrog committed Mar 13, 2023
2 parents 2145448 + fcd811e commit 3bb79cd
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 5 deletions.
15 changes: 11 additions & 4 deletions reproject/interpolation/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,8 +672,9 @@ def test_broadcast_reprojection(input_extra_dims, output_shape, input_as_wcs, ou
@pytest.mark.parametrize("input_extra_dims", (1, 2))
@pytest.mark.parametrize("output_shape", (None, "single", "full"))
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("header_or_wcs", (lambda x: x, WCS))
@pytest.mark.filterwarnings("ignore::astropy.wcs.wcs.FITSFixedWarning")
def test_blocked_broadcast_reprojection(input_extra_dims, output_shape, parallel):
def test_blocked_broadcast_reprojection(input_extra_dims, output_shape, parallel, header_or_wcs):
image_stack, array_ref, footprint_ref, header_in, header_out = _setup_for_broadcast_test()
# Test both single and multiple dimensions being broadcast
if input_extra_dims == 2:
Expand All @@ -689,6 +690,9 @@ def test_blocked_broadcast_reprojection(input_extra_dims, output_shape, parallel
# Provide the broadcast dimensions as part of the output shape
output_shape = image_stack.shape

# test different behavior when the output projection is a WCS
header_out = header_or_wcs(header_out)

array_broadcast, footprint_broadcast = reproject_interp(
(image_stack, header_in), header_out, output_shape, parallel=parallel, block_size=[5, 5]
)
Expand All @@ -701,9 +705,12 @@ def test_blocked_broadcast_reprojection(input_extra_dims, output_shape, parallel
@pytest.mark.parametrize("block_size", [[500, 500], [500, 100], None])
@pytest.mark.parametrize("return_footprint", [False, True])
@pytest.mark.parametrize("existing_outputs", [False, True])
@pytest.mark.parametrize("header_or_wcs", (lambda x: x, WCS))
@pytest.mark.remote_data
@pytest.mark.filterwarnings("ignore::astropy.wcs.wcs.FITSFixedWarning")
def test_blocked_against_single(parallel, block_size, return_footprint, existing_outputs):
def test_blocked_against_single(
parallel, block_size, return_footprint, existing_outputs, header_or_wcs
):
# Ensure when we break a reprojection down into multiple discrete blocks
# it has the same result as if all pixels where reprejcted at once

Expand All @@ -727,7 +734,7 @@ def test_blocked_against_single(parallel, block_size, return_footprint, existing

result_test = reproject_interp(
hdu2,
hdu1.header,
header_or_wcs(hdu1.header),
parallel=parallel,
block_size=block_size,
return_footprint=return_footprint,
Expand All @@ -737,7 +744,7 @@ def test_blocked_against_single(parallel, block_size, return_footprint, existing

result_reference = reproject_interp(
hdu2,
hdu1.header,
header_or_wcs(hdu1.header),
parallel=False,
block_size=None,
return_footprint=return_footprint,
Expand Down
20 changes: 20 additions & 0 deletions reproject/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from reproject.tests.helpers import assert_wcs_allclose
from reproject.utils import parse_input_data, parse_input_shape, parse_output_projection
from reproject.wcs_utils import has_celestial


@pytest.mark.filterwarnings("ignore:unclosed file:ResourceWarning")
Expand Down Expand Up @@ -89,3 +90,22 @@ def test_parse_output_projection_invalid_header(simple_celestial_fits_wcs):
def test_parse_output_projection_invalid_wcs(simple_celestial_fits_wcs):
with pytest.raises(ValueError, match="Need to specify shape"):
parse_output_projection(simple_celestial_fits_wcs)


@pytest.mark.filterwarnings("ignore::astropy.utils.exceptions.AstropyUserWarning")
@pytest.mark.filterwarnings("ignore::astropy.wcs.wcs.FITSFixedWarning")
def test_has_celestial():
from .test_high_level import INPUT_HDR

hdr = fits.Header.fromstring(INPUT_HDR)
ww = WCS(hdr)
assert ww.has_celestial
assert has_celestial(ww)

from astropy.wcs.wcsapi import HighLevelWCSWrapper, SlicedLowLevelWCS

wwh = HighLevelWCSWrapper(SlicedLowLevelWCS(ww, Ellipsis))
assert has_celestial(wwh)

wwh2 = HighLevelWCSWrapper(SlicedLowLevelWCS(ww, [slice(0, 1), slice(0, 1)]))
assert has_celestial(wwh2)
7 changes: 6 additions & 1 deletion reproject/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,12 @@ def reproject_single_block(a, block_info=None):
if a.ndim == 0 or block_info is None or block_info == []:
return np.array([a, a])
slices = [slice(*x) for x in block_info[None]["array-location"][-wcs_out.pixel_n_dim :]]
wcs_out_sub = HighLevelWCSWrapper(SlicedLowLevelWCS(wcs_out, slices=slices))

if isinstance(wcs_out, BaseHighLevelWCS):
low_level_wcs = SlicedLowLevelWCS(wcs_out.low_level_wcs, slices=slices)
else:
low_level_wcs = SlicedLowLevelWCS(wcs_out, slices=slices)
wcs_out_sub = HighLevelWCSWrapper(low_level_wcs)
if isinstance(array_in_or_path, str):
array_in = np.memmap(array_in_or_path, dtype=float, shape=shape_in)
else:
Expand Down
1 change: 1 addition & 0 deletions reproject/wcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from astropy.coordinates import SkyCoord
from astropy.wcs import WCS
from astropy.wcs.utils import pixel_to_pixel
from astropy.wcs.wcsapi.high_level_api import BaseHighLevelWCS

__all__ = ["has_celestial", "pixel_to_pixel_with_roundtrip"]

Expand Down

0 comments on commit 3bb79cd

Please sign in to comment.