Skip to content

Commit

Permalink
Moved the summary bars, contrast bars, and swarm bars into plot_tools
Browse files Browse the repository at this point in the history
  • Loading branch information
JAnns98 committed Apr 17, 2024
1 parent cb28e6c commit d4d62cf
Show file tree
Hide file tree
Showing 5 changed files with 362 additions and 146 deletions.
6 changes: 6 additions & 0 deletions dabest/_modidx.py
Expand Up @@ -84,12 +84,18 @@
'dabest.plot_tools.SwarmPlot.plot': ('API/plot_tools.html#swarmplot.plot', 'dabest/plot_tools.py'),
'dabest.plot_tools.check_data_matches_labels': ( 'API/plot_tools.html#check_data_matches_labels',
'dabest/plot_tools.py'),
'dabest.plot_tools.contrast_bars_plotter': ( 'API/plot_tools.html#contrast_bars_plotter',
'dabest/plot_tools.py'),
'dabest.plot_tools.error_bar': ('API/plot_tools.html#error_bar', 'dabest/plot_tools.py'),
'dabest.plot_tools.get_swarm_spans': ('API/plot_tools.html#get_swarm_spans', 'dabest/plot_tools.py'),
'dabest.plot_tools.halfviolin': ('API/plot_tools.html#halfviolin', 'dabest/plot_tools.py'),
'dabest.plot_tools.normalize_dict': ('API/plot_tools.html#normalize_dict', 'dabest/plot_tools.py'),
'dabest.plot_tools.sankeydiag': ('API/plot_tools.html#sankeydiag', 'dabest/plot_tools.py'),
'dabest.plot_tools.single_sankey': ('API/plot_tools.html#single_sankey', 'dabest/plot_tools.py'),
'dabest.plot_tools.summary_bars_plotter': ( 'API/plot_tools.html#summary_bars_plotter',
'dabest/plot_tools.py'),
'dabest.plot_tools.swarm_bars_plotter': ( 'API/plot_tools.html#swarm_bars_plotter',
'dabest/plot_tools.py'),
'dabest.plot_tools.swarmplot': ('API/plot_tools.html#swarmplot', 'dabest/plot_tools.py'),
'dabest.plot_tools.width_determine': ('API/plot_tools.html#width_determine', 'dabest/plot_tools.py')},
'dabest.plotter': {'dabest.plotter.effectsize_df_plotter': ('API/plotter.html#effectsize_df_plotter', 'dabest/plotter.py')}}}
160 changes: 159 additions & 1 deletion dabest/plot_tools.py
Expand Up @@ -5,7 +5,8 @@

# %% auto 0
__all__ = ['halfviolin', 'get_swarm_spans', 'error_bar', 'check_data_matches_labels', 'normalize_dict', 'width_determine',
'single_sankey', 'sankeydiag', 'swarmplot', 'SwarmPlot']
'single_sankey', 'sankeydiag', 'summary_bars_plotter', 'contrast_bars_plotter', 'swarm_bars_plotter',
'swarmplot', 'SwarmPlot']

# %% ../nbs/API/plot_tools.ipynb 4
import math
Expand All @@ -17,6 +18,7 @@
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.axes as axes
import matplotlib.patches as mpatches
from collections import defaultdict
from typing import List, Tuple, Dict, Iterable, Union
from pandas.api.types import CategoricalDtype
Expand Down Expand Up @@ -778,6 +780,162 @@ def sankeydiag(
ax.set_xticks([0, 1])
ax.set_xticklabels(sankey_ticks)

def summary_bars_plotter(summary_bars: list, results: object, ax_to_plot: object,
float_contrast: bool,summary_bars_kwargs: dict, ci_type: str,
ticks_to_plot: list, color_col: str, swarm_colors: list,
proportional: bool, is_paired: bool):
"""
Add summary bars to the contrast plot.
Parameters
----------
summary_bars : list
List of indices of the contrast objects to plot summary bars for.
results : object (Dataframe)
Dataframe of contrast object comparisons.
ax_to_plot : object
Matplotlib axis object to plot on.
float_contrast : bool
Whether the DABEST plot uses Gardner-Altman or Cummings.
summary_bars_kwargs : dict
Keyword arguments for the summary bars.
ci_type : str
Type of confidence interval to plot.
ticks_to_plot : list
List of indices of the contrast objects.
color_col : str
Column name of the color column.
swarm_colors : list
List of colors used in the plot.
proportional : bool
Whether the data is proportional.
is_paired : bool
Whether the data is paired.
"""
# Begin checks
if not isinstance(summary_bars, list):
raise TypeError("summary_bars must be a list of indices (ints).")
if not all(isinstance(i, int) for i in summary_bars):
raise TypeError("summary_bars must be a list of indices (ints).")
if any(i >= len(results) for i in summary_bars):
raise ValueError("Index {} chosen is out of range for the contrast objects.".format([i for i in summary_bars if i >= len(results)]))
if float_contrast:
raise ValueError("summary_bars cannot be used with Gardner-Altman plots.")
# End checks
else:
summary_xmin, summary_xmax = ax_to_plot.get_xlim()
summary_bars_colors = [summary_bars_kwargs.get('color')]*(len(summary_bars)+1) if summary_bars_kwargs.get('color') is not None else ['black']*(max(summary_bars)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors
summary_bars_kwargs.pop('color')
for summary_index in summary_bars:
if ci_type == "bca":
summary_ci_low = results.bca_low[summary_index]
summary_ci_high = results.bca_high[summary_index]
else:
summary_ci_low = results.pct_low[summary_index]
summary_ci_high = results.pct_high[summary_index]

summary_color = summary_bars_colors[ticks_to_plot[summary_index]]

ax_to_plot.add_patch(mpatches.Rectangle((summary_xmin,summary_ci_low),summary_xmax+1,
summary_ci_high-summary_ci_low, zorder=-2, color=summary_color, **summary_bars_kwargs))


def contrast_bars_plotter(results: object, ax_to_plot: object, swarm_plot_ax: object,
ticks_to_plot: list, contrast_bars_kwargs: dict, color_col: str,
swarm_colors: list, show_mini_meta: bool, mini_meta_delta: object,
show_delta2: bool, delta_delta: object, proportional: bool, is_paired: bool):
"""
Add contrast bars to the contrast plot.
Parameters
----------
results : object (Dataframe)
Dataframe of contrast object comparisons.
ax_to_plot : object
Matplotlib axis object to plot on.
swarm_plot_ax : object (ax)
Matplotlib axis object of the swarm plot.
ticks_to_plot : list
List of indices of the contrast objects.
contrast_bars_kwargs : dict
Keyword arguments for the contrast bars.
color_col : str
Column name of the color column.
swarm_colors : list
List of colors used in the plot.
show_mini_meta : bool
Whether to show the mini meta-analysis.
mini_meta_delta : object
Mini meta-analysis object.
show_delta2 : bool
Whether to show the delta-delta.
delta_delta : object
delta-delta object.
proportional : bool
Whether the data is proportional.
is_paired : bool
Whether the data is paired.
"""
contrast_means = []
for j, tick in enumerate(ticks_to_plot):
contrast_means.append(results.difference[j])

contrast_bars_colors = [contrast_bars_kwargs.get('color')]*(len(ticks_to_plot)+1) if contrast_bars_kwargs.get('color') is not None else ['black']*(max(ticks_to_plot)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors
contrast_bars_kwargs.pop('color')
for contrast_bars_x,contrast_bars_y in zip(ticks_to_plot, contrast_means):
ax_to_plot.add_patch(mpatches.Rectangle((contrast_bars_x-0.25,0),0.5, contrast_bars_y, zorder=-1, color=contrast_bars_colors[contrast_bars_x], **contrast_bars_kwargs))

if show_mini_meta:
ax_to_plot.add_patch(mpatches.Rectangle((max(swarm_plot_ax.get_xticks())+2-0.25,0),0.5, mini_meta_delta.difference, zorder=-1, color='black', **contrast_bars_kwargs))

if show_delta2:
ax_to_plot.add_patch(mpatches.Rectangle((max(swarm_plot_ax.get_xticks())+2-0.25,0),0.5, delta_delta.difference, zorder=-1, color='black', **contrast_bars_kwargs))

def swarm_bars_plotter(plot_data: object, xvar: str, yvar: str, ax: object,
swarm_bars_kwargs: dict, color_col: str, swarm_colors: list, is_paired: bool):
"""
Add bars to the raw data plot.
Parameters
----------
plot_data : object (Dataframe)
Dataframe of the plot data.
xvar : str
Column name of the x variable.
yvar : str
Column name of the y variable.
ax : object
Matplotlib axis object to plot on.
swarm_bars_kwargs : dict
Keyword arguments for the swarm bars.
color_col : str
Column name of the color column.
swarm_colors : list
List of colors used in the plot.
is_paired : bool
Whether the data is paired.
"""

# if is_paired:
# swarm_bar_xlocs_adjustleft = {'right': -0.2, 'left': -0.2, 'center': -0.2}
# swarm_bar_xlocs_adjustright = {'right': -0.1, 'left': -0.1, 'center': -0.1}
# else:
# swarm_bar_xlocs_adjustleft = {'right': 0, 'left': -0.4, 'center': -0.2}
# swarm_bar_xlocs_adjustright = {'right': -0.1, 'left': -0.1, 'center': -0.1}

if isinstance(plot_data[xvar].dtype, pd.CategoricalDtype):
swarm_bars_order = pd.unique(plot_data[xvar]).categories
else:
swarm_bars_order = pd.unique(plot_data[xvar])

swarm_means = plot_data.groupby(xvar)[yvar].mean().reindex(index=swarm_bars_order)
swarm_bars_colors = [swarm_bars_kwargs.get('color')]*(len(swarm_bars_order)+1) if swarm_bars_kwargs.get('color') is not None else ['black']*(len(swarm_bars_order)+1) if color_col is not None or is_paired else swarm_colors
swarm_bars_kwargs.pop('color')
for swarm_bars_x,swarm_bars_y,c in zip(np.arange(0,len(swarm_bars_order)+1,1), swarm_means, swarm_bars_colors):
ax.add_patch(mpatches.Rectangle((swarm_bars_x-0.25,0),
0.5, swarm_bars_y, zorder=-1,color=c,**swarm_bars_kwargs))


# %% ../nbs/API/plot_tools.ipynb 6
def swarmplot(
data: pd.DataFrame,
Expand Down
92 changes: 20 additions & 72 deletions dabest/plotter.py
Expand Up @@ -65,6 +65,9 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
error_bar,
sankeydiag,
swarmplot,
swarm_bars_plotter,
contrast_bars_plotter,
summary_bars_plotter,
)
from ._stats_tools.effsize import (
_compute_standardizers,
Expand Down Expand Up @@ -1594,100 +1597,45 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):

####################################################### END GRIDKEY MAIN CODE WIP

################################################### Swarm & Contrast Bars WIP
# Swarm Bars WIP
################################################### Swarm & Contrast & Summary Bars WIP

# Swarm bars WIP
swarm_bars = plot_kwargs["swarm_bars"]
default_swarm_bars_kwargs = {"color": None, "alpha": 0.1}
if plot_kwargs["swarm_bars_kwargs"] is None:
swarm_bars_kwargs = default_swarm_bars_kwargs
else:
swarm_bars_kwargs = merge_two_dicts(default_swarm_bars_kwargs, plot_kwargs["swarm_bars_kwargs"])

if swarm_bars and not proportional:
# if is_paired:
# swarm_bar_xlocs_adjustleft = {'right': -0.2, 'left': -0.2, 'center': -0.2}
# swarm_bar_xlocs_adjustright = {'right': -0.1, 'left': -0.1, 'center': -0.1}
# else:
# swarm_bar_xlocs_adjustleft = {'right': 0, 'left': -0.4, 'center': -0.2}
# swarm_bar_xlocs_adjustright = {'right': -0.1, 'left': -0.1, 'center': -0.1}

if isinstance(plot_data[xvar].dtype, pd.CategoricalDtype):
swarm_bars_order = pd.unique(plot_data[xvar]).categories
else:
swarm_bars_order = pd.unique(plot_data[xvar])

swarm_means = plot_data.groupby(xvar)[yvar].mean().reindex(index=swarm_bars_order)
swarm_bars_colors = [swarm_bars_kwargs.get('color')]*(len(swarm_bars_order)+1) if swarm_bars_kwargs.get('color') is not None else ['black']*(len(swarm_bars_order)+1) if color_col is not None or is_paired else swarm_colors
swarm_bars_kwargs.pop('color')
for swarm_bars_x,swarm_bars_y,c in zip(np.arange(0,len(swarm_bars_order)+1,1), swarm_means, swarm_bars_colors):
rawdata_axes.add_patch(mpatches.Rectangle((swarm_bars_x-0.25,0),
0.5, swarm_bars_y, zorder=-1,color=c,**swarm_bars_kwargs))

# Contrast Bars WIP
swarm_bars_plotter(plot_data=plot_data, xvar=xvar, yvar=yvar, ax=rawdata_axes, swarm_bars_kwargs=swarm_bars_kwargs,
color_col=color_col, swarm_colors=swarm_colors, is_paired=is_paired)

# Contrast bars WIP
contrast_bars = plot_kwargs["contrast_bars"]
default_contrast_bars_kwargs = {"color": None, "alpha": 0.15}
if plot_kwargs["contrast_bars_kwargs"] is None:
contrast_bars_kwargs = default_contrast_bars_kwargs
else:
contrast_bars_kwargs = merge_two_dicts(default_contrast_bars_kwargs, plot_kwargs["contrast_bars_kwargs"])
if contrast_bars and not float_contrast:
contrast_means = []
for j, tick in enumerate(ticks_to_plot):
contrast_means.append(results.difference[j])

contrast_bars_colors = [contrast_bars_kwargs.get('color')]*(len(ticks_to_plot)+1) if contrast_bars_kwargs.get('color') is not None else ['black']*(max(ticks_to_plot)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors
contrast_bars_kwargs.pop('color')
for contrast_bars_x,contrast_bars_y in zip(ticks_to_plot, contrast_means):
contrast_axes.add_patch(mpatches.Rectangle((contrast_bars_x-0.25,0),0.5, contrast_bars_y, zorder=-1, color=contrast_bars_colors[contrast_bars_x], **contrast_bars_kwargs))

if show_mini_meta:
contrast_axes.add_patch(mpatches.Rectangle((max(rawdata_axes.get_xticks())+2-0.25,0),0.5, mini_meta_delta.difference, zorder=-1, color='black', **contrast_bars_kwargs))

if show_delta2:
contrast_axes.add_patch(mpatches.Rectangle((max(rawdata_axes.get_xticks())+2-0.25,0),0.5, delta_delta.difference, zorder=-1, color='black', **contrast_bars_kwargs))

################################################### Swarm & Contrast Bars WIP

################################################### Summary Bars WIP
contrast_bars_plotter(results=results, ax_to_plot=contrast_axes, swarm_plot_ax=rawdata_axes,ticks_to_plot=ticks_to_plot,
contrast_bars_kwargs=contrast_bars_kwargs, color_col=color_col, swarm_colors=swarm_colors, show_mini_meta=show_mini_meta,
mini_meta_delta=mini_meta_delta if show_mini_meta else None, show_delta2=show_delta2,
delta_delta=delta_delta if show_delta2 else None, proportional=proportional, is_paired=is_paired)

# Summary bars WIP
summary_bars = plot_kwargs["summary_bars"]
default_summary_bars_kwargs = {"color": None, "alpha": 0.15}
if plot_kwargs["summary_bars_kwargs"] is None:
summary_bars_kwargs = default_summary_bars_kwargs
else:
summary_bars_kwargs = merge_two_dicts(default_summary_bars_kwargs, plot_kwargs["summary_bars_kwargs"])

if summary_bars is not None:
if not isinstance(summary_bars, list):
raise TypeError("summary_bars must be a list of indices (ints).")
if not all(isinstance(i, int) for i in summary_bars):
raise TypeError("summary_bars must be a list of indices (ints).")
if any(i >= len(results) for i in summary_bars):
raise ValueError("Index {} chosen is out of range for the contrast objects.".format([i for i in summary_bars if i >= len(results)]))
if float_contrast:
raise ValueError("summary_bars cannot be used with Gardner-Altman plots.")
else:
print('Summary plots WIP')
summary_xmin, summary_xmax = contrast_axes.get_xlim()
summary_bars_colors = [summary_bars_kwargs.get('color')]*(len(summary_bars)+1) if summary_bars_kwargs.get('color') is not None else ['black']*(max(summary_bars)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors
summary_bars_kwargs.pop('color')
for summary_index in summary_bars:
print('Summary plot for contrast object:', summary_index)
if ci_type == "bca":
summary_ci_low = results.bca_low[summary_index]
summary_ci_high = results.bca_high[summary_index]
else:
summary_ci_low = results.pct_low[summary_index]
summary_ci_high = results.pct_high[summary_index]

summary_color = summary_bars_colors[ticks_to_plot[summary_index]]

contrast_axes.add_patch(mpatches.Rectangle((summary_xmin,summary_ci_low),summary_xmax+1,
summary_ci_high-summary_ci_low, zorder=-2, color=summary_color, **summary_bars_kwargs))


################################################### Summary Bars WIP
summary_bars_plotter(summary_bars=summary_bars, results=results, ax_to_plot=contrast_axes, float_contrast=float_contrast,
summary_bars_kwargs=summary_bars_kwargs, ci_type=ci_type, ticks_to_plot=ticks_to_plot, color_col=color_col,
swarm_colors=swarm_colors, proportional=proportional, is_paired=is_paired)

################################################### Swarm & Contrast & Summary Bars WIP END

# Make sure no stray ticks appear!
rawdata_axes.xaxis.set_ticks_position("bottom")
Expand Down

0 comments on commit d4d62cf

Please sign in to comment.