Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plans for vectorisation / multi-region objects #497

Open
AlecThomson opened this issue Jan 18, 2023 · 2 comments
Open

Plans for vectorisation / multi-region objects #497

AlecThomson opened this issue Jan 18, 2023 · 2 comments

Comments

@AlecThomson
Copy link

Hi all!

Thanks for the work on this neat module! I've run into a couple of situations where I want to do operations with many regions e.g. if SkyCoord falls inside a 'list' of regions. 'Air quotes' on the 'list' here, as it looks like there aren't vectorised versions of regions (yet). So, I was wondering, what are the plans were for supporting vectors of regions?

I've got a little sketch of what support for vectorised operations could look like, and I'd be interested to get the maintainers thoughts. To my mind, there are two ways of going about to a) reworking the current regions objects to support vectorisation (making the current scalar form just the 0 length case), or b) producing vectorised versions of the objects with the appropriate functionality. In either case a lot of the core functionality is already in place from np.ndarrays, u.Quantitys, SkyCoords etc. Cases like plotting will need some careful consideration - only briefly thinking about it I think loops would be needed if the vectorised case was supported.

I've got a quick sketch for option b) which I've been playing with in my fork. Here's an example of one possible class in regions/shapes/rectangles.py:

class RectangleSkyRegions(SkyRegion):

    _params = ("centers", "widths", "heights", "angles")
    centers = VectorSkyCoord("The center positions as a |SkyCoord|. ")
    widths = PositiveVectorAngle(
        "The widths of the rectangles (before rotation) " "as a |Quantity| angle."
    )
    heights = PositiveVectorAngle(
        "The heights of the rectangles (before " "rotation) as a |Quantity| angle."
    )
    angles = VectorAngle(
        "The rotation angles measured anti-clockwise as a " "|Quantity| angle."
    )
    meta = RegionMetaDescr("The meta attributes as a |RegionMeta|")
    visual = RegionVisualDescr("The visual attributes as a |RegionVisual|.")

    def __init__(self, centers, widths, heights, angles, meta=None, visual=None):
        self.centers = centers
        self.widths = widths
        self.heights = heights
        self.angles = angles
        self.meta = meta or RegionMeta()
        self.visual = visual or RegionVisual()

    def __len__(self):
        return len(self.centers)

    def __getitem__(self, item):
        if isinstance(item, slice):
            return RectangleSkyRegions(
                self.centers[item],
                self.widths[item],
                self.heights[item],
                self.angles[item],
                meta=self.meta,
                visual=self.visual,
            )
        return RectangleSkyRegion(
            self.centers[item],
            self.widths[item],
            self.heights[item],
            self.angles[item],
            meta=self.meta,
            visual=self.visual,
        )

    def __iter__(self):
        for i in range(len(self)):
            yield self[i]

    @property
    def area(self):
        return self.widths * self.heights

    def to_pixel(self, wcs):
        centers, pixscales, north_angles = pixel_scale_angle_at_skycoord(
            self.centers, wcs
        )
        widths = (self.widths / pixscales).to(u.pix).value
        heights = (self.heights / pixscales).to(u.pix).value
        # Region sky angles are defined relative to the WCS longitude axis;
        # photutils aperture sky angles are defined as the PA of the
        # semimajor axis (i.e., relative to the WCS latitude axis)
        angles = self.angles + (north_angles - 90 * u.deg)
        return RectanglePixelRegions(
            centers,
            widths,
            heights,
            angles=angles,
            meta=self.meta.copy(),
            visual=self.visual.copy(),
        )

    def contains(self, coord):
        cos_angle = np.cos(self.angles)
        sin_angle = np.sin(self.angles)
        dx = coord.ra - self.centers.ra
        dy = coord.dec - self.centers.dec
        dx_rot = cos_angle * dx + sin_angle * dy
        dy_rot = sin_angle * dx - cos_angle * dy
        in_rect = (np.abs(dx_rot) < self.widths * 0.5) & (
            np.abs(dy_rot) < self.heights * 0.5
        )
        if self.meta.get("include", True):
            return in_rect
        else:
            return np.logical_not(in_rect)

and the extras to regions/core/attributes.py:

class VectorSkyCoord(RegionAttribute):
    """
    Descriptor class to check that value is a vector
    `~astropy.coordinates.SkyCoord`.
    """

    def _validate(self, value):
        if not (isinstance(value, SkyCoord) and not value.isscalar):
            raise ValueError(f'{self.name!r} must be a vector SkyCoord')

class VectorAngle(RegionAttribute):
    """
    Descriptor class to check that value is a vector angle, either an
    `~astropy.coordinates.Angle` or `~astropy.units.Quantity` with
    angular units.
    """

    def _validate(self, value):
        if isinstance(value, Quantity):
            if value.isscalar:
                raise ValueError(f'{self.name!r} must be a vector')

            if not value.unit.physical_type == 'angle':
                raise ValueError(f'{self.name!r} must have angular units')
        else:
            raise ValueError(f'{self.name!r} must be a vector angle')
class PositiveVectorAngle(RegionAttribute):
    """
    Descriptor class to check that value is a strictly positive
    vector angle, either an `~astropy.coordinates.Angle` or
    `~astropy.units.Quantity` with angular units.
    """

    def _validate(self, value):
        if isinstance(value, Quantity):
            if value.isscalar:
                raise ValueError(f'{self.name!r} must be a vector')

            if not value.unit.physical_type == 'angle':
                raise ValueError(f'{self.name!r} must have angular units')

            if not np.all(value > 0):
                raise ValueError(f'{self.name!r} must be strictly positive')
        else:
            raise ValueError(f'{self.name!r} must be a strictly positive '
                             'vector angle')

Let me know what you think of this, and I'm happy to start a PR if this wold be useful :)

@keflavich
Copy link
Contributor

I think this is a great idea, as this is quite often my main use case for regions (i.e., making many and combining them).

@larrybradley have you given this any further consideration?

@AlecThomson
Copy link
Author

Thanks @keflavich! Just to show my initial motivation for this, here's a dummy example of checking if some set coordinates lie within a set of regions. With my implementation above, going vectorised gives ~10x speedup:

import numpy as np
from astropy.io import fits
from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.wcs import WCS
from regions import RectangleSkyRegions, RectangleSkyRegion

# Define a few rectangular regions and some coordinates
centers=SkyCoord([1, 2, 3], [1, 2, 3], unit="deg")
widths=np.array([1, 2, 3]) * u.deg
heights=np.array([1, 2, 3]) * u.deg
angles=np.array([1, 2, 3]) * u.deg
coords = SkyCoord([1, 2, 4], [1, 2, 4], unit="deg")
# Create a dummy header with a WCS
header = fits.Header(
    {
        "NAXIS": 2,
        "NAXIS1": 10,
        "NAXIS2": 10,
        "CTYPE1": "RA---TAN",
        "CRVAL1": 0,
        "CRPIX1": 5,
        "CDELT1": -0.1,
        "CUNIT1": "deg",
        "CTYPE2": "DEC--TAN",
        "CRVAL2": 0,
        "CRPIX2": 5,
        "CDELT2": 0.1,
        "CUNIT2": "deg",
    }
)
wcs = WCS(header)

Benchmarking with loops (currently required):

%%timeit
# Current
rectangles = [
    RectangleSkyRegion(
        center=center, width=width, height=height, angle=angle
    ) for center, width, height, angle in zip(centers, widths, heights, angles)
]
rec_pix = [r.to_pixel(wcs) for r in rectangles]
coord_check = [r.contains(c, wcs) for r, c in zip(rectangles, coords)]
109 ms ± 33.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

And with vectors:

%%timeit
# Vectorised
rectangles = RectangleSkyRegions(
    centers=centers, widths=widths, heights=heights, angles=angles
)
rec_pix = rectangles.to_pixel(wcs)
coord_check = rectangles.contains(coords)
13.5 ms ± 743 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants