Skip to content

Commit

Permalink
Merge pull request #911 from rosteen/free-the-regions
Browse files Browse the repository at this point in the history
Allow descending spectral axes and more flexible SpectralRegion bounds
  • Loading branch information
eteq committed Jan 21, 2022
2 parents 7efabf8 + 84a46b5 commit 54d01a3
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 89 deletions.
6 changes: 5 additions & 1 deletion docs/spectrum1d.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@ create it explicitly from arrays or `~astropy.units.Quantity` objects:
>>> ax.set_ylabel("Flux") # doctest: +SKIP

.. note::
Note that the ``spectral_axis`` can also be provided as a :class:`~specutils.SpectralAxis` object,
The ``spectral_axis`` can also be provided as a :class:`~specutils.SpectralAxis` object,
and in fact will internally convert the spectral_axis to :class:`~specutils.SpectralAxis` if it
is provided as an array or `~astropy.units.Quantity`.

.. note::
The ``spectral_axis`` can be either ascending or descending, but must be monotonic
in either case.

Reading from a File
-------------------

Expand Down
117 changes: 50 additions & 67 deletions specutils/manipulation/extract_spectral_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,38 @@
__all__ = ['extract_region', 'extract_bounding_spectral_region', 'spectral_slab']


def _to_edge_pixel(subregion, spectrum):
def _edge_value_to_pixel(edge_value, spectrum, order, side):
spectral_axis = spectrum.spectral_axis
if order == "ascending":
if edge_value > spectral_axis[-1]:
return len(spectral_axis)
if edge_value < spectral_axis[0]:
return 0

elif order == "descending":
if edge_value < spectral_axis[-1]:
return len(spectral_axis)
if edge_value > spectral_axis[0]:
return 0

try:
if hasattr(spectrum.wcs, "spectral"):
index = spectrum.wcs.spectral.world_to_pixel(edge_value)
else:
index = spectrum.wcs.world_to_pixel(edge_value)

if side == "left":
index = int(np.ceil(index))
elif side == "right":
index = int(np.floor(index)) + 1

return index

except Exception as e:
raise ValueError(f"Bound {edge_value}, could not be converted to pixel index"
f" using spectrum's WCS. Exception: {e}")

def _subregion_to_edge_pixels(subregion, spectrum):
"""
Calculate and return the left and right indices defined
by the lower and upper bounds and based on the input
Expand All @@ -28,70 +59,34 @@ def _to_edge_pixel(subregion, spectrum):
Left and right indices defined by the lower and upper bounds.
"""
# TODO: spectral regions cannot handle strictly ascending spectral axis
# values. Instead, convert to length space if axis given in a desceninding
# unit space (e.g. frequency). We should find a more elegant solution.
spectral_axis = spectrum.spectral_axis
if spectral_axis[-1] > spectral_axis[0]:
order = "ascending"
left_func = min
right_func = max
else:
order = "descending"
left_func = max
right_func = min

if spectrum.spectral_axis.unit.physical_type != 'length' and \
spectrum.spectral_axis.unit.is_equivalent(
u.AA, equivalencies=u.spectral()):
spectral_axis = spectrum.spectral_axis.to(
u.AA, equivalencies=u.spectral())

#
# Left/lower side of sub region
#
if subregion[0].unit.is_equivalent(u.pix):
if subregion[0].unit.is_equivalent(u.pix) and not spectral_axis.unit.is_equivalent(u.pix):
left_index = floor(subregion[0].value)
else:
# Convert lower value to spectrum spectral_axis units
left_reg_in_spec_unit = subregion[0].to(spectral_axis.unit,
u.spectral())
left_reg_in_spec_unit = left_func(subregion).to(spectral_axis.unit,
u.spectral())
left_index = _edge_value_to_pixel(left_reg_in_spec_unit, spectrum, order, "left")

if left_reg_in_spec_unit < spectral_axis[0]:
left_index = 0
elif left_reg_in_spec_unit > spectral_axis[-1]:
left_index = len(spectrum.spectral_axis)
else:
try:
if hasattr(spectrum.wcs, "spectral"):
left_index = spectrum.wcs.spectral.world_to_pixel(left_reg_in_spec_unit)
else:
left_index = spectrum.wcs.world_to_pixel(left_reg_in_spec_unit)
left_index = int(np.ceil(left_index))
except Exception as e:
raise ValueError(
"Lower value, {}, could not convert using spectrum's WCS "
"{}. Exception: {}".format(
left_reg_in_spec_unit, spectrum.wcs, e))

#
# Right/upper side of sub region
#
if subregion[1].unit.is_equivalent(u.pix):
if subregion[1].unit.is_equivalent(u.pix) and not spectral_axis.unit.is_equivalent(u.pix):
right_index = ceil(subregion[1].value)
else:
# Convert upper value to spectrum spectral_axis units
right_reg_in_spec_unit = subregion[1].to(spectral_axis.unit,
right_reg_in_spec_unit = right_func(subregion).to(spectral_axis.unit,
u.spectral())

if right_reg_in_spec_unit > spectral_axis[-1]:
right_index = len(spectrum.spectral_axis)
elif right_reg_in_spec_unit < spectral_axis[0]:
right_index = 0
else:
try:
if hasattr(spectrum.wcs, "spectral"):
right_index = spectrum.wcs.spectral.world_to_pixel(right_reg_in_spec_unit)
else:
right_index = spectrum.wcs.world_to_pixel(right_reg_in_spec_unit)
right_index = int(np.floor(right_index)) + 1
except Exception as e:
raise ValueError(
"Upper value, {}, could not convert using spectrum's WCS "
"{}. Exception: {}".format(
right_reg_in_spec_unit, spectrum.wcs, e))
right_index = _edge_value_to_pixel(right_reg_in_spec_unit, spectrum, order, "right")

return left_index, right_index

Expand Down Expand Up @@ -144,26 +139,14 @@ def extract_region(spectrum, region, return_single_spectrum=False):
"""
extracted_spectrum = []
for subregion in region._subregions:
left_index, right_index = _to_edge_pixel(subregion, spectrum)
left_index, right_index = _subregion_to_edge_pixels(subregion, spectrum)

# If both indices are out of bounds then return an empty spectrum
if left_index is None and right_index is None:
if left_index == right_index:
empty_spectrum = Spectrum1D(spectral_axis=[]*spectrum.spectral_axis.unit,
flux=[]*spectrum.flux.unit)
extracted_spectrum.append(empty_spectrum)
else:

# If only one index is out of bounds then set it to
# the lower or upper extent
if left_index is None:
left_index = 0

if right_index is None:
right_index = len(spectrum.spectral_axis)

if left_index > right_index:
left_index, right_index = right_index, left_index

extracted_spectrum.append(spectrum[..., left_index:right_index])

# If there is only one subregion in the region then we will
Expand Down Expand Up @@ -276,7 +259,7 @@ def extract_bounding_spectral_region(spectrum, region):
max_right = -sys.maxsize - 1

# Look for indices that bound the entire set of sub-regions.
index_list = [_to_edge_pixel(sr, spectrum) for sr in region._subregions]
index_list = [_subregion_to_edge_pixels(sr, spectrum) for sr in region._subregions]

for left_index, right_index in index_list:
if left_index is not None:
Expand Down
18 changes: 7 additions & 11 deletions specutils/spectra/spectral_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def __init__(self, *args):
f'positional arguments but {len(args)} were given')

# Check validity of the input sub regions.
if not self._valid():
raise ValueError("SpectralRegion 2-tuple lower extent must be "
"less than upper extent.")
self._valid()

# The sub-regions are to always be ordered based on the lower bound.
self._reorder()
Expand Down Expand Up @@ -184,14 +182,12 @@ def __delitem__(self, item):

def _valid(self):

# Lower bound < Upper bound for all sub regions in length physical type
with u.set_enabled_equivalencies(u.spectral()):
sub_regions = [(x[0].to('m'), x[1].to('m'))
if x[0].unit.is_equivalent(u.m) else x
for x in self._subregions]

if any(x[0] >= x[1] for x in sub_regions):
raise ValueError('Lower bound must be strictly less than the upper bound')
bound_unit = self._subregions[0][0].unit
for x in self._subregions:
if x[0].unit != bound_unit or x[1].unit != bound_unit:
raise ValueError("All SpectralRegion bounds must have the same unit.")
if x[0] == x[1]:
raise ValueError("Upper and lower bound must be different values.")

return True

Expand Down
11 changes: 10 additions & 1 deletion specutils/tests/spectral_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class SpectraExamples:
5. s1_um_mJy_e1_masked - same as 1, but with a random set of pixels
masked.
6. s1_um_mJy_e1_desc - same as 1, but with the spectral axis in
descending rather than ascending order.
"""

def __init__(self):
Expand Down Expand Up @@ -91,13 +94,16 @@ def __init__(self):
self._s1_AA_nJy_e4 = Spectrum1D(spectral_axis=self.wavelengths_AA * u.AA,
flux=self._flux_e4 * u.nJy)


#
# Create one spectrum like 1 but with a mask
#
self._s1_um_mJy_e1_masked = copy(self._s1_um_mJy_e1) # SHALLOW copy - the data are shared with the above non-masked case
self._s1_um_mJy_e1_masked.mask = (np.random.randn(*self.base_flux.shape) + 1) > 0

# Create a spectrum like 1, but with descending spectral axis
self._s1_um_mJy_e1_desc = Spectrum1D(spectral_axis=self.wavelengths_um[::-1] * u.um,
flux=self._flux_e1[::-1] * u.mJy)


@property
def s1_um_mJy_e1(self):
Expand Down Expand Up @@ -135,6 +141,9 @@ def s1_AA_nJy_e4_flux(self):
def s1_um_mJy_e1_masked(self):
return self._s1_um_mJy_e1_masked

@property
def s1_um_mJy_e1_desc(self):
return self._s1_um_mJy_e1_desc


@pytest.fixture
Expand Down
32 changes: 32 additions & 0 deletions specutils/tests/test_region_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,38 @@ def test_region_empty(simulated_spectra):
region = SpectralRegion(3*u.um, 3*u.um)


def test_region_descending(simulated_spectra):
np.random.seed(42)

spectrum = simulated_spectra.s1_um_mJy_e1
uncertainty = StdDevUncertainty(0.1*np.random.random(len(spectrum.flux))*u.mJy)
spectrum.uncertainty = uncertainty

region = SpectralRegion(0.8*u.um, 0.6*u.um)

sub_spectrum = extract_region(spectrum, region)

sub_spectrum_flux_expected = np.array(FLUX_ARRAY)

assert quantity_allclose(sub_spectrum.flux.value, sub_spectrum_flux_expected)


def test_descending_spectral_axis(simulated_spectra):
spectrum = simulated_spectra.s1_um_mJy_e1_desc

sub_spectrum_flux_expected = np.array(FLUX_ARRAY[::-1])

region = SpectralRegion(0.8*u.um, 0.6*u.um)
sub_spectrum = extract_region(spectrum, region)

assert quantity_allclose(sub_spectrum.flux.value, sub_spectrum_flux_expected)

region = SpectralRegion(0.6*u.um, 0.8*u.um)
sub_spectrum = extract_region(spectrum, region)

assert quantity_allclose(sub_spectrum.flux.value, sub_spectrum_flux_expected)


def test_region_two_sub(simulated_spectra):
np.random.seed(42)

Expand Down
18 changes: 9 additions & 9 deletions specutils/utils/wcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,6 @@ def gwcs_from_array(array):
"""
orig_array = u.Quantity(array)

# TODO: Input arrays must be strictly ascending. This is not always the
# case for a spectral axis (e.g. when in frequency space). Thus, we
# convert to wavelength to create the wcs.
if orig_array.unit.physical_type != 'length' and \
orig_array.unit.is_equivalent(u.AA, equivalencies=u.spectral()):
array = orig_array.to(u.AA, equivalencies=u.spectral())

coord_frame = cf.CoordinateFrame(naxes=1,
axes_type=('SPECTRAL',),
axes_order=(0,))
Expand All @@ -209,8 +202,15 @@ def gwcs_from_array(array):

forward_transform = SpectralTabular1D(np.arange(len(array)),
lookup_table=array)
forward_transform.inverse = SpectralTabular1D(
array, lookup_table=np.arange(len(array)))
# If our spectral axis is in descending order, we have to flip the lookup
# table to be ascending in order for world_to_pixel to work.
if len(array) == 0 or array[-1] > array[0]:
forward_transform.inverse = SpectralTabular1D(
array, lookup_table=np.arange(len(array)))
else:
forward_transform.inverse = SpectralTabular1D(
array[::-1], lookup_table=np.arange(len(array))[::-1])


class SpectralGWCS(GWCS):
def pixel_to_world(self, *args, **kwargs):
Expand Down

0 comments on commit 54d01a3

Please sign in to comment.