Skip to content

Commit

Permalink
add support to barplot color dict input, #224
Browse files Browse the repository at this point in the history
  • Loading branch information
zqfang committed Sep 29, 2023
1 parent ab76367 commit ae4ac3f
Showing 1 changed file with 30 additions and 12 deletions.
42 changes: 30 additions & 12 deletions gseapy/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,22 @@ def add_colorbar(self, sc):
for key, spine in cbar.ax.spines.items():
spine.set_visible(False)

def _parse_colors(self, color=None):
"""
parse colors for groups
"""
# map color to group
if isinstance(color, dict):
return list(color.values())
# get default color cycle
if (not isinstance(color, str)) and hasattr(color, "__len__"):
_colors = list(color)
else:
# get current matplotlib color cycle
prop_cycle = plt.rcParams["axes.prop_cycle"]
_colors = prop_cycle.by_key()["color"]
return _colors

def barh(self, color=None, group=None, ax=None):
"""
Barplot
Expand All @@ -956,31 +972,31 @@ def barh(self, color=None, group=None, ax=None):
bar.set_ylabel("")
bar.set_title(self.title, fontsize=24, fontweight="bold")
bar.xaxis.set_major_locator(MaxNLocator(nbins=5, integer=True))

# get default color cycle
if (not isinstance(color, str)) and hasattr(color, "__len__"):
_colors = list(color)
else:
prop_cycle = plt.rcParams["axes.prop_cycle"]
_colors = prop_cycle.by_key()["color"]
colors = _colors
#
_colors = self._parse_colors(color=color)
# remove old legend first
bar.legend_.remove()
if (group is not None) and (group in self.data.columns):
num_grp = self.data[group].value_counts(sort=False)
# set colors for each bar (groupby hue)
# set colors for each bar (groupby hue) using full length
colors = []
legend_elements = []
for i, n in enumerate(num_grp):
# cycle _colors if num_grp > len(_colors)
c = _colors[i % len(_colors)]
# group_label
label = num_grp.index[i]
# if input color is a dict with keys in group
if isinstance(color, dict) and label in color:
c = color[label]
# expand the length to match bars
colors += [c] * n
ele = Line2D(
xdata=[0],
ydata=[0],
marker="o",
color="w",
label=num_grp.index[i],
label=label,
markerfacecolor=c,
markersize=8,
)
Expand All @@ -993,6 +1009,7 @@ def barh(self, color=None, group=None, ax=None):
bbox_to_anchor=(1.02, 0.5),
frameon=False,
)

# update color of bars
for j, b in enumerate(ax.patches):
c = colors[j % len(colors)]
Expand Down Expand Up @@ -1210,7 +1227,7 @@ def barplot(
cutoff: float = 0.05,
top_term: int = 10,
figsize: Tuple[float, float] = (4, 6),
color: Union[str, List[str]] = "salmon",
color: Union[str, List[str], Dict[str, str]] = "salmon",
ofname: Optional[str] = None,
**kwargs,
):
Expand All @@ -1225,7 +1242,8 @@ def barplot(
("Adjusted P-value", "P-value", "NOM p-val", "FDR q-val")
:param top_term: number of top enriched terms grouped by `hue` are shown.
:param figsize: tuple, matplotlib figsize.
:param color: color or list of matplotlib.colors. Must be reconigzed by matplotlib.
:param color: color or list or dict of matplotlib.colors. Must be reconigzed by matplotlib.
if dict input, dict keys must be found in the `group`
:param ofname: output file name. If None, don't save figure
:return: matplotlib.Axes. return None if given ofname.
Expand Down

0 comments on commit ae4ac3f

Please sign in to comment.