Skip to content

greenelab/shared-latent-space

Repository files navigation

Shared Latent Space Variational Autoencoders

Motivation for application of Shared Latent SPace VAE's In Biology

Variational Autoencoders are machine learning models which learn the distribution of the data in a latent space manifold. This allows the model to be trained on unsuprivesed data, but learning to recreate the input data, and also allows the model to create new data that ressembles the original data by picking points in the lantent space manifold. Shared Latent Space VAE's find relationships between two different domains and allow for transformations between the two. They achieve this by linking the lantent space manifold between two different encoders and decoders. This is particularly useful in Biology where we could use different data types as different 'views' on the same biological problem. The ability to transform between domains also allows us to transition between different data types.

Diagram of Model

Alt text

Usage

Computational Environment

All libraries and packages are handled by conda and specified in the environment.yml file. To build and activate this enviroment, run:

#conda version 4.4.10
conda env create --force --file environment.yml

conda activate shared-latenet-space

Running the Model

The pickle files for MNIST and ICVL are both too big to include on the repo. They are publically available and I will later provide an easy way to download them that will be consistent among platforms. You should consider changing the layout of the model, which is controlled in the intialization of the model_parameters object in main_file.py. You can also control the noise level as a parameter to the train function also located inside main_file.py.

Changing Datasets

If the data set is one of the already included datasets, you need to change which implementation of DataSetInfoAbstract is called when defining the DataSetInfo object. Additionally, you should consider changing the model layout, as outlined in Running the Model.

Adding More Datasets

If you want to add more Datasets, that is fully supported. Create your own implementation of the DataSetInfoAbstract abstract class including a method for loading and visualizing. You must return two training sets and two testing sets. You do not have to return anything for visualization, so it is merely sufficent to define the function and return nothing. Thus, you don't need to implement a visualization, but you must declare it to comply with the interface.

Files

main_file.py

This is the file which should be called. It handles calling other files for loading and formating data. It also calls upon shared_vae_class.py to create the model, train it, and generate data. As work continues, this file will become more general and easier to work with. As of now, if you are using your own data, you should create a file for it which impliments the DataSetInfoAbstract abtract class. Then create it in the file and the file should know how to interact with it.

To run the file, open command line and enter:

python main_file.py

shared_vae_class.py

This file is the main class which hold the model. It contains functions to compile, train, and generate from the model. The model will take in a series of parameters which control size of layers, etc. The model right now is very rigid in structure, but this may change. There a 5 different models built inside here for the purposes of training, but they are hidden. The generate function calls on the visualize function of the DataSetInfoAbstract class. This will produce an image to help visualize how the model is working.

model_objects.py

This file contains the model_parameters class which is fed to the shared_vae_class when it is initialized.

DataSetInfoAbstractClass.py

This file is an abstract class which is used for any dataset specific functions such as loading and visualizing. These are abstract functions, so a specific implimentation must be provided for the dataset.

ICVL.py

This is a specific implimentation of the DataSetInfoAbstract abstract class for the ICVL data. It contains a load function which loads the data from a pickle file as well as a visualize function which produces images of the depth maps and knuckle maps. The draw_hands function draws all of the lines between the joints in the hand as given by the dataset.

MNIST.py

This is a specific implimentation of the DataSetInfoAbstract abstract class for the MNIST data. It contains a load function which loads the data from a pickle file as well as a visualize function which produces images of the regular MNIST digits and the inverse MNIST digits.