Skip to content

paidamoyo/survival_cluster_analysis

Repository files navigation

Survival Cluster Analysis (ACM CHIL 2020)

This repository contains the TensorFlow code to replicate experiments in our paper Survival Cluster Analysis accepted at ACM Conference on Health, Inference, and Learning (ACM CHIL) 2020:

@inproceedings{chapfuwa2020survival, 
  title={Survival Cluster Analysis},
  author={Paidamoyo Chapfuwa and Chunyuan Li and Nikhil Mehta and Lawrence Carin and Ricardo Henao},
  booktitle={ACM Conference on Health, Inference, and Learning},
  year={2020}
}

Model

Model

Illustration of Survival Clustering Analysis (SCA). The latent space has a mixture-of-distributions structure, illustrated as three mixture components. Observation x is mapped into its latent representation z via a deterministic encoding, which is then used to stochastically predict (via sampling) the time-to-event p(t|x).

Risk

Cluster-specific Kaplan-Meier survival profiles for three clustering methods on the SLEEP dataset. Our model (SCA) can identify high-, medium- and low-risk individuals. Demonstrating the need to account for time information via a non-linear transformation of covariates when clustering survival datasets.

Prerequisites

The code is implemented with the following dependencies:

pip install -r requirements.txt

Data

We consider the following datasets:

  • SUPPORT
  • Flchain
  • SEER
  • SLEEP: A subset of the Sleep Heart Health Study (SHHS), a multi-center cohort study implemented by the National Heart Lung & Blood Institute to determine the cardiovascular and other consequences of sleep-disordered breathing.
  • Framingham: A subset (Framingham Offspring) of the longitudinal study of heart disease dataset, initially for predicting 10-year risk for future coronary heart disease (CHD).
  • EHR: A large study from Duke University Health System centered around inpatient visits due to comorbidities in patients with Type-2 diabetes.

For convenience, we provide pre-processing scripts of all datasets (except EHR and Framingham). In addition, the data directory contains downloaded Flchain and SUPPORT datasets.

Model Training

Please modify the train arguments with the chosen:

  • dataset is set to one of the three public datasets {flchain, support, seer, sleep}, the default is support
  • K cluster uppper bound n_clusters, the default is 25
  • Dirichlet process concetration parameter gamma_0 selected from {2, 3, 4, 8}, default is 2
 python train.py --dataset support --n_clusters 25 --gamma_0 2
  • The hyper-parameters settings can be found at configs.py

Metrics and Visualizations

Once the networks are trained and the results are saved, we extract the following key results:

  • Training and evaluation metrics are logged in model.log
  • Epoch based cost function plots can be found in the plots directory
  • Numpy files to generate calibration and cluster plots are saved in matrix directory
  • Run the Calibration.ipynb to generate calibration results and Clustering.ipynb for clustering results

Acknowledgments

This work leverages the calibration framework from SFM and the accuracy objective from DATE. Contact Paidamoyo for issues relevant to this project.