Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start overhaul for embedding sort order #2998

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
86 changes: 79 additions & 7 deletions scanpy/plotting/_tools/scatterplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections.abc as cabc
import inspect
import sys
import warnings
from collections.abc import Mapping, Sequence # noqa: TCH003
from copy import copy
from functools import partial
Expand Down Expand Up @@ -63,7 +64,12 @@
mask_obs: NDArray[np.bool_] | str | None = None,
gene_symbols: str | None = None,
use_raw: bool | None = None,
sort_order: bool = True,
sort_order: bool | Empty = Empty,
order_continuous: Literal["ascending", "descending"]
| None
| np.ndarray
| Empty = Empty,
order_categorical: None | np.ndarray = None,
edges: bool = False,
edges_width: float = 0.1,
edges_color: str | Sequence[float] | Sequence[str] = "grey",
Expand Down Expand Up @@ -140,6 +146,9 @@
raise ValueError("Groups and mask arguments are incompatible.")
if mask_obs is not None:
mask_obs = _check_mask(adata, mask_obs, "obs")
order_continuous = check_continuous_order(
order_continuous, sort_order, adata.shape[0]
)

# Figure out if we're using raw
if use_raw is None:
Expand Down Expand Up @@ -282,12 +291,23 @@

# Order points
order = slice(None)
if sort_order and value_to_plot is not None and color_type == "cont":
# Higher values plotted on top, null values on bottom
order = np.argsort(-color_vector, kind="stable")[::-1]
elif sort_order and color_type == "cat":
# Null points go on bottom
order = np.argsort(~pd.isnull(color_source_vector), kind="stable")
if value_to_plot is None:
pass
elif color_type == "cont" and order_continuous is not None:
if isinstance(order_continuous, np.ndarray):
order = order_continuous

Check warning on line 298 in scanpy/plotting/_tools/scatterplots.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_tools/scatterplots.py#L298

Added line #L298 was not covered by tests
elif order_continuous == "ascending":
order = np.argsort(color_source_vector, kind="stable")
elif order_continuous == "descending":
order = np.argsort(-color_source_vector, kind="stable")

Check warning on line 302 in scanpy/plotting/_tools/scatterplots.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_tools/scatterplots.py#L301-L302

Added lines #L301 - L302 were not covered by tests

elif color_type == "cat" and order_categorical is not None:
order = order_categorical

Check warning on line 305 in scanpy/plotting/_tools/scatterplots.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_tools/scatterplots.py#L305

Added line #L305 was not covered by tests
if (masked_entries := pd.isnull(color_source_vector)).any():
if isinstance(order, slice):
order = np.arange(adata.n_obs)
order = order[np.argsort(~masked_entries[order], kind="stable")]

# Set orders
if isinstance(size, np.ndarray):
size = np.array(size)[order]
Expand Down Expand Up @@ -466,6 +486,58 @@
return axs


def check_continuous_order(
order_continuous: Literal["ascending", "descending"] | None | np.ndarray | Empty,
sort_order: bool | Empty,
N: int,
) -> Literal["ascending", "descending"] | None | np.ndarray:
# Backwards compat
if sort_order is not Empty:
warnings.warn(

Check warning on line 496 in scanpy/plotting/_tools/scatterplots.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_tools/scatterplots.py#L496

Added line #L496 was not covered by tests
"The `sort_order` parameter is deprecated and will be removed in the future. "
"Please use `order_continuous` and `order_categorical` instead.",
FutureWarning,
stacklevel=2,
)
if order_continuous is not Empty:
raise ValueError(

Check warning on line 503 in scanpy/plotting/_tools/scatterplots.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_tools/scatterplots.py#L502-L503

Added lines #L502 - L503 were not covered by tests
"Cannot specify both `sort_order` and `order_continuous`. "
"Please use only `order_continuous`."
)
elif sort_order:
order_continuous = "ascending"

Check warning on line 508 in scanpy/plotting/_tools/scatterplots.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_tools/scatterplots.py#L507-L508

Added lines #L507 - L508 were not covered by tests
else:
order_continuous = None

Check warning on line 510 in scanpy/plotting/_tools/scatterplots.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_tools/scatterplots.py#L510

Added line #L510 was not covered by tests
elif order_continuous is Empty:
# Default path
order_continuous = "ascending"
elif isinstance(order_continuous, np.ndarray) and order_continuous.shape != (N,):
raise ValueError(

Check warning on line 515 in scanpy/plotting/_tools/scatterplots.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_tools/scatterplots.py#L514-L515

Added lines #L514 - L515 were not covered by tests
f"order_continuous array must have shape ({N},). Got shape {order_continuous.shape}."
)
elif order_continuous not in ["ascending", "descending", None]:
raise ValueError(

Check warning on line 519 in scanpy/plotting/_tools/scatterplots.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_tools/scatterplots.py#L518-L519

Added lines #L518 - L519 were not covered by tests
f"order_continuous must be 'ascending', 'descending', None, or an array of values. Got {order_continuous}."
)
return order_continuous


def check_categorical_order(
order_categorical: None | np.ndarray, N: int
) -> None | np.ndarray:
if order_categorical is None:
pass
elif isinstance(order_categorical, np.ndarray) and order_categorical.shape != (N,):
raise ValueError(

Check warning on line 531 in scanpy/plotting/_tools/scatterplots.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_tools/scatterplots.py#L528-L531

Added lines #L528 - L531 were not covered by tests
f"order_categorical array must have shape ({N},). Got shape {order_categorical.shape}."
)
else:
raise ValueError(

Check warning on line 535 in scanpy/plotting/_tools/scatterplots.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_tools/scatterplots.py#L535

Added line #L535 was not covered by tests
"order_categorical must be None or an array of values. Got {order_categorical}."
)
return order_categorical

Check warning on line 538 in scanpy/plotting/_tools/scatterplots.py

View check run for this annotation

Codecov / codecov/patch

scanpy/plotting/_tools/scatterplots.py#L538

Added line #L538 was not covered by tests


def _panel_grid(hspace, wspace, ncols, num_panels):
from matplotlib import gridspec

Expand Down