Skip to content

CTPLab/IID_representation_learning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

IID representation learning

Official PyTorch implementation for the following manuscript:

Towards IID representation learning and its application on biomedical data, MIDL 2022.
Jiqing Wu, Inti Zlobec, Maxime W Lafarge, Yukun He and Viktor Koelzer.

Due to the heterogeneity of real-world data, the widely accepted independent and identically distributed (IID) assumption has been criticized in recent studies on causality. In this paper, we argue that instead of being a questionable assumption, IID is a fundamental task-relevant property that needs to be learned. Consider k independent random vectors , We elaborate on how a variety of different causal questions can be reformulated to learning a task-relevant function that induces IID among , which we term IID representation learning.

For proof of concept, we examine the IID representation learning on Out-of-Distribution (OOD) generalization tasks. Concretely, by utilizing the representation obtained via the learned function that induces IID, we conduct prediction of molecular characteristics (molecular prediction) on two biomedical datasets with real-world distribution shifts introduced by a) preanalytical variation and b) sampling protocol. To enable reproducibility and for comparison to the state-of-the-art (SOTA) methods, this is done by following the OOD benchmarking guidelines recommended from WILDS. Compared to the SOTA baselines supported in WILDS, the results confirm the superior performance of IID representation learning on OOD tasks.


The overall model illustrations of the proposed IID representation learning. Left: Restyle Encoder and StyleGAN Decoder. Right: The downstream molecular predictor (ERM + IID representation).


WILDS Leaderboard

Installation

This implementation is mainly dependent on two backends: Restyle and WILDS. This suggests that, if the dependencies for Restyle and Wilds are installed, then the environment of our code is correctly configured.

Prerequisites
This implementation has been successfully tested under the following configurations, meanwhile the version range in the parentheses could also work:

  • Ubuntu 20.04 (>=18.04)
  • Nvidia driver 470.86 (>=450.xx)
  • CUDA 11.3 (>=11.0)
  • Python 3.9 (>=3.7)
  • Miniconda
  • Docker
  • Nvidia-Docker2

After the prerequisites are satisfied, you could either build the docker image or install conda dependencies.

Docker Installation
Assume Docker and Nvidia-Docker2 toolkit are correctly installed, build the Docker image as follows

cd /Path/To/IID_representation_learning/
docker build -t iid .

Conda Installation
Assume a local Conda environment is installed and activated, assume CUDA is installed to the /usr/local/cuda folder (Please do not install CUDA locally as suggested by PyTorch Conda installation, this may cause errors when StyleGAN calls ninja package.), install the dependencies as follows

pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

pip install wilds matplotlib opencv-python ninja

You could also try creating a conda environment from environment.yml.

Data Preparation

To run the experiments investigated in the paper, we need to download the RxRx1 dataset and some pretrained models

  1. Download RxRx1

    git clone https://github.com/p-lambda/wilds.git
    cd wilds/
    python wilds/download_datasets.py --root_dir /Path/To/Data --datasets rxrx1
    
  2. Create restyle/pretrained_models/ folder and download the following models to this folder

Run the Experiments

In case of using Docker image, you could first launch the Docker image

Launch the Docker Image

sudo docker run -it \
    # this folder stores the pretrained weights of some standard backbones
    # that may be used during training
    -v /Path/To/.cache/torch/hub/checkpoints:/root/.cache/torch/hub/checkpoints/ \
    -v /Path/To/IID_representation_learning:/root/IID_representation_learning \
    -v /Path/To/Data:/root/Data \
    --gpus '"device=0"'  iid

The Training of Molecular Prediction

cd /Path/To/IID_representation_learning/

python main.py -m train \
               -e 150 \
               --seed 0 \
               --lr cosine,1.5e-4,90,6e-5,150,0 \
               --backbone resnet50 \
               --data_type rxrx1 \
               --classes 1139 \
               --split_scheme official \
               --batch_size 24 \
               --eval_batch_size 64 \
               --data /Path/To/Data/ \
               --save /Path/To/Save/ \
               --checkpoint_path restyle/pretrained_models/ --save-model --loss-coef 0.15 --noise --style

The Training of IID representation (Restyle auto-encoder)
First, add the data path string '/Path/to/Training/Data' to dataset_paths['wilds'] in restyle/configs/paths_config.py, then run

cd /Path/To/IID_representation_learning/restyle

python scripts/train_restyle_psp.py \
    --dataset_type=rxrx1 \
    --split_scheme=official \
    --encoder_type=ResNetBackboneEncoder \
    --exp_dir=/Path/To/Output \
    --workers=8 \
    --batch_size=8 \
    --test_batch_size=8 \
    --test_workers=8 \
    --max_steps=100000 \
    --val_interval=5000 \
    --image_interval=5000 \
    --save_interval=10000 \
    --start_from_latent_avg \
    --lpips_lambda=0.2 \
    --l2_lambda=5 \
    --w_norm_lambda=0 \
    --id_lambda=0 \
    --moco_lambda=0.2 \
    --input_nc=6 \
    --n_iters_per_batch=1 \
    --output_size=256 \
    --stylegan_weights=pretrained_models/rxrx1_100000.pt

The Training of StyleGAN
See https://github.com/rosinality/stylegan2-pytorch for more details about creating the lmdb datasets with 256X256 resolution (such as RxRx1) and running the training script.

Visualization

The Image Synthesis of StyleGAN

cd /Path/To/IID_representation_learning/

python main.py -m synth \
               --data_type rxrx1 \
               --classes 1139 \
               --split_scheme official \
               --batch_size 32 \
               --data /Path/To/Data/ \
               --save /Path/To/Save/ \
               --checkpoint_path restyle/pretrained_models/ 


Nonexistent fluorescence images synthesized by StyleGAN learned with RxRx1 training data.

The Reconstruction of Restyle Auto-encoder

cd /Path/To/IID_representation_learning/

python main.py -m recon \
               --data_type rxrx1 \
               --classes 1139 \
               --split_scheme official \
               --batch_size 16 \
               --eval_batch_size 16 \
               --data /Path/To/Data/ \
               --save /Path/To/Save/ \
               --checkpoint_path restyle/pretrained_models/


The visual comparison between ground-truth (red bounding box) and reconstructed RxRx1 images. After running the above implementation, additional post-processing steps are requried to normalize the images along each channel and zoom in on a small region for better visualization.

Acknowledgments

This implementation is built upon Restyle and WILDS projects. Besides, our code borrows from kaggle-rcic-1st. We would like to convey our gratitudes to all the authors working on those projects.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages