Skip to content

Commit

Permalink
Add: New test images and api changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Lucas1213WZY committed Apr 10, 2024
1 parent ef46eda commit 99eeb08
Show file tree
Hide file tree
Showing 104 changed files with 112 additions and 68 deletions.
4 changes: 3 additions & 1 deletion dabest/_modidx.py
Expand Up @@ -64,7 +64,9 @@
'dabest.forest_plot': { 'dabest.forest_plot.extract_plot_data': ( 'API/forest_plot.html#extract_plot_data',
'dabest/forest_plot.py'),
'dabest.forest_plot.forest_plot': ('API/forest_plot.html#forest_plot', 'dabest/forest_plot.py'),
'dabest.forest_plot.load_plot_data': ('API/forest_plot.html#load_plot_data', 'dabest/forest_plot.py')},
'dabest.forest_plot.load_plot_data': ('API/forest_plot.html#load_plot_data', 'dabest/forest_plot.py'),
'dabest.forest_plot.map_effect_attribute': ( 'API/forest_plot.html#map_effect_attribute',
'dabest/forest_plot.py')},
'dabest.misc_tools': { 'dabest.misc_tools.get_varname': ('API/misc_tools.html#get_varname', 'dabest/misc_tools.py'),
'dabest.misc_tools.merge_two_dicts': ('API/misc_tools.html#merge_two_dicts', 'dabest/misc_tools.py'),
'dabest.misc_tools.print_greeting': ('API/misc_tools.html#print_greeting', 'dabest/misc_tools.py'),
Expand Down
81 changes: 52 additions & 29 deletions dabest/forest_plot.py
@@ -1,7 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/forest_plot.ipynb.

# %% auto 0
__all__ = ['load_plot_data', 'extract_plot_data', 'forest_plot']
__all__ = ['load_plot_data', 'extract_plot_data', 'map_effect_attribute', 'forest_plot']

# %% ../nbs/API/forest_plot.ipynb 5
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -72,28 +72,42 @@ def extract_plot_data(contrast_plot_data, contrast_type):

return bootstraps, differences, bcalows, bcahighs

def map_effect_attribute(attribute_key):
# Check if the attribute key exists in the dictionary
effect_attr_map = {
"mean_diff": "Mean Difference",
"median_diff": "Median Difference",
"cliffs_delta": "Cliffs Delta",
"cohens_d": "Cohens d",
"hedges_g": "Hedges g",
"delta_g": "Delta g"
}
if attribute_key in effect_attr_map:
return effect_attr_map[attribute_key]
else:
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.") # Return a default value or message if the key is not found

def forest_plot(
contrasts: List,
selected_indices: Optional[List] = None,
contrast_type: str = "delta2",
xticklabels: Optional[List] = None,
effect_size: str = "mean_diff",
contrast_labels: List[str] = None,
ylabel: str = "value",
ylabel: str = "effect size",
plot_elements_to_extract: Optional[List] = None,
title: str = "ΔΔ Forest",
custom_palette: Optional[Union[dict, list, str]] = None,
fontsize: int = 20,
fontsize: int = 12,
title_font_size: int =16,
violin_kwargs: Optional[dict] = None,
marker_size: int = 20,
ci_line_width: float = 2.5,
zero_line_width: int = 1,
desat_violin: float = 1,
remove_spines: bool = True,
ax: Optional[plt.Axes] = None,
additional_plotting_kwargs: Optional[dict] = None,
rotation_for_xlabels: int = 45,
alpha_violin_plot: float = 0.4,
alpha_violin_plot: float = 0.8,
horizontal: bool = False # New argument for horizontal orientation
)-> plt.Figure:
"""
Expand All @@ -106,11 +120,9 @@ def forest_plot(
selected_indices : Optional[List], default=None
Indices of specific contrasts to plot, if not plotting all.
analysis_type : str
the type of analysis (e.g., 'delta2', 'minimeta').
xticklabels : Optional[List], default=None
Custom labels for the x-axis ticks.
the type of analysis (e.g., 'delta2', 'mini_meta').
effect_size : str
Type of effect size to plot (e.g., 'mean_diff', 'median_diff').
Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g`).
contrast_labels : List[str]
Labels for each contrast.
ylabel : str
Expand All @@ -125,14 +137,14 @@ def forest_plot(
Custom color palette for the plot.
fontsize : int
Font size for text elements in the plot.
title_font_size: int =16
Font size for text of plot title.
violin_kwargs : Optional[dict], default=None
Additional arguments for violin plot customization.
marker_size : int
Marker size for plotting mean differences or effect sizes.
ci_line_width : float
Width of confidence interval lines.
zero_line_width : int
Width of the line indicating zero effect size.
remove_spines : bool, default=False
If True, removes top and right plot spines.
ax : Optional[plt.Axes], default=None
Expand Down Expand Up @@ -161,14 +173,13 @@ def forest_plot(
if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):
raise TypeError("The `selected_indices` must be a list of integers or `None`.")

# For the 'contrast_type' parameter
if not isinstance(contrast_type, str):
raise TypeError("The `contrast_type` argument must be a string.")

if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):
raise TypeError("The `xticklabels` must be a list of strings or `None`.")

raise TypeError("The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.")

# For the 'effect_size' parameter
if not isinstance(effect_size, str):
raise TypeError("The `effect_size` argument must be a string.")
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.")

if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):
raise TypeError("The `contrast_labels` must be a list of strings or `None`.")
Expand All @@ -191,9 +202,6 @@ def forest_plot(
if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:
raise TypeError("`ci_line_width` must be a positive integer or float.")

if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:
raise TypeError("`zero_line_width` must be a positive integer or float.")

if not isinstance(remove_spines, bool):
raise TypeError("`remove_spines` must be a boolean value.")

Expand All @@ -209,6 +217,8 @@ def forest_plot(
if not isinstance(horizontal, bool):
raise TypeError("`horizontal` must be a boolean value.")

if (effect_size and isinstance(effect_size, str)):
ylabel = map_effect_attribute(effect_size)
# Load plot data
contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)

Expand Down Expand Up @@ -250,7 +260,7 @@ def forest_plot(
if custom_palette:
if isinstance(custom_palette, dict):
violin_colors = [
custom_palette.get(c, sns.color_palette()[0]) for c in contrasts
custom_palette.get(c, sns.color_palette()[0]) for c in contrast_labels
]
elif isinstance(custom_palette, list):
violin_colors = custom_palette[: len(contrasts)]
Expand All @@ -262,12 +272,18 @@ def forest_plot(
f"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette."
)
else:
violin_colors = sns.color_palette()[: len(contrasts)]
violin_colors = sns.color_palette(n_colors=len(contrasts))

violin_colors = [sns.desaturate(color, desat_violin) for color in violin_colors]

for patch, color in zip(v["bodies"], violin_colors):
patch.set_facecolor(color)
patch.set_alpha(alpha_violin_plot)

if horizontal:
ax.plot([0, 0], [0, len(contrasts)+1], 'k', linewidth = 1)
else:
ax.plot([0, len(contrasts)+1], [0, 0], 'k', linewidth = 1)

# Flipping the axes for plotting based on 'horizontal'
for k in range(1, len(contrasts) + 1):
if horizontal:
Expand All @@ -280,19 +296,26 @@ def forest_plot(
# Adjusting labels, ticks, and limits based on 'horizontal'
if horizontal:
ax.set_yticks(range(1, len(contrasts) + 1))
ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
ax.set_yticklabels(contrast_labels, rotation=0, fontsize=fontsize)
ax.set_xlabel(ylabel, fontsize=fontsize)
ax.set_ylim([0.7, len(contrasts) + 0.5])
else:
ax.set_xticks(range(1, len(contrasts) + 1))
ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
ax.set_xlim([0.7, len(contrasts) + 0.5])

# Setting the title and adjusting spines as before
ax.set_title(title, fontsize=fontsize)
ax.set_title(title, fontsize=title_font_size)
if remove_spines:
for spine in ax.spines.values():
spine.set_visible(False)

if horizontal:
ax.spines['left'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
else:
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
# Apply additional customizations if provided
if additional_plotting_kwargs:
ax.set(**additional_plotting_kwargs)
Expand Down
79 changes: 51 additions & 28 deletions nbs/API/forest_plot.ipynb
Expand Up @@ -133,28 +133,42 @@
" \n",
" return bootstraps, differences, bcalows, bcahighs\n",
"\n",
"def map_effect_attribute(attribute_key):\n",
" # Check if the attribute key exists in the dictionary\n",
" effect_attr_map = {\n",
" \"mean_diff\": \"Mean Difference\",\n",
" \"median_diff\": \"Median Difference\",\n",
" \"cliffs_delta\": \"Cliffs Delta\",\n",
" \"cohens_d\": \"Cohens d\",\n",
" \"hedges_g\": \"Hedges g\",\n",
" \"delta_g\": \"Delta g\"\n",
" }\n",
" if attribute_key in effect_attr_map:\n",
" return effect_attr_map[attribute_key]\n",
" else:\n",
" raise TypeError(\"The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.\") # Return a default value or message if the key is not found\n",
"\n",
"def forest_plot(\n",
" contrasts: List,\n",
" selected_indices: Optional[List] = None,\n",
" contrast_type: str = \"delta2\",\n",
" xticklabels: Optional[List] = None,\n",
" effect_size: str = \"mean_diff\",\n",
" contrast_labels: List[str] = None,\n",
" ylabel: str = \"value\",\n",
" ylabel: str = \"effect size\",\n",
" plot_elements_to_extract: Optional[List] = None,\n",
" title: str = \"ΔΔ Forest\",\n",
" custom_palette: Optional[Union[dict, list, str]] = None,\n",
" fontsize: int = 20,\n",
" fontsize: int = 12,\n",
" title_font_size: int =16,\n",
" violin_kwargs: Optional[dict] = None,\n",
" marker_size: int = 20,\n",
" ci_line_width: float = 2.5,\n",
" zero_line_width: int = 1,\n",
" desat_violin: float = 1,\n",
" remove_spines: bool = True,\n",
" ax: Optional[plt.Axes] = None,\n",
" additional_plotting_kwargs: Optional[dict] = None,\n",
" rotation_for_xlabels: int = 45,\n",
" alpha_violin_plot: float = 0.4,\n",
" alpha_violin_plot: float = 0.8,\n",
" horizontal: bool = False # New argument for horizontal orientation\n",
")-> plt.Figure:\n",
" \"\"\" \n",
Expand All @@ -167,11 +181,9 @@
" selected_indices : Optional[List], default=None\n",
" Indices of specific contrasts to plot, if not plotting all.\n",
" analysis_type : str\n",
" the type of analysis (e.g., 'delta2', 'minimeta').\n",
" xticklabels : Optional[List], default=None\n",
" Custom labels for the x-axis ticks.\n",
" the type of analysis (e.g., 'delta2', 'mini_meta').\n",
" effect_size : str\n",
" Type of effect size to plot (e.g., 'mean_diff', 'median_diff').\n",
" Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g`).\n",
" contrast_labels : List[str]\n",
" Labels for each contrast.\n",
" ylabel : str\n",
Expand All @@ -186,14 +198,14 @@
" Custom color palette for the plot.\n",
" fontsize : int\n",
" Font size for text elements in the plot.\n",
" title_font_size: int =16\n",
" Font size for text of plot title.\n",
" violin_kwargs : Optional[dict], default=None\n",
" Additional arguments for violin plot customization.\n",
" marker_size : int\n",
" Marker size for plotting mean differences or effect sizes.\n",
" ci_line_width : float\n",
" Width of confidence interval lines.\n",
" zero_line_width : int\n",
" Width of the line indicating zero effect size.\n",
" remove_spines : bool, default=False\n",
" If True, removes top and right plot spines.\n",
" ax : Optional[plt.Axes], default=None\n",
Expand Down Expand Up @@ -222,14 +234,13 @@
" if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):\n",
" raise TypeError(\"The `selected_indices` must be a list of integers or `None`.\")\n",
" \n",
" # For the 'contrast_type' parameter\n",
" if not isinstance(contrast_type, str):\n",
" raise TypeError(\"The `contrast_type` argument must be a string.\")\n",
" \n",
" if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):\n",
" raise TypeError(\"The `xticklabels` must be a list of strings or `None`.\")\n",
" \n",
" raise TypeError(\"The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.\")\n",
"\n",
" # For the 'effect_size' parameter\n",
" if not isinstance(effect_size, str):\n",
" raise TypeError(\"The `effect_size` argument must be a string.\")\n",
" raise TypeError(\"The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.\")\n",
" \n",
" if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):\n",
" raise TypeError(\"The `contrast_labels` must be a list of strings or `None`.\")\n",
Expand All @@ -252,9 +263,6 @@
" if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:\n",
" raise TypeError(\"`ci_line_width` must be a positive integer or float.\")\n",
" \n",
" if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:\n",
" raise TypeError(\"`zero_line_width` must be a positive integer or float.\")\n",
" \n",
" if not isinstance(remove_spines, bool):\n",
" raise TypeError(\"`remove_spines` must be a boolean value.\")\n",
" \n",
Expand All @@ -270,6 +278,8 @@
" if not isinstance(horizontal, bool):\n",
" raise TypeError(\"`horizontal` must be a boolean value.\")\n",
"\n",
" if (effect_size and isinstance(effect_size, str)):\n",
" ylabel = map_effect_attribute(effect_size)\n",
" # Load plot data\n",
" contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)\n",
"\n",
Expand Down Expand Up @@ -311,7 +321,7 @@
" if custom_palette:\n",
" if isinstance(custom_palette, dict):\n",
" violin_colors = [\n",
" custom_palette.get(c, sns.color_palette()[0]) for c in contrasts\n",
" custom_palette.get(c, sns.color_palette()[0]) for c in contrast_labels\n",
" ]\n",
" elif isinstance(custom_palette, list):\n",
" violin_colors = custom_palette[: len(contrasts)]\n",
Expand All @@ -323,12 +333,18 @@
" f\"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette.\"\n",
" )\n",
" else:\n",
" violin_colors = sns.color_palette()[: len(contrasts)]\n",
" violin_colors = sns.color_palette(n_colors=len(contrasts))\n",
"\n",
" violin_colors = [sns.desaturate(color, desat_violin) for color in violin_colors]\n",
" \n",
" for patch, color in zip(v[\"bodies\"], violin_colors):\n",
" patch.set_facecolor(color)\n",
" patch.set_alpha(alpha_violin_plot)\n",
"\n",
" if horizontal:\n",
" ax.plot([0, 0], [0, len(contrasts)+1], 'k', linewidth = 1)\n",
" else:\n",
" ax.plot([0, len(contrasts)+1], [0, 0], 'k', linewidth = 1)\n",
" \n",
" # Flipping the axes for plotting based on 'horizontal'\n",
" for k in range(1, len(contrasts) + 1):\n",
" if horizontal:\n",
Expand All @@ -341,19 +357,26 @@
" # Adjusting labels, ticks, and limits based on 'horizontal'\n",
" if horizontal:\n",
" ax.set_yticks(range(1, len(contrasts) + 1))\n",
" ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)\n",
" ax.set_yticklabels(contrast_labels, rotation=0, fontsize=fontsize)\n",
" ax.set_xlabel(ylabel, fontsize=fontsize)\n",
" ax.set_ylim([0.7, len(contrasts) + 0.5])\n",
" else:\n",
" ax.set_xticks(range(1, len(contrasts) + 1))\n",
" ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)\n",
" ax.set_ylabel(ylabel, fontsize=fontsize)\n",
" ax.set_xlim([0.7, len(contrasts) + 0.5])\n",
"\n",
" # Setting the title and adjusting spines as before\n",
" ax.set_title(title, fontsize=fontsize)\n",
" ax.set_title(title, fontsize=title_font_size)\n",
" if remove_spines:\n",
" for spine in ax.spines.values():\n",
" spine.set_visible(False)\n",
"\n",
" if horizontal:\n",
" ax.spines['left'].set_visible(False)\n",
" ax.spines['right'].set_visible(False)\n",
" ax.spines['top'].set_visible(False)\n",
" else:\n",
" ax.spines['top'].set_visible(False)\n",
" ax.spines['bottom'].set_visible(False)\n",
" ax.spines['right'].set_visible(False)\n",
" # Apply additional customizations if provided\n",
" if additional_plotting_kwargs:\n",
" ax.set(**additional_plotting_kwargs)\n",
Expand Down

0 comments on commit 99eeb08

Please sign in to comment.