Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-do of plot_quick #1106

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
73 changes: 73 additions & 0 deletions specutils/spectra/spectrum1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,3 +768,76 @@
result = "<Spectrum1D({})>".format(inner_str)

return result

def plot(self, ax=None, x_name='spectral axis', y_name='flux',
set_quantity_support=True, **kwargs):
"""
Visualize this spectrum using matplotlib in "histogram style".

Parameters
----------
ax : `matplotlib.axes.Axes` or None
The axis to plot this figure into. If None, use the current
``pyplot`` axes (which will create a new figure if none exists).
x_name : str or None
The name to use for the x axis (units will be automatically added)
or None to not set the x axis label.
y_name : str or None
The name to use for the y axis (units will be automatically added)
or None to not set the y axis label.
set_quantity_support : bool
If True, call `astropy.visualization.quantity_support` to ensure
that the quantities in the plot are properly settable.

kwargs are passed into `~matplotlib.axes.Axes.plot`, except for
``drawstyle`` or ``ds``.

Returns
-------
ax : `matplotlib.axes.Axes`
Either ``ax``, or the newly created axes object (if the ``ax``
parameter is None).
"""
# import is intentionally inside the method to make matplotlib an
# "optional" dependency
from matplotlib import pyplot as plt
from astropy.visualization import quantity_support

if set_quantity_support:
quantity_support()

if 'drawstyle' in kwargs or 'ds' in kwargs:
raise TypeError("cannot set draw style in a spectrum's plot_quick")

Check warning on line 810 in specutils/spectra/spectrum1d.py

View check run for this annotation

Codecov / codecov/patch

specutils/spectra/spectrum1d.py#L810

Added line #L810 was not covered by tests

kwargs['drawstyle'] = 'steps-post'

if len(self.shape) != 1:
nspecdim = len(self.shape) - 1
indexing_hint = 'spec[' + ', '.join(['0']*nspecdim) + ']'
raise ValueError(f'plot_quick can only be used on 1d spectra. To '

Check warning on line 817 in specutils/spectra/spectrum1d.py

View check run for this annotation

Codecov / codecov/patch

specutils/spectra/spectrum1d.py#L815-L817

Added lines #L815 - L817 were not covered by tests
'get the first spectrum, try {indexing_hint}')

if ax is None:
ax = plt.gca()

# TODO: replace below with self.bin_edges once it is correct
mid_bin_edges = (self.spectral_axis[1:] + self.spectral_axis[:-1])/2
bin_edges = np.concatenate([(self.spectral_axis[0]*2-mid_bin_edges[0]).ravel(),
mid_bin_edges,
(self.spectral_axis[-1]*2-mid_bin_edges[-1]).ravel()])

# for a plot with steps-post, the last horizontal line requires a repeat
# of the last flux value
extended_flux = np.concatenate([self.flux, [self.flux[-1]]])

ax.plot(bin_edges, extended_flux, **kwargs)

if x_name is not None:
sa_unit = self.spectral_axis.unit.to_string(format='latex_inline')
ax.set_xlabel(x_name + f' [{sa_unit}]')

if y_name is not None:
flux_unit = self.flux.unit.to_string(format='latex_inline')
ax.set_ylabel(y_name + f' [{flux_unit}]')

return ax
13 changes: 13 additions & 0 deletions specutils/tests/test_spectrum1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from .conftest import remote_access
from ..spectra import Spectrum1D

try:
import matplotlib
HAS_MATPLOTLIB = True
except ImportError:
HAS_MATPLOTLIB = False


def test_empty_spectrum():
spec = Spectrum1D(spectral_axis=[]*u.um,
Expand Down Expand Up @@ -538,3 +544,10 @@ def test_spectral_axis_direction():
wave = [3, 2, 1] * u.nm
spec1d = Spectrum1D(spectral_axis=wave, flux=flux)
assert spec1d.spectral_axis_direction == 'decreasing'


@pytest.mark.skipif('not HAS_MATPLOTLIB')
def test_plot():
spec_single_flux = Spectrum1D([1, 2] * u.Jy, [3, 4] * u.nm)
ax = spec_single_flux.plot()
assert isinstance(ax, matplotlib.axes.Axes)