-
Notifications
You must be signed in to change notification settings - Fork 45
/
forest_plot.py
323 lines (274 loc) · 12.7 KB
/
forest_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/forest_plot.ipynb.
# %% auto 0
__all__ = ['load_plot_data', 'extract_plot_data', 'map_effect_attribute', 'forest_plot']
# %% ../nbs/API/forest_plot.ipynb 5
import matplotlib.pyplot as plt
# %matplotlib inline
import seaborn as sns
from typing import List, Optional, Union
# %% ../nbs/API/forest_plot.ipynb 6
def load_plot_data(
contrasts: List, effect_size: str = "mean_diff", contrast_type: str = "delta2"
) -> List:
"""
Loads plot data based on specified effect size and contrast type.
Parameters
----------
contrasts : List
List of contrast objects.
effect_size: str
Type of effect size ('mean_diff', 'median_diff', etc.).
contrast_type: str
Type of contrast ('delta2', 'mini_meta').
Returns
-------
List: Contrast plot data based on specified parameters.
"""
effect_attr_map = {
"mean_diff": "mean_diff",
"median_diff": "median_diff",
"cliffs_delta": "cliffs_delta",
"cohens_d": "cohens_d",
"hedges_g": "hedges_g",
"delta_g": "delta_g"
}
contrast_attr_map = {"delta2": "delta_delta", "mini_meta": "mini_meta_delta"}
effect_attr = effect_attr_map.get(effect_size)
contrast_attr = contrast_attr_map.get(contrast_type)
if not effect_attr:
raise ValueError(f"Invalid effect_size: {effect_size}")
if not contrast_attr:
raise ValueError(f"Invalid contrast_type: {contrast_type}. Available options: [`delta2`, `mini_meta`]")
return [
getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in contrasts
]
def extract_plot_data(contrast_plot_data, contrast_type):
"""Extracts bootstrap, difference, and confidence intervals based on contrast labels."""
if contrast_type == "mini_meta":
attribute_suffix = "weighted_delta"
else:
attribute_suffix = "delta_delta"
bootstraps = [
getattr(result, f"bootstraps_{attribute_suffix}")
for result in contrast_plot_data
]
differences = [result.difference for result in contrast_plot_data]
bcalows = [result.bca_low for result in contrast_plot_data]
bcahighs = [result.bca_high for result in contrast_plot_data]
return bootstraps, differences, bcalows, bcahighs
def map_effect_attribute(attribute_key):
# Check if the attribute key exists in the dictionary
effect_attr_map = {
"mean_diff": "Mean Difference",
"median_diff": "Median Difference",
"cliffs_delta": "Cliffs Delta",
"cohens_d": "Cohens d",
"hedges_g": "Hedges g",
"delta_g": "Delta g"
}
if attribute_key in effect_attr_map:
return effect_attr_map[attribute_key]
else:
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.") # Return a default value or message if the key is not found
def forest_plot(
contrasts: List,
selected_indices: Optional[List] = None,
contrast_type: str = "delta2",
effect_size: str = "mean_diff",
contrast_labels: List[str] = None,
ylabel: str = "effect size",
plot_elements_to_extract: Optional[List] = None,
title: str = "ΔΔ Forest",
custom_palette: Optional[Union[dict, list, str]] = None,
fontsize: int = 12,
title_font_size: int =16,
violin_kwargs: Optional[dict] = None,
marker_size: int = 20,
ci_line_width: float = 2.5,
desat_violin: float = 1,
remove_spines: bool = True,
ax: Optional[plt.Axes] = None,
additional_plotting_kwargs: Optional[dict] = None,
rotation_for_xlabels: int = 45,
alpha_violin_plot: float = 0.8,
horizontal: bool = False # New argument for horizontal orientation
)-> plt.Figure:
"""
Custom function that generates a forest plot from given contrast objects, suitable for a range of data analysis types, including those from packages like DABEST-python.
Parameters
----------
contrasts : List
List of contrast objects.
selected_indices : Optional[List], default=None
Indices of specific contrasts to plot, if not plotting all.
analysis_type : str
the type of analysis (e.g., 'delta2', 'mini_meta').
effect_size : str
Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g`).
contrast_labels : List[str]
Labels for each contrast.
ylabel : str
Label for the y-axis, describing the plotted data or effect size.
plot_elements_to_extract : Optional[List], default=None
Elements to extract for detailed plot customization.
title : str
Plot title, summarizing the visualized data.
ylim : Tuple[float, float]
Limits for the y-axis.
custom_palette : Optional[Union[dict, list, str]], default=None
Custom color palette for the plot.
fontsize : int
Font size for text elements in the plot.
title_font_size: int =16
Font size for text of plot title.
violin_kwargs : Optional[dict], default=None
Additional arguments for violin plot customization.
marker_size : int
Marker size for plotting mean differences or effect sizes.
ci_line_width : float
Width of confidence interval lines.
remove_spines : bool, default=False
If True, removes top and right plot spines.
ax : Optional[plt.Axes], default=None
Matplotlib Axes object for the plot; creates new if None.
additional_plotting_kwargs : Optional[dict], default=None
Further customization arguments for the plot.
rotation_for_xlabels : int, default=0
Rotation angle for x-axis labels, improving readability.
alpha_violin_plot : float, default=1.0
Transparency level for violin plots.
Returns
-------
plt.Figure
The matplotlib figure object with the generated forest plot.
"""
from .plot_tools import halfviolin
# Validate inputs
if contrasts is None:
raise ValueError("The `contrasts` parameter cannot be None")
if not isinstance(contrasts, list) or not contrasts:
raise ValueError("The `contrasts` argument must be a non-empty list.")
if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):
raise TypeError("The `selected_indices` must be a list of integers or `None`.")
# For the 'contrast_type' parameter
if not isinstance(contrast_type, str):
raise TypeError("The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.")
# For the 'effect_size' parameter
if not isinstance(effect_size, str):
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.")
if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):
raise TypeError("The `contrast_labels` must be a list of strings or `None`.")
if contrast_labels is not None and len(contrast_labels) != len(contrasts):
raise ValueError("`contrast_labels` must match the number of `contrasts` if provided.")
if not isinstance(ylabel, str):
raise TypeError("The `ylabel` argument must be a string.")
if custom_palette is not None and not isinstance(custom_palette, (dict, list, str, type(None))):
raise TypeError("The `custom_palette` must be either a dictionary, list, string, or `None`.")
if not isinstance(fontsize, (int, float)):
raise TypeError("`fontsize` must be an integer or float.")
if not isinstance(marker_size, (int, float)) or marker_size <= 0:
raise TypeError("`marker_size` must be a positive integer or float.")
if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:
raise TypeError("`ci_line_width` must be a positive integer or float.")
if not isinstance(remove_spines, bool):
raise TypeError("`remove_spines` must be a boolean value.")
if ax is not None and not isinstance(ax, plt.Axes):
raise TypeError("`ax` must be a `matplotlib.axes.Axes` instance or `None`.")
if not isinstance(rotation_for_xlabels, (int, float)) or not 0 <= rotation_for_xlabels <= 360:
raise TypeError("`rotation_for_xlabels` must be an integer or float between 0 and 360.")
if not isinstance(alpha_violin_plot, float) or not 0 <= alpha_violin_plot <= 1:
raise TypeError("`alpha_violin_plot` must be a float between 0 and 1.")
if not isinstance(horizontal, bool):
raise TypeError("`horizontal` must be a boolean value.")
if (effect_size and isinstance(effect_size, str)):
ylabel = map_effect_attribute(effect_size)
# Load plot data
contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)
# Extract data for plotting
bootstraps, differences, bcalows, bcahighs = extract_plot_data(
contrast_plot_data, contrast_type
)
# Adjust figure size based on orientation
all_groups_count = len(contrasts)
if horizontal:
fig_size = (4, 1.5 * all_groups_count)
else:
fig_size = (1.5 * all_groups_count, 4)
if ax is None:
fig, ax = plt.subplots(figsize=fig_size)
else:
fig = ax.figure
# Adjust violin plot orientation based on the 'horizontal' argument
violin_kwargs = violin_kwargs or {
"widths": 0.5,
"showextrema": False,
"showmedians": False,
}
violin_kwargs["vert"] = not horizontal
v = ax.violinplot(bootstraps, **violin_kwargs)
# Adjust the halfviolin function call based on 'horizontal'
if horizontal:
half = "top"
else:
half = "right" # Assuming "right" is the default or another appropriate value
# Assuming halfviolin has been updated to accept a 'half' parameter
halfviolin(v, alpha=alpha_violin_plot, half=half)
# Handle the custom color palette
if custom_palette:
if isinstance(custom_palette, dict):
violin_colors = [
custom_palette.get(c, sns.color_palette()[0]) for c in contrast_labels
]
elif isinstance(custom_palette, list):
violin_colors = custom_palette[: len(contrasts)]
elif isinstance(custom_palette, str):
if custom_palette in plt.colormaps():
violin_colors = sns.color_palette(custom_palette, len(contrasts))
else:
raise ValueError(
f"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette."
)
else:
violin_colors = sns.color_palette(n_colors=len(contrasts))
violin_colors = [sns.desaturate(color, desat_violin) for color in violin_colors]
for patch, color in zip(v["bodies"], violin_colors):
patch.set_facecolor(color)
patch.set_alpha(alpha_violin_plot)
if horizontal:
ax.plot([0, 0], [0, len(contrasts)+1], 'k', linewidth = 1)
else:
ax.plot([0, len(contrasts)+1], [0, 0], 'k', linewidth = 1)
# Flipping the axes for plotting based on 'horizontal'
for k in range(1, len(contrasts) + 1):
if horizontal:
ax.plot(differences[k - 1], k, "k.", markersize=marker_size) # Flipped axes
ax.plot([bcalows[k - 1], bcahighs[k - 1]], [k, k], "k", linewidth=ci_line_width) # Flipped axes
else:
ax.plot(k, differences[k - 1], "k.", markersize=marker_size)
ax.plot([k, k], [bcalows[k - 1], bcahighs[k - 1]], "k", linewidth=ci_line_width)
# Adjusting labels, ticks, and limits based on 'horizontal'
if horizontal:
ax.set_yticks(range(1, len(contrasts) + 1))
ax.set_yticklabels(contrast_labels, rotation=0, fontsize=fontsize)
ax.set_xlabel(ylabel, fontsize=fontsize)
ax.set_ylim([0.7, len(contrasts) + 0.5])
else:
ax.set_xticks(range(1, len(contrasts) + 1))
ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
ax.set_xlim([0.7, len(contrasts) + 0.5])
# Setting the title and adjusting spines as before
ax.set_title(title, fontsize=title_font_size)
if remove_spines:
if horizontal:
ax.spines['left'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
else:
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
# Apply additional customizations if provided
if additional_plotting_kwargs:
ax.set(**additional_plotting_kwargs)
return fig