Skip to content

Commit

Permalink
Merge pull request #149 from tompollard/tp/typehints
Browse files Browse the repository at this point in the history
Minor formatting changes (typehints)
  • Loading branch information
tompollard committed Apr 24, 2023
2 parents 673254f + 84e00d4 commit 92fb976
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 36 deletions.
6 changes: 3 additions & 3 deletions requirements.txt
Expand Up @@ -4,6 +4,6 @@ numpy==1.23.2
openpyxl==3.1.2
pandas==2.0.0
pytest-cov==3.0.0
scipy==1.9.0
statsmodels==0.13.2
tabulate==0.8.10
scipy==1.10.1
statsmodels==0.13.5
tabulate==0.9.0
6 changes: 3 additions & 3 deletions setup.py
Expand Up @@ -53,9 +53,9 @@
install_requires=[
'numpy>=1.19.1',
'pandas>=1.4.3',
'scipy>=1.7.0',
'statsmodels>=0.12.1',
'tabulate>=0.8.10',
'scipy>=1.10.1',
'statsmodels>=0.13.5',
'tabulate>=0.9.0',
'Jinja2==3.1.2',
'openpyxl==3.1.2'
],
Expand Down
76 changes: 46 additions & 30 deletions tableone/tableone.py
Expand Up @@ -2,7 +2,7 @@
The tableone package is used for creating "Table 1" summary statistics for
research papers.
"""
from typing import Optional, Union
from typing import Optional, Tuple, Union
import warnings

import numpy as np
Expand All @@ -18,7 +18,7 @@
warnings.simplefilter('always', DeprecationWarning)


def load_dataset(name: str):
def load_dataset(name: str) -> pd.DataFrame:
"""
Load an example dataset from the online repository (requires internet).
Expand Down Expand Up @@ -60,7 +60,7 @@ class InputError(Exception):
pass


class TableOne(object):
class TableOne:
"""
If you use the tableone package, please cite:
Expand Down Expand Up @@ -200,7 +200,8 @@ class TableOne(object):
...
"""
def __init__(self, data: pd.DataFrame, columns: Optional[list] = None,
def __init__(self, data: pd.DataFrame,
columns: Optional[list] = None,
categorical: Optional[list] = None,
groupby: Optional[str] = None,
nonnormal: Optional[list] = None,
Expand Down Expand Up @@ -397,20 +398,23 @@ def __init__(self, data: pd.DataFrame, columns: Optional[list] = None,

# create overall tables if required
if self._categorical and self._groupby and overall:
self.cat_describe_all = self._create_cat_describe(data, False,
['Overall'])
self.cat_describe_all = self._create_cat_describe(data=data,
groupby=None,
groupbylvls=['Overall'])

if self._continuous and self._groupby and overall:
self.cont_describe_all = self._create_cont_describe(data, False)
self.cont_describe_all = self._create_cont_describe(data=data,
groupby=None)

# create descriptive tables
if self._categorical:
self.cat_describe = self._create_cat_describe(data, self._groupby,
self._groupbylvls)
self.cat_describe = self._create_cat_describe(data=data,
groupby=self._groupby,
groupbylvls=self._groupbylvls)

if self._continuous:
self.cont_describe = self._create_cont_describe(data,
self._groupby)
self.cont_describe = self._create_cont_describe(data=data,
groupby=self._groupby)

# compute standardized mean differences
if self._smd:
Expand Down Expand Up @@ -439,13 +443,13 @@ def __init__(self, data: pd.DataFrame, columns: Optional[list] = None,
if display_all:
self._set_display_options()

def __str__(self):
def __str__(self) -> str:
return self.tableone.to_string() + self._generate_remarks('\n')

def __repr__(self):
def __repr__(self) -> str:
return self.tableone.to_string() + self._generate_remarks('\n')

def _repr_html_(self):
def _repr_html_(self) -> str:
return self.tableone._repr_html_() + self._generate_remarks('<br />')

def _set_display_options(self):
Expand All @@ -465,7 +469,7 @@ def _set_display_options(self):
option.""".format(k)
warnings.warn(msg)

def tabulate(self, headers=None, tablefmt='grid', **kwargs):
def tabulate(self, headers=None, tablefmt='grid', **kwargs) -> str:
"""
Pretty-print tableone data. Wrapper for the Python 'tabulate' library.
Expand Down Expand Up @@ -500,7 +504,7 @@ def tabulate(self, headers=None, tablefmt='grid', **kwargs):

return tabulate(df, headers=headers, tablefmt=tablefmt, **kwargs)

def _generate_remarks(self, newline='\n'):
def _generate_remarks(self, newline='\n') -> str:
"""
Generate a series of remarks that the user should consider
when interpreting the summary statistics.
Expand Down Expand Up @@ -546,7 +550,7 @@ def _generate_remarks(self, newline='\n'):

return msg

def _detect_categorical_columns(self, data):
def _detect_categorical_columns(self, data) -> list:
"""
Detect categorical columns if they are not specified.
Expand Down Expand Up @@ -783,7 +787,7 @@ def _normality(self, x):

def _tukey(self, x, threshold):
"""
Count outliers according to Tukey's rule.
Find outliers according to Tukey's rule.
Where Q1 is the lower quartile and Q3 is the upper quartile,
an outlier is an observation outside of the range:
Expand All @@ -806,21 +810,21 @@ def _tukey(self, x, threshold):

return outliers

def _outliers(self, x):
def _outliers(self, x) -> int:
"""
Compute number of outliers
"""
outliers = self._tukey(x, threshold=1.5)
return np.size(outliers)

def _far_outliers(self, x):
def _far_outliers(self, x) -> int:
"""
Compute number of "far out" outliers
"""
outliers = self._tukey(x, threshold=3.0)
return np.size(outliers)

def _t1_summary(self, x):
def _t1_summary(self, x: pd.Series) -> str:
"""
Compute median [IQR] or mean (Std) for the input series.
Expand Down Expand Up @@ -867,7 +871,9 @@ def _t1_summary(self, x):
f = '{{:.{}f}} ({{:.{}f}})'.format(n, n)
return f.format(np.nanmean(x.values), self._std(x))

def _create_cont_describe(self, data, groupby):
def _create_cont_describe(self,
data: pd.DataFrame,
groupby: Optional[str] = None) -> pd.DataFrame:
"""
Describe the continuous data.
Expand Down Expand Up @@ -937,7 +943,10 @@ def _create_cont_describe(self, data, groupby):

return df_cont

def _format_cat(self, row, col):
def _format_cat(self, row, col) -> str:
"""
Format values to n decimal places.
"""
var = row.name[0]
if var in self._decimals:
n = self._decimals[var]
Expand All @@ -946,7 +955,9 @@ def _format_cat(self, row, col):
f = '{{:.{}f}}'.format(n)
return f.format(row[col])

def _create_cat_describe(self, data, groupby, groupbylvls):
def _create_cat_describe(self, data: pd.DataFrame,
groupby: Optional[str] = None,
groupbylvls: Optional[list] = None) -> pd.DataFrame:
"""
Describe the categorical data.
Expand Down Expand Up @@ -1054,7 +1065,7 @@ def _create_cat_describe(self, data, groupby, groupbylvls):

return df_cat

def _create_htest_table(self, data):
def _create_htest_table(self, data: pd.DataFrame) -> pd.DataFrame:
"""
Create a table containing P-Values for significance tests. Add features
of the distributions and the P-Values to the dataframe.
Expand Down Expand Up @@ -1119,7 +1130,7 @@ def _create_htest_table(self, data):

return df

def _create_smd_table(self, data):
def _create_smd_table(self, data: pd.DataFrame) -> pd.DataFrame:
"""
Create a table containing pairwise Standardized Mean Differences
(SMDs).
Expand Down Expand Up @@ -1180,8 +1191,13 @@ def _create_smd_table(self, data):

return df

def _p_test(self, v, grouped_data, is_continuous, is_categorical,
is_normal, min_observed, catlevels):
def _p_test(self, v: str,
grouped_data: dict,
is_continuous: bool,
is_categorical: bool,
is_normal: bool,
min_observed: int,
catlevels: list):
"""
Compute P-Values.
Expand Down Expand Up @@ -1267,7 +1283,7 @@ def _p_test(self, v, grouped_data, is_continuous, is_categorical,

return pval, ptest

def _create_cont_table(self, data, overall):
def _create_cont_table(self, data, overall) -> pd.DataFrame:
"""
Create tableone for continuous data.
Expand Down Expand Up @@ -1582,7 +1598,7 @@ def _create_tableone(self, data):

return table

def _create_row_labels(self):
def _create_row_labels(self) -> dict:
"""
Take the original labels for rows. Rename if alternative labels are
provided. Append label suffix if label_suffix is True.
Expand Down

0 comments on commit 92fb976

Please sign in to comment.