Skip to content

Commit

Permalink
refactor data parsing, now can set pos and neg mannually #221
Browse files Browse the repository at this point in the history
  • Loading branch information
zqfang committed Oct 25, 2023
1 parent 14ca305 commit a169925
Showing 1 changed file with 46 additions and 24 deletions.
70 changes: 46 additions & 24 deletions gseapy/gsea.py
Expand Up @@ -6,7 +6,7 @@
import os
import xml.etree.ElementTree as ET
from collections import Counter
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(
verbose=verbose,
)
self.data = data
self.classes = classes
# self.classes = classes
self.permutation_type = permutation_type
self.method = method
self.min_size = min_size
Expand All @@ -65,10 +65,12 @@ def __init__(
self.seed = seed
self.ranking = None
self._noplot = no_plot
self.pheno_pos = "pos"
self.pheno_neg = "neg"
# phenotype labels parsing
self.load_classes(classes)

def load_data(self, cls_vec: List[str]) -> Tuple[pd.DataFrame, Dict]:
def load_data(
self, groups: Union[List[str], Dict[str, Any]]
) -> Tuple[pd.DataFrame, Dict]:
"""pre-processed the data frame.new filtering methods will be implement here."""
# read data in
if isinstance(self.data, pd.DataFrame):
Expand All @@ -88,6 +90,14 @@ def load_data(self, cls_vec: List[str]) -> Tuple[pd.DataFrame, Dict]:
else:
raise Exception("Error parsing gene expression DataFrame!")

exprs = self._check_data(exprs)
exprs, cls_dict = self._filter_data(exprs)
return exprs, cls_dict

def _check_data(self, exprs: pd.DataFrame) -> pd.DataFrame:
"""
check NAs, duplicates.
"""
if exprs.isnull().any().sum() > 0:
self._logger.warning("Input data contains NA, filled NA with 0")
exprs.dropna(how="all", inplace=True) # drop rows with all NAs
Expand All @@ -102,12 +112,29 @@ def load_data(self, cls_vec: List[str]) -> Tuple[pd.DataFrame, Dict]:
"Found duplicated gene names, values averaged by gene names!"
)
df = df.groupby(level=0).mean()
return df

def _map_classes(self, sample_names: List[str]) -> Dict[str, Any]:
"""
update
"""
cls_dict = self.groups
if isinstance(self.groups, dict):
# update groups
self.groups = [cls_dict[c] for c in sample_names]
else:
cls_dict = {k: v for k, v in zip(sample_names, self.groups)}
return cls_dict

def _filter_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""
filter data rows with std == 0
"""
# in case the description column is numeric
if len(cls_vec) == (df.shape[1] - 1):
if len(self.groups) == (df.shape[1] - 1):
df = df.iloc[:, 1:]
cls_dict = self._map_classes(df.columns)
# drop gene which std == 0 in all samples
cls_dict = {k: v for k, v in zip(df.columns, cls_vec)}
# compatible to py3.7
major, minor, _ = [int(i) for i in pd.__version__.split(".")]
if (major == 1 and minor < 5) or (major < 1):
Expand All @@ -120,13 +147,13 @@ def load_data(self, cls_vec: List[str]) -> Tuple[pd.DataFrame, Dict]:

return df, cls_dict

def calculate_metric(
def calc_metric(
self,
df: pd.DataFrame,
method: str,
pos: str,
neg: str,
classes: Dict[str, List[str]],
classes: Dict[str, str],
ascending: bool,
) -> Tuple[List[int], pd.Series]:
"""The main function to rank an expression table. works for 2d array.
Expand Down Expand Up @@ -210,13 +237,11 @@ def calculate_metric(
# descending order
return ser_ind[::-1], ser[::-1]

def load_classes(
self,
):
def load_classes(self, classes: Union[str, List[str], Dict[str, Any]]):
"""Parse group (classes)"""
if isinstance(self.classes, dict):
if isinstance(classes, dict):
# check number of samples
class_values = Counter(self.classes.values())
class_values = Counter(classes.values())
s = []
for c, v in sorted(class_values.items(), key=lambda item: item[1]):
if v < 3:
Expand All @@ -226,12 +251,12 @@ def load_classes(
self.pheno_neg = s[1]
# n_pos = class_values[pos]
# n_neg = class_values[neg]
return
self.groups = classes
else:
pos, neg, cls_vector = gsea_cls_parser(self.classes)
pos, neg, cls_vector = gsea_cls_parser(classes)
self.pheno_pos = pos
self.pheno_neg = neg
return cls_vector
self.groups = cls_vector

# @profile
def run(self):
Expand All @@ -257,11 +282,8 @@ def run(self):

# Start Analysis
self._logger.info("Parsing data files for GSEA.............................")
# phenotype labels parsing
cls_vector = self.load_classes()
# select correct expression genes and values.
dat, cls_dict = self.load_data(cls_vector)
self.cls_dict = cls_dict
dat, cls_dict = self.load_data(self.groups)
# data frame must have length > 1
assert len(dat) > 1
# filtering out gene sets and build gene sets dictionary
Expand All @@ -275,7 +297,7 @@ def run(self):
# compute ES, NES, pval, FDR, RES
if self.permutation_type == "gene_set":
# ranking metrics calculation.
idx, dat2 = self.calculate_metric(
idx, dat2 = self.calc_metric(
df=dat,
method=self.method,
pos=self.pheno_pos,
Expand All @@ -301,7 +323,7 @@ def run(self):
gsum.indices = indices # only accept [[]]
else: # phenotype permutation
group = list(
map(lambda x: True if x == self.pheno_pos else False, cls_vector)
map(lambda x: True if x == self.pheno_pos else False, self.groups)
)
gsum = gsea_rs(
dat.index.values.tolist(),
Expand All @@ -324,7 +346,7 @@ def run(self):
self.ranking = pd.Series(gsum.rankings[0], index=dat.index[gsum.indices[0]])
# reorder datarame for heatmap
# self._heatmat(df=dat.loc[dat2.index], classes=cls_vector)
self._heatmat(df=dat.iloc[gsum.indices[0]], classes=cls_vector)
self._heatmat(df=dat.iloc[gsum.indices[0]], classes=self.groups)
# write output and plotting
self.to_df(gsum.summaries, gmt, self.ranking)
self._logger.info("Congratulations. GSEApy ran successfully.................\n")
Expand Down

0 comments on commit a169925

Please sign in to comment.