Skip to content

Commit

Permalink
feat: adjusted test messages, forest_plot function
Browse files Browse the repository at this point in the history
  • Loading branch information
Lucas1213WZY committed Apr 8, 2024
1 parent 6540b40 commit 51b44aa
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 40 deletions.
4 changes: 3 additions & 1 deletion dabest/_modidx.py
Expand Up @@ -64,7 +64,9 @@
'dabest.forest_plot': { 'dabest.forest_plot.extract_plot_data': ( 'API/forest_plot.html#extract_plot_data',
'dabest/forest_plot.py'),
'dabest.forest_plot.forest_plot': ('API/forest_plot.html#forest_plot', 'dabest/forest_plot.py'),
'dabest.forest_plot.load_plot_data': ('API/forest_plot.html#load_plot_data', 'dabest/forest_plot.py')},
'dabest.forest_plot.load_plot_data': ('API/forest_plot.html#load_plot_data', 'dabest/forest_plot.py'),
'dabest.forest_plot.map_effect_attribute': ( 'API/forest_plot.html#map_effect_attribute',
'dabest/forest_plot.py')},
'dabest.misc_tools': { 'dabest.misc_tools.get_varname': ('API/misc_tools.html#get_varname', 'dabest/misc_tools.py'),
'dabest.misc_tools.merge_two_dicts': ('API/misc_tools.html#merge_two_dicts', 'dabest/misc_tools.py'),
'dabest.misc_tools.print_greeting': ('API/misc_tools.html#print_greeting', 'dabest/misc_tools.py'),
Expand Down
43 changes: 34 additions & 9 deletions dabest/forest_plot.py
@@ -1,7 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/forest_plot.ipynb.

# %% auto 0
__all__ = ['load_plot_data', 'extract_plot_data', 'forest_plot']
__all__ = ['load_plot_data', 'extract_plot_data', 'map_effect_attribute', 'forest_plot']

# %% ../nbs/API/forest_plot.ipynb 5
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -72,14 +72,28 @@ def extract_plot_data(contrast_plot_data, contrast_type):

return bootstraps, differences, bcalows, bcahighs

def map_effect_attribute(attribute_key):
# Check if the attribute key exists in the dictionary
effect_attr_map = {
"mean_diff": "Mean Difference",
"median_diff": "Median Difference",
"cliffs_delta": "Cliffs Delta",
"cohens_d": "Cohens d",
"hedges_g": "Hedges g",
"delta_g": "Delta g"
}
if attribute_key in effect_attr_map:
return effect_attr_map[attribute_key]
else:
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.") # Return a default value or message if the key is not found

def forest_plot(
contrasts: List,
selected_indices: Optional[List] = None,
contrast_type: str = "delta2",
effect_size: str = "mean_diff",
contrast_labels: List[str] = None,
ylabel: str = "value",
ylabel: str = "effect size",
plot_elements_to_extract: Optional[List] = None,
title: str = "ΔΔ Forest",
custom_palette: Optional[Union[dict, list, str]] = None,
Expand Down Expand Up @@ -147,7 +161,7 @@ def forest_plot(
plt.Figure
The matplotlib figure object with the generated forest plot.
"""
from dabest.plot_tools import halfviolin
from .plot_tools import halfviolin

# Validate inputs
if contrasts is None:
Expand All @@ -159,11 +173,13 @@ def forest_plot(
if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):
raise TypeError("The `selected_indices` must be a list of integers or `None`.")

# For the 'contrast_type' parameter
if not isinstance(contrast_type, str):
raise TypeError("The `contrast_type` argument must be a string. Please choose from `delta2` and `minimeta`.")

raise TypeError("The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.")

# For the 'effect_size' parameter
if not isinstance(effect_size, str):
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.")
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.")

if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):
raise TypeError("The `contrast_labels` must be a list of strings or `None`.")
Expand Down Expand Up @@ -201,6 +217,8 @@ def forest_plot(
if not isinstance(horizontal, bool):
raise TypeError("`horizontal` must be a boolean value.")

if (effect_size and isinstance(effect_size, str)):
ylabel = map_effect_attribute(effect_size)
# Load plot data
contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)

Expand Down Expand Up @@ -280,17 +298,24 @@ def forest_plot(
ax.set_yticks(range(1, len(contrasts) + 1))
ax.set_yticklabels(contrast_labels, rotation=0, fontsize=fontsize)
ax.set_xlabel(ylabel, fontsize=fontsize)
ax.set_ylim([0.7, len(contrasts) + 0.5])
else:
ax.set_xticks(range(1, len(contrasts) + 1))
ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
ax.set_xlim([0.7, len(contrasts) + 0.5])

# Setting the title and adjusting spines as before
ax.set_title(title, fontsize=title_font_size)
if remove_spines:
for spine in ax.spines.values():
spine.set_visible(False)

if horizontal:
ax.spines['left'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
else:
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
# Apply additional customizations if provided
if additional_plotting_kwargs:
ax.set(**additional_plotting_kwargs)
Expand Down
34 changes: 12 additions & 22 deletions nbs/API/forest_plot.ipynb
Expand Up @@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -22,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -32,7 +32,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -44,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -54,7 +54,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -67,7 +67,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -154,7 +154,7 @@
" contrast_type: str = \"delta2\",\n",
" effect_size: str = \"mean_diff\",\n",
" contrast_labels: List[str] = None,\n",
" ylabel: str = \"effect_size\",\n",
" ylabel: str = \"effect size\",\n",
" plot_elements_to_extract: Optional[List] = None,\n",
" title: str = \"ΔΔ Forest\",\n",
" custom_palette: Optional[Union[dict, list, str]] = None,\n",
Expand Down Expand Up @@ -234,11 +234,13 @@
" if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):\n",
" raise TypeError(\"The `selected_indices` must be a list of integers or `None`.\")\n",
" \n",
" # For the 'contrast_type' parameter\n",
" if not isinstance(contrast_type, str):\n",
" raise TypeError(\"The `contrast_type` argument must be a string. Please choose from `delta2` and `minimeta`.\")\n",
" \n",
" raise TypeError(\"The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.\")\n",
"\n",
" # For the 'effect_size' parameter\n",
" if not isinstance(effect_size, str):\n",
" raise TypeError(\"The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.\")\n",
" raise TypeError(\"The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.\")\n",
" \n",
" if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):\n",
" raise TypeError(\"The `contrast_labels` must be a list of strings or `None`.\")\n",
Expand Down Expand Up @@ -388,18 +390,6 @@
"display_name": "python3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
Expand Down
8 changes: 3 additions & 5 deletions nbs/tests/data/mocked_data_test_forestplot.py
Expand Up @@ -40,21 +40,19 @@
"contrasts": dummy_contrasts, # Ensure this is a list of contrast objects.
"selected_indices": None, # Valid as None or a list of integers.
"contrast_type": "delta2", # Ensure it's a string and one of the allowed contrast types.
"xticklabels": None, # Valid as None or a list of strings.
"effect_size": "mean_diff", # Ensure it's a string.
"contrast_labels": ["Drug1"], # This should be a list of strings.
"ylabel": "Effect Size", # Ensure it's a string.
"plot_elements_to_extract": None, # No specific checks needed based on your tests.
"title": "ΔΔ Forest Plot", # Ensure it's a string.
#"plot_elements_to_extract": None, # No specific checks needed based on your tests.
#"title": "ΔΔ Forest Plot", # Ensure it's a string.
"custom_palette": None, # Valid as None, a dictionary, list, or string.
"fontsize": 20, # Ensure it's an integer or float.
"violin_kwargs": None, # No specific checks needed based on your tests.
"marker_size": 20, # Ensure it's a positive integer or float.
"ci_line_width": 2.5, # Ensure it's a positive integer or float.
"zero_line_width": 1, # Ensure it's a positive integer or float.
"remove_spines": True, # Ensure it's a boolean.
"additional_plotting_kwargs": None, # No specific checks needed based on your tests.
"rotation_for_xlabels": 45, # Ensure it's an integer or float between 0 and 360.
"alpha_violin_plot": 0.4, # Ensure it's a float between 0 and 1.
"alpha_violin_plot": 0.8, # Ensure it's a float between 0 and 1.
"horizontal": False, # Ensure it's a boolean.
}
6 changes: 3 additions & 3 deletions nbs/tests/test_forest_plot.py
Expand Up @@ -16,8 +16,8 @@ def test_forest_plot_no_input_parameters():
("contrasts", None, "The `contrasts` parameter cannot be None", ValueError),
("contrasts", [], "The `contrasts` argument must be a non-empty list.", ValueError),
("selected_indices", "not a list or None", "The `selected_indices` must be a list of integers or `None`.", TypeError),
("contrast_type", 123, "The `contrast_type` argument must be a string. Please choose from `delta2` and `minimeta`.", TypeError),
("effect_size", 456, "The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.", TypeError),
("contrast_type", 123, "The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.", TypeError),
("effect_size", 456, "The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.", TypeError),
("contrast_labels", ["valid", 123], "The `contrast_labels` must be a list of strings or `None`.", TypeError),
("ylabel", 789, "The `ylabel` argument must be a string.", TypeError),
("custom_palette", 123, "The `custom_palette` must be either a dictionary, list, string, or `None`.", TypeError),
Expand All @@ -28,8 +28,8 @@ def test_forest_plot_no_input_parameters():
("rotation_for_xlabels", "right", "`rotation_for_xlabels` must be an integer or float between 0 and 360.", TypeError),
("alpha_violin_plot", "opaque", "`alpha_violin_plot` must be a float between 0 and 1.", TypeError),
("horizontal", "sideways", "`horizontal` must be a boolean value.", TypeError),
("contrast_type", "unknown", "Invalid contrast_type: unknown. Available options: [`delta2`, `mini_meta`]", ValueError),
])

def test_forest_plot_input_error_handling(param_name, param_value, error_msg, error_type):
# Setup: Define a base set of valid inputs to forest_plot
valid_inputs = default_forestplot_kwargs.copy()
Expand Down

0 comments on commit 51b44aa

Please sign in to comment.