Skip to content

Commit

Permalink
Revert "revert: back to Johnathan's changes - feat: changing the y la…
Browse files Browse the repository at this point in the history
…bels into effect sizes passed to the function and implemented more compact layout"

This reverts commit 6540b40.
  • Loading branch information
Lucas1213WZY committed Apr 8, 2024
1 parent dc243a5 commit ef46eda
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 90 deletions.
46 changes: 24 additions & 22 deletions dabest/forest_plot.py
Expand Up @@ -77,23 +77,23 @@ def forest_plot(
contrasts: List,
selected_indices: Optional[List] = None,
contrast_type: str = "delta2",
xticklabels: Optional[List] = None,
effect_size: str = "mean_diff",
contrast_labels: List[str] = None,
ylabel: str = "value",
plot_elements_to_extract: Optional[List] = None,
title: str = "ΔΔ Forest",
custom_palette: Optional[Union[dict, list, str]] = None,
fontsize: int = 12,
title_font_size: int =16,
fontsize: int = 20,
violin_kwargs: Optional[dict] = None,
marker_size: int = 20,
ci_line_width: float = 2.5,
desat_violin: float = 1,
zero_line_width: int = 1,
remove_spines: bool = True,
ax: Optional[plt.Axes] = None,
additional_plotting_kwargs: Optional[dict] = None,
rotation_for_xlabels: int = 45,
alpha_violin_plot: float = 0.8,
alpha_violin_plot: float = 0.4,
horizontal: bool = False # New argument for horizontal orientation
)-> plt.Figure:
"""
Expand All @@ -106,9 +106,11 @@ def forest_plot(
selected_indices : Optional[List], default=None
Indices of specific contrasts to plot, if not plotting all.
analysis_type : str
the type of analysis (e.g., 'delta2', 'mini_meta').
the type of analysis (e.g., 'delta2', 'minimeta').
xticklabels : Optional[List], default=None
Custom labels for the x-axis ticks.
effect_size : str
Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g`).
Type of effect size to plot (e.g., 'mean_diff', 'median_diff').
contrast_labels : List[str]
Labels for each contrast.
ylabel : str
Expand All @@ -123,14 +125,14 @@ def forest_plot(
Custom color palette for the plot.
fontsize : int
Font size for text elements in the plot.
title_font_size: int =16
Font size for text of plot title.
violin_kwargs : Optional[dict], default=None
Additional arguments for violin plot customization.
marker_size : int
Marker size for plotting mean differences or effect sizes.
ci_line_width : float
Width of confidence interval lines.
zero_line_width : int
Width of the line indicating zero effect size.
remove_spines : bool, default=False
If True, removes top and right plot spines.
ax : Optional[plt.Axes], default=None
Expand All @@ -147,7 +149,7 @@ def forest_plot(
plt.Figure
The matplotlib figure object with the generated forest plot.
"""
from dabest.plot_tools import halfviolin
from .plot_tools import halfviolin

# Validate inputs
if contrasts is None:
Expand All @@ -160,10 +162,13 @@ def forest_plot(
raise TypeError("The `selected_indices` must be a list of integers or `None`.")

if not isinstance(contrast_type, str):
raise TypeError("The `contrast_type` argument must be a string. Please choose from `delta2` and `minimeta`.")
raise TypeError("The `contrast_type` argument must be a string.")

if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):
raise TypeError("The `xticklabels` must be a list of strings or `None`.")

if not isinstance(effect_size, str):
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.")
raise TypeError("The `effect_size` argument must be a string.")

if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):
raise TypeError("The `contrast_labels` must be a list of strings or `None`.")
Expand All @@ -186,6 +191,9 @@ def forest_plot(
if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:
raise TypeError("`ci_line_width` must be a positive integer or float.")

if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:
raise TypeError("`zero_line_width` must be a positive integer or float.")

if not isinstance(remove_spines, bool):
raise TypeError("`remove_spines` must be a boolean value.")

Expand Down Expand Up @@ -242,7 +250,7 @@ def forest_plot(
if custom_palette:
if isinstance(custom_palette, dict):
violin_colors = [
custom_palette.get(c, sns.color_palette()[0]) for c in contrast_labels
custom_palette.get(c, sns.color_palette()[0]) for c in contrasts
]
elif isinstance(custom_palette, list):
violin_colors = custom_palette[: len(contrasts)]
Expand All @@ -254,18 +262,12 @@ def forest_plot(
f"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette."
)
else:
violin_colors = sns.color_palette(n_colors=len(contrasts))
violin_colors = sns.color_palette()[: len(contrasts)]

violin_colors = [sns.desaturate(color, desat_violin) for color in violin_colors]

for patch, color in zip(v["bodies"], violin_colors):
patch.set_facecolor(color)
patch.set_alpha(alpha_violin_plot)
if horizontal:
ax.plot([0, 0], [0, len(contrasts)+1], 'k', linewidth = 1)
else:
ax.plot([0, len(contrasts)+1], [0, 0], 'k', linewidth = 1)


# Flipping the axes for plotting based on 'horizontal'
for k in range(1, len(contrasts) + 1):
if horizontal:
Expand All @@ -278,15 +280,15 @@ def forest_plot(
# Adjusting labels, ticks, and limits based on 'horizontal'
if horizontal:
ax.set_yticks(range(1, len(contrasts) + 1))
ax.set_yticklabels(contrast_labels, rotation=0, fontsize=fontsize)
ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
ax.set_xlabel(ylabel, fontsize=fontsize)
else:
ax.set_xticks(range(1, len(contrasts) + 1))
ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)

# Setting the title and adjusting spines as before
ax.set_title(title, fontsize=title_font_size)
ax.set_title(title, fontsize=fontsize)
if remove_spines:
for spine in ax.spines.values():
spine.set_visible(False)
Expand Down
99 changes: 33 additions & 66 deletions nbs/API/forest_plot.ipynb
Expand Up @@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -22,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -32,7 +32,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -44,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -54,7 +54,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -67,7 +67,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -133,42 +133,28 @@
" \n",
" return bootstraps, differences, bcalows, bcahighs\n",
"\n",
"def map_effect_attribute(attribute_key):\n",
" # Check if the attribute key exists in the dictionary\n",
" effect_attr_map = {\n",
" \"mean_diff\": \"Mean Difference\",\n",
" \"median_diff\": \"Median Difference\",\n",
" \"cliffs_delta\": \"Cliffs Delta\",\n",
" \"cohens_d\": \"Cohens d\",\n",
" \"hedges_g\": \"Hedges g\",\n",
" \"delta_g\": \"Delta g\"\n",
" }\n",
" if attribute_key in effect_attr_map:\n",
" return effect_attr_map[attribute_key]\n",
" else:\n",
" raise TypeError(\"The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.\") # Return a default value or message if the key is not found\n",
"\n",
"def forest_plot(\n",
" contrasts: List,\n",
" selected_indices: Optional[List] = None,\n",
" contrast_type: str = \"delta2\",\n",
" xticklabels: Optional[List] = None,\n",
" effect_size: str = \"mean_diff\",\n",
" contrast_labels: List[str] = None,\n",
" ylabel: str = \"effect_size\",\n",
" ylabel: str = \"value\",\n",
" plot_elements_to_extract: Optional[List] = None,\n",
" title: str = \"ΔΔ Forest\",\n",
" custom_palette: Optional[Union[dict, list, str]] = None,\n",
" fontsize: int = 12,\n",
" title_font_size: int =16,\n",
" fontsize: int = 20,\n",
" violin_kwargs: Optional[dict] = None,\n",
" marker_size: int = 20,\n",
" ci_line_width: float = 2.5,\n",
" desat_violin: float = 1,\n",
" zero_line_width: int = 1,\n",
" remove_spines: bool = True,\n",
" ax: Optional[plt.Axes] = None,\n",
" additional_plotting_kwargs: Optional[dict] = None,\n",
" rotation_for_xlabels: int = 45,\n",
" alpha_violin_plot: float = 0.8,\n",
" alpha_violin_plot: float = 0.4,\n",
" horizontal: bool = False # New argument for horizontal orientation\n",
")-> plt.Figure:\n",
" \"\"\" \n",
Expand All @@ -181,9 +167,11 @@
" selected_indices : Optional[List], default=None\n",
" Indices of specific contrasts to plot, if not plotting all.\n",
" analysis_type : str\n",
" the type of analysis (e.g., 'delta2', 'mini_meta').\n",
" the type of analysis (e.g., 'delta2', 'minimeta').\n",
" xticklabels : Optional[List], default=None\n",
" Custom labels for the x-axis ticks.\n",
" effect_size : str\n",
" Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g`).\n",
" Type of effect size to plot (e.g., 'mean_diff', 'median_diff').\n",
" contrast_labels : List[str]\n",
" Labels for each contrast.\n",
" ylabel : str\n",
Expand All @@ -198,14 +186,14 @@
" Custom color palette for the plot.\n",
" fontsize : int\n",
" Font size for text elements in the plot.\n",
" title_font_size: int =16\n",
" Font size for text of plot title.\n",
" violin_kwargs : Optional[dict], default=None\n",
" Additional arguments for violin plot customization.\n",
" marker_size : int\n",
" Marker size for plotting mean differences or effect sizes.\n",
" ci_line_width : float\n",
" Width of confidence interval lines.\n",
" zero_line_width : int\n",
" Width of the line indicating zero effect size.\n",
" remove_spines : bool, default=False\n",
" If True, removes top and right plot spines.\n",
" ax : Optional[plt.Axes], default=None\n",
Expand Down Expand Up @@ -235,10 +223,13 @@
" raise TypeError(\"The `selected_indices` must be a list of integers or `None`.\")\n",
" \n",
" if not isinstance(contrast_type, str):\n",
" raise TypeError(\"The `contrast_type` argument must be a string. Please choose from `delta2` and `minimeta`.\")\n",
" raise TypeError(\"The `contrast_type` argument must be a string.\")\n",
" \n",
" if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):\n",
" raise TypeError(\"The `xticklabels` must be a list of strings or `None`.\")\n",
" \n",
" if not isinstance(effect_size, str):\n",
" raise TypeError(\"The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.\")\n",
" raise TypeError(\"The `effect_size` argument must be a string.\")\n",
" \n",
" if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):\n",
" raise TypeError(\"The `contrast_labels` must be a list of strings or `None`.\")\n",
Expand All @@ -261,6 +252,9 @@
" if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:\n",
" raise TypeError(\"`ci_line_width` must be a positive integer or float.\")\n",
" \n",
" if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:\n",
" raise TypeError(\"`zero_line_width` must be a positive integer or float.\")\n",
" \n",
" if not isinstance(remove_spines, bool):\n",
" raise TypeError(\"`remove_spines` must be a boolean value.\")\n",
" \n",
Expand All @@ -276,8 +270,6 @@
" if not isinstance(horizontal, bool):\n",
" raise TypeError(\"`horizontal` must be a boolean value.\")\n",
"\n",
" if (effect_size and isinstance(effect_size, str)):\n",
" ylabel = map_effect_attribute(effect_size)\n",
" # Load plot data\n",
" contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)\n",
"\n",
Expand Down Expand Up @@ -319,7 +311,7 @@
" if custom_palette:\n",
" if isinstance(custom_palette, dict):\n",
" violin_colors = [\n",
" custom_palette.get(c, sns.color_palette()[0]) for c in contrast_labels\n",
" custom_palette.get(c, sns.color_palette()[0]) for c in contrasts\n",
" ]\n",
" elif isinstance(custom_palette, list):\n",
" violin_colors = custom_palette[: len(contrasts)]\n",
Expand All @@ -331,18 +323,12 @@
" f\"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette.\"\n",
" )\n",
" else:\n",
" violin_colors = sns.color_palette(n_colors=len(contrasts))\n",
" violin_colors = sns.color_palette()[: len(contrasts)]\n",
"\n",
" violin_colors = [sns.desaturate(color, desat_violin) for color in violin_colors]\n",
" \n",
" for patch, color in zip(v[\"bodies\"], violin_colors):\n",
" patch.set_facecolor(color)\n",
" patch.set_alpha(alpha_violin_plot)\n",
" if horizontal:\n",
" ax.plot([0, 0], [0, len(contrasts)+1], 'k', linewidth = 1)\n",
" else:\n",
" ax.plot([0, len(contrasts)+1], [0, 0], 'k', linewidth = 1)\n",
" \n",
"\n",
" # Flipping the axes for plotting based on 'horizontal'\n",
" for k in range(1, len(contrasts) + 1):\n",
" if horizontal:\n",
Expand All @@ -355,26 +341,19 @@
" # Adjusting labels, ticks, and limits based on 'horizontal'\n",
" if horizontal:\n",
" ax.set_yticks(range(1, len(contrasts) + 1))\n",
" ax.set_yticklabels(contrast_labels, rotation=0, fontsize=fontsize)\n",
" ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)\n",
" ax.set_xlabel(ylabel, fontsize=fontsize)\n",
" ax.set_ylim([0.7, len(contrasts) + 0.5])\n",
" else:\n",
" ax.set_xticks(range(1, len(contrasts) + 1))\n",
" ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)\n",
" ax.set_ylabel(ylabel, fontsize=fontsize)\n",
" ax.set_xlim([0.7, len(contrasts) + 0.5])\n",
"\n",
" # Setting the title and adjusting spines as before\n",
" ax.set_title(title, fontsize=title_font_size)\n",
" ax.set_title(title, fontsize=fontsize)\n",
" if remove_spines:\n",
" if horizontal:\n",
" ax.spines['left'].set_visible(False)\n",
" ax.spines['right'].set_visible(False)\n",
" ax.spines['top'].set_visible(False)\n",
" else:\n",
" ax.spines['top'].set_visible(False)\n",
" ax.spines['bottom'].set_visible(False)\n",
" ax.spines['right'].set_visible(False)\n",
" for spine in ax.spines.values():\n",
" spine.set_visible(False)\n",
"\n",
" # Apply additional customizations if provided\n",
" if additional_plotting_kwargs:\n",
" ax.set(**additional_plotting_kwargs)\n",
Expand All @@ -388,18 +367,6 @@
"display_name": "python3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
Expand Down

0 comments on commit ef46eda

Please sign in to comment.