Skip to content

Commit

Permalink
fixed gene filtering with sample size <=3, #253
Browse files Browse the repository at this point in the history
  • Loading branch information
zqfang committed Mar 20, 2024
1 parent 0d01bb8 commit ca08aa0
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions gseapy/gsea.py
Expand Up @@ -75,7 +75,9 @@ def load_data(self) -> Tuple[pd.DataFrame, Dict]:
"""pre-processed the data frame.new filtering methods will be implement here."""
exprs = self._load_data(self.data)
exprs = self._check_data(exprs)
print(exprs)
exprs, cls_dict = self._filter_data(exprs)

return exprs, cls_dict

def _map_classes(self, sample_names: List[str]) -> Dict[str, Any]:
Expand All @@ -101,13 +103,22 @@ def _filter_data(self, df: pd.DataFrame) -> pd.DataFrame:
# drop gene which std == 0 in all samples
# compatible to py3.7
major, minor, _ = [int(i) for i in pd.__version__.split(".")]
# handle cases for samples < 3, use mean
if (major == 1 and minor < 5) or (major < 1):
# fix numeric_only error
df_std = df.groupby(by=cls_dict, axis=1).std()
df_std = df.groupby(by=cls_dict, axis=1).std(ddof=0)
else:
df_std = df.groupby(by=cls_dict, axis=1).std(numeric_only=True)
df = df[df_std.sum(axis=1) > 0]
df = df + 1e-08 # we don't like zeros!!!
df_std = df.groupby(by=cls_dict, axis=1).std(numeric_only=True, ddof=0)

print(df)
# remove rows that are all zeros !
df = df.loc[df.abs().sum(axis=1) > 0, :]
# remove rows that std are zeros for sample size >= 3 in each group
if all(map(lambda a: a[1] >= 3, Counter(cls_dict.values()).items())):
df = df[df_std.abs().sum(axis=1) > 0]
df = df + 1e-08 # we don't like zeros in denominator !!!
# data frame must have length > 1
assert df.shape[0] > 1

return df, cls_dict

Expand Down Expand Up @@ -194,6 +205,7 @@ def calc_metric(
if ser.isna().sum() > 0:
self._logger.warning("Invalid value encountered in log2, and dumped.")
ser = ser.dropna()
assert len(ser) > 1
else:
logging.error("Please provide correct method name!!!")
raise LookupError("Input method: %s is not supported" % method)
Expand Down Expand Up @@ -261,8 +273,6 @@ def run(self):
self._logger.info("Parsing data files for GSEA.............................")
# select correct expression genes and values.
dat, cls_dict = self.load_data()
# data frame must have length > 1
assert len(dat) > 1
# filtering out gene sets and build gene sets dictionary
gmt = self.load_gmt(gene_list=dat.index.values, gmt=self.gene_sets)
self.gmt = gmt
Expand Down

0 comments on commit ca08aa0

Please sign in to comment.