Skip to content

Commit

Permalink
Add ability to specify which vars to aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
qiemem committed Jul 23, 2023
1 parent 3ef2a0f commit 8264a71
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
7 changes: 4 additions & 3 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,19 +1185,20 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:

def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None:

grouping_vars = [v for v in PROPERTIES if v not in "xy"]
grouping_vars += ["col", "row", "group"]

pair_vars = spec._pair_spec.get("structure", {})

for layer in layers:

data = layer["data"]
mark = layer["mark"]
stat = layer["stat"]

if stat is None:
continue
target_vars = getattr(stat, "target_vars", "xy")

grouping_vars = [v for v in PROPERTIES if v not in target_vars]
grouping_vars += ["col", "row", "group"]

iter_axes = itertools.product(*[
pair_vars.get(axis, [axis]) for axis in "xy"
Expand Down
12 changes: 8 additions & 4 deletions seaborn/_stats/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import ClassVar, Callable
from typing import ClassVar, Callable, Iterable

import pandas as pd
from pandas import DataFrame
Expand All @@ -21,6 +21,9 @@ class Agg(Stat):
----------
func : str or callable
Name of a :class:`pandas.Series` method or a vector -> scalar function.
target_vars : list of strings
Variables to perform the aggregation on. Defaults to x or y, depending on
orientation.
See Also
--------
Expand All @@ -32,18 +35,19 @@ class Agg(Stat):
"""
func: str | Callable[[Vector], float] = "mean"
target_vars: Iterable[str] = ("x", "y")

group_by_orient: ClassVar[bool] = True

def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

var = {"x": "y", "y": "x"}.get(orient)
vars = [v for v in self.target_vars if v != orient]
res = (
groupby
.agg(data, {var: self.func})
.dropna(subset=[var])
.agg(data, {var: self.func for var in vars})
.dropna(subset=vars)
.reset_index(drop=True)
)
return res
Expand Down

0 comments on commit 8264a71

Please sign in to comment.