Skip to content

Commit

Permalink
dotplot add more control of x-axis, y-axis order, #175
Browse files Browse the repository at this point in the history
  • Loading branch information
zqfang committed Dec 13, 2022
1 parent a3024c8 commit a4a31c5
Showing 1 changed file with 78 additions and 29 deletions.
107 changes: 78 additions & 29 deletions gseapy/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import operator
import sys
import warnings
from collections.abc import Iterable
from typing import Iterable, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import matplotlib.transforms as transforms
import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as sch
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.category import UnitData
from matplotlib.colors import Normalize
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
Expand Down Expand Up @@ -69,7 +72,7 @@ def __init__(
ofname: Optional[str] = None,
**kwargs,
):
self.title = "Heatmap of the Analyzed Geneset" if title is None else title
self.title = "" if title is None else title
self.figsize = figsize
self.xticklabels = xticklabels
self.yticklabels = yticklabels
Expand All @@ -80,7 +83,7 @@ def __init__(
df = zscore(df, axis=z_score)
df = df.iloc[::-1]
self._df = df
self.cbar_title = "Scaled Exp" if z_score is None else "Z-Score"
self.cbar_title = "Norm.Exp" if z_score is None else "Z-Score"
self.cmap = cmap
if cmap is None:
self.cmap = SciPalette.create_colormap() # navyblue2darkred
Expand Down Expand Up @@ -511,7 +514,9 @@ def __init__(
x: Optional[str] = None,
y: str = "Term",
hue: str = "Adjusted P-value",
size_scale: float = 5.0,
dot_scale: float = 5.0,
x_order: Optional[List[str]] = None,
y_order: Optional[List[str]] = None,
thresh: float = 0.05,
n_terms: int = 10,
title: str = "",
Expand All @@ -525,12 +530,14 @@ def __init__(
self.marker = kwargs["marker"]
self.y = y
self.x = x
self.x_order = x_order
self.y_order = y_order
self.hue = str(hue)
self.colname = str(hue)
self.figsize = figsize
self.cmap = cmap
self.ofname = ofname
self.scale = size_scale
self.scale = dot_scale
self.title = title
self.n_terms = n_terms
self.thresh = thresh
Expand Down Expand Up @@ -593,6 +600,31 @@ def process(self, df):
df = df.assign(Hits_ratio=temp.iloc[:, 0] / temp.iloc[:, 1])
return df

def get_y_order(
self, method: str = "single", metric: str = "euclidean"
) -> List[str]:
"""See scipy.cluster.hierarchy.linkage()
Perform hierarchical/agglomerative clustering.
Return categorical order.
"""
if isinstance(self.y_order, Iterable):
return self.y_order
mat = self._df.pivot(
index=self.y,
columns=self.x,
values=self.colname, # [self.colname, "Hits_ratio"],
).fillna(0)
Y0 = sch.linkage(mat, method=method, metric=metric)
Z0 = sch.dendrogram(
Y0,
orientation="left",
# labels=mat.index,
no_plot=True,
distance_sort="descending",
)
idx = Z0["leaves"][::-1] # reverse the order to make the view better
return list(mat.index[idx])

def get_ax(self):
"""
setup figure axes
Expand Down Expand Up @@ -634,7 +666,10 @@ def set_x(self):

return x, xlabel

def scatter(self, outer_ring=False):
def scatter(
self,
outer_ring: bool = False,
):
"""
build scatter
"""
Expand All @@ -656,13 +691,18 @@ def scatter(self, outer_ring=False):
# if self.x is None:
x, xlabel = self.set_x()
y = self.y
# set x, y order
xunits = UnitData(self.x_order) if self.x_order else None
yunits = UnitData(self.get_y_order()) if self.x else None

# outer ring
if outer_ring:
smax = df["area"].max()
# TODO:
# Matplotlib BUG: when setting edge colors,
# there's the center of scatter could not aligned.
# Instead, I have to add more dots in the plot to get the ring
# Must set backend to TKcario... to fix it
# Instead, I just add more dots in the plot to get the ring
blk_sc = ax.scatter(
x=x,
y=y,
Expand All @@ -671,6 +711,8 @@ def scatter(self, outer_ring=False):
c="black",
data=df,
marker=self.marker,
xunits=xunits, # set x categorical order
yunits=yunits, # set y categorical order
zorder=0,
)
wht_sc = ax.scatter(
Expand All @@ -681,6 +723,8 @@ def scatter(self, outer_ring=False):
c="white",
data=df,
marker=self.marker,
xunits=xunits, # set x categorical order
yunits=yunits, # set y categorical order
zorder=1,
)
# data = np.array(rg.get_offsets()) # get data coordinates
Expand All @@ -696,6 +740,8 @@ def scatter(self, outer_ring=False):
vmin=vmin,
vmax=vmax,
marker=self.marker,
xunits=xunits, # set x categorical order
yunits=yunits, # set y categorical order
zorder=2,
)
ax.set_xlabel(xlabel, fontsize=14, fontweight="bold")
Expand Down Expand Up @@ -730,6 +776,7 @@ def scatter(self, outer_ring=False):
)
ax.set_title(self.title, fontsize=20, fontweight="bold")
self.add_colorbar(sc)

return ax

def add_colorbar(self, sc):
Expand Down Expand Up @@ -894,7 +941,10 @@ def to_edgelist(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
def dotplot(
df: pd.DataFrame,
column: str = "Adjusted P-value",
group: Optional[str] = None,
x: Optional[str] = None,
y: str = "Term",
x_order: Optional[List[str]] = None,
y_order: Optional[List[str]] = None,
title: str = "",
cutoff: float = 0.05,
top_term: int = 10,
Expand All @@ -908,41 +958,45 @@ def dotplot(
show_ring: bool = False,
**kwargs,
):
"""Visualize GSEApy Results.
"""Visualize GSEApy Results with categorical scatterplot
When multiple datasets exist in the input dataframe, the `group` argument is your friend.
:param df: GSEApy DataFrame results.
:param column: column name in `df` to map the dot colors. Default: Adjusted P-value
:param group: group by the variable in `df` that will produce categorical scatterplot.
:param title: figure title
:param cutoff: terms with `column` value < cut-off are shown. Work only for
:param column: column name in `df` that map the dot colors. Default: Adjusted P-value.
:param x: Categorical variable in `df` that map the x-axis data. Default: None.
:param y: Categorical variable in `df` that map the y-axis data. Default: Term.
:param x_order: X-axis order to plot the `x` categorical levels. Default: None.
:param y_order: Y-axis order to plot the `y` categorical levels. Default: None.
:param title: Figure title.
:param cutoff: Terms with `column` value < cut-off are shown. Work only for
("Adjusted P-value", "P-value", "NOM p-val", "FDR q-val")
:param top_term: number of enriched terms to show.
:param top_term: Number of enriched terms to show.
:param size: float, scale the dot size to get proper visualization.
:param figsize: tuple, matplotlib figure size.
:param cmap: matplotlib colormap for mapping the `column` semantic.
:param ofname: output file name. If None, don't save figure
:param marker: the matplotlib.markers. See https://matplotlib.org/stable/api/markers_api.html
:param show_ring bool: whether to show outer ring.
:param cmap: Matplotlib colormap for mapping the `column` semantic.
:param ofname: Output file name. If None, don't save figure
:param marker: The matplotlib.markers. See https://matplotlib.org/stable/api/markers_api.html
:param show_ring bool: Whether to draw outer ring.
:return: matplotlib.Axes. return None if given ofname.
Only terms with `column` <= `cut-off` are plotted.
"""

dot = DotPlot(
df=df,
x=group,
y="Term",
x=x,
y=y,
x_order=x_order,
y_order=y_order,
hue=column,
title=title,
thresh=cutoff,
n_terms=int(top_term),
size_scale=size,
dot_scale=size,
figsize=figsize,
cmap=cmap,
ofname=ofname,
marker=marker,
**kwargs,
)
ax = dot.scatter(outer_ring=show_ring)

Expand All @@ -964,7 +1018,7 @@ def dotplot(
def ringplot(
df: pd.DataFrame,
column: str = "Adjusted P-value",
group: Optional[str] = None,
x: Optional[str] = None,
title: str = "",
cutoff: float = 0.05,
top_term: int = 10,
Expand All @@ -981,7 +1035,7 @@ def ringplot(
"""ringplot is deprecated, use dotplot instead
:param df: GSEApy DataFrame results.
:param group: the old `x`. Group by the variable in `df` that will produce categorical scatterplot.
:param x: Group by the variable in `df` that will produce categorical scatterplot.
:param column: column name in `df` to map the dot colors. Default: Adjusted P-value
:param title: figure title
:param cutoff: terms with `column` value < cut-off are shown. Work only for
Expand All @@ -998,12 +1052,7 @@ def ringplot(
Only terms with `column` <= `cut-off` are plotted.
"""
warnings.warn("ringplot is deprecated; use dotplot instead", DeprecationWarning, 2)
if "x" in kwargs:
warnings.warn("x is deprecated; use group", DeprecationWarning, 2)
kwargs["group"] = kwargs["x"]
del kwargs["x"]
ax = dotplot(df, **kwargs)
return ax
return


def barplot(
Expand Down

0 comments on commit a4a31c5

Please sign in to comment.