This repository collects Chainer implementations of state-of-the-art methods for unsupervised disentangled representation learning.
Codes were evaluated quantitatively and qualitatively on dSprites[13] and mpi3d_toy[14] dataset.
git clone git@github.com:pfnet-research/chainer_disentanglement_lib.git
cd chainer_disentanglement_lib
This repository is tested on Python 3.6.8.
Required python packages are listed in requirements-gpu.txt
.
Note that you may need to install additional libraries manually, such as cuda.
pip install -r requirements-gpu.txt # or, requirements-cpu.txt if your machine does not have any GPU
Dockerfile
contains all dependencies needed for this repository.
# Don't miss the final dot!
docker build -t <your_image_name> -f Dockerfile-gpu . # or, Dockerfile-cpu if your machine does not have any GPU
Commands below should be run on the docker container's console. (docker run -it -v
pwd:/home/jovyan/chainer_disentanglement_lib <your_image_name> /bin/bash
)
Download dSprites dataset and mpi3d_toy dataset.
source bin/set_environ.sh
bash bin/download_dsprites.sh
You can change model save path and dataset path.
export OUTPUT_PATH=<your_output_path>
export DISENTANGLEMENT_LIB_DATA=<your_dataset_path>
You can set experiment name and dataset name to use.
# Enter the name of your experiment
# training script will make directory ${OUTPUT_PATH}/${EVALUATION_NAME}
export EVALUATION_NAME=dev_tmp
# Enter the name of dataset to use
export DATASET_NAME=dsprites_full
# export DATASET_name=mpi3d_toy
DATASET_NAME
has several options if you have additional datasets. For more detail, follow instructions described in disentanglement_lib/disentanglement_lib/data/ground_truth/named_data.py and disentanglement_lib/bin/dlib_download_data.
You can start training with train.py
.
# please set `--device -1` on CPU-only machines.
python3 train.py --vae FactorVAE --device 0 --training_steps 1000
Please see bin/example.sh
for other options.
chainer_disentanglement_lib supports evaluating some "disentanglement score"s.
Evaluation script is derived from Disentanglement Challenge Starter Kit, originated in disentanglement_lib.
Note that evaluation process may take about 1 hour.
# this script evaluate the model in
# ${OUTPUT_PATH}/${EVALUATION_NAME}/representation
# metrics will be evaluated by the ${DATASET_NAME} dataset
# local_scores.json wil be made in ${OUTPUT_PATH}/${EVALUATION_NAME}
python3 local_evaluation.py
chainer_disentanglement_lib has been tested with the following environment.
Docker base image nvidia/cuda:10.0-cudnn7-runtime-ubuntu18.04
Python 3.6.8
chainer 6.3.0
cupy-cuda100 6.3.0
numpy 1.16.3
Pillow 6.0.0
# for metric calculation
tensorflow-gpu 1.13.1
tensorflow-probability 0.6.0
gin-config 0.1.4
Implementations for BetaVAE, FactorVAE, DIPVAE-1/2 have been tested with cars3d dataset. From the result, chainer_disentanglement_lib can be say to reproduce the results of disentanglement_lib's implementation.
Following table shows the comparison between our implementation and disentanglement_lib
.
Each model was trained on 50 different seeds and evaluated on 5 disentanglement metric scores.
model | quantitative comparison on cars3d dataset (left: ours, right: disentanglement_lib ) |
---|---|
BetaVAE[2] | |
FactorVAE[3] | |
DIPVAE-1[5] | |
DIPVAE-2[5] |
Implementation for JointVAE has been tested with dSprites dataset.
As in the original JointVAE paper, outliers where the model collapsed to the mean are removed.
The model was trained on 10 different seeds and evaluated on 5 disentanglement metric scores.
The result shows that our implementation reproduce the original performance.
model | ours | reported average score in the original paper |
---|---|---|
JointVAE[6] | factor_vae_metric: 0.69 |
dSprites | mpi3d_toy |
---|---|
The image on the most left is the source image and second left one is its reconstruction image by the model.
Other animated images show latent traversals of the encoded latent vector.
For each latent traversal, only one dimension of the latent vector is changed from -1.5 to 1.5.
Notice that some latent dimension seems to represent "disentangled" factor of the dataset, namely, x-position/y-position/rotation/shape/scale.
(MIT License. Please see the LICENSE file for details.)
- [1] Kingma & Welling, 2014, Auto-Encoding Variational Bayes
- [2] Higgins et al., 2017, β-VAE: LEARNING BASIC VISUAL CONCEPTS WITH A CONSTRAINED VARIATIONAL FRAMEWORK
- [3] Kim & Mnih, 2018, Disentangling by Factorising
- [4] Chen et al., 2018, Isolating Sources of Disentanglement in Variational Autoencoders
- [5] Kumar et al. 2018, Variational Inference of Disentangled Latent Concepts from Unlabeled Observations
- [6] Dupont, 2018, Learning Disentangled Joint Continuous and Discrete Representations
- [7] Jeong & Song, 2019, Learning Discrete and Continuous Factors of Data via Alternating Disentanglement
- [8] Do and Tran, 2019, Theory and Evaluation Metrics for Learning Disentangled Representations
- [9] Suter et al., 2019, Robustly Disentangled Causal Mechanisms: Validating Deep Representations for Interventional Robustness
- [10] Eastwood & Williams, 2018, A FRAMEWORK FOR THE QUANTITATIVE EVALUATION OF DISENTANGLED REPRESENTATIONS
- [11] Locatello et al., 2018, Challenging Common Assumptions in the Unsupervised Learning of Disentangled Representation
- [12] Mathieu et al., 2019 Disentangling Disentanglement in Variational Autoencoders
- [13] Loic et al., 2017, dSprites: Disentanglement testing Sprites dataset
- [14] Gondal et al., 2019, On the Transfer of Inductive Bias from Simulation to the Real World: a New Disentanglement Dataset