Skip to content

Generative modeling and representation learning through reconstruction

License

Notifications You must be signed in to change notification settings

tomouellette/autoencodersplz

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

80 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation


A variety of autoencoder structured models for generative modeling and/or representation learning in pytorch. Models are mostly designed for usability/extensability/research rather than production implementations. But, go ahead and train some models and reconstruct some things!

Table of contents

Installation

pip install autoencodersplz

Models

LinearAE

A fully-connected autoencoder with a linear/multi-layer perceptron encoder and decoder

Reducing the Dimensionality of Data with Neural Networks

import torch
from autoencodersplz.models import LinearAE

model = LinearAE(
    img_size = 224,
    in_chans = 3,
    hidden_layers = [64, 64],
    dropout_rate = 0,
    latent_dim = 16,
    beta = 0.1, # beta > 0 = variational
    max_temperature = 1000, # kld temperature annealing
    device = None
)

img = torch.rand(1, 3, 224, 224)

loss, reconstructed_img = model(img)

LinearResidualAE

A fully-connected autoencoder with a linear/multi-layer perceptron residual network encoder and decoder

Skip Connections Eliminate Singularities

import torch
from autoencodersplz.models import LinearResidualAE

model = LinearResidualAE(
    img_size = 224,
    in_chans = 3,
    hidden_dim = [64, 64],
    blocks = [2, 2],
    dropout_rate = 0.1,
    with_batch_norm = False,
    latent_dim = 16,
    beta = 0.1, # beta > 0 = variational
    max_temperature = 1000, # kld temperature annealing
    device = None,
)

img = torch.rand(1, 3, 224, 224)

loss, reconstructed_img = model(img)

ConvResidualAE

A convolutional autoencoder with a ResNet encoder and symmetric decoder

Deep Residual Learning for Image Recognition

import torch
from autoencodersplz.models import ConvResidualAE

model = ConvResidualAE(
    img_size = 224,
    in_chans = 3,
    channels = [64, 128, 256, 512], 
    blocks = [2, 2, 2, 2], 
    latent_dim = 16,
    beta = 0, # beta > 0 = variational
    max_temperature = 1000, # kld temperature annealing
    upsample_mode = 'nearest', # interpolation method
    device = None,
)

img = torch.rand(1, 3, 224, 224)

loss, reconstructed_img = model(img)

VQVAE

A vector-quantized variational autoencoder with a ResNet encoder and symmetric decoder

Neural Discrete Representation Learning

import torch
from autoencodersplz.models import VQVAE

model = VQVAE(
    img_size = 224,
    in_chans = 3,
    channels = [64, 128, 256, 512],
    blocks = [2, 2, 2, 2],
    codebook_size = 256,
    codebook_dim = 8,
    use_cosine_sim = True,
    kmeans_init = True,
    commitment_weight = 0.5,
    upsample_mode = 'nearest',
    vq_kwargs = {},
)

img = torch.rand(1, 3, 224, 224)

loss, reconstructed_img = model(img)

FSQVAE

A finite-scalar quantized variational autoencoder with a ResNet encoder and symmetric decoder

Finite Scalar Quantization: VQ-VAE Made Simple

import torch
from autoencodersplz.models import FSQVAE

model = FSQVAE(
    img_size = 224,
    in_chans = 3,
    channels = [64, 128, 256, 512],
    blocks = [2, 2, 2, 2],
    levels = [8, 6, 5],
    upsample_mode = 'nearest'
)

img = torch.rand(1, 3, 224, 224)

loss, reconstructed_img = model(img)

MAE

A masked autoencoder with a vision transformer encoder and decoder

Masked Autoencoders Are Scalable Vision Learners

import torch
import torch.nn as nn
from autoencodersplz.models import MAE

model = MAE(
    img_size = 224,
    patch_size = 16,
    in_chans = 3,
    mask_ratio = 0.5,
    embed_dim = 768,
    depth = 12,
    num_heads = 12,
    mlp_ratio = 4,
    pre_norm = False,
    decoder_embed_dim = 768,
    decoder_depth = 12,
    decoder_num_heads = 12,
    norm_layer = torch.nn.LayerNorm,
    patch_norm_layer = torch.nn.LayerNorm,
    post_norm_layer = torch.nn.LayerNorm,
)

img = torch.rand(1, 3, 224, 224)

loss, reconstructed_img = model(img)

MAEMix

A masked autoencoder with a MLP-mixer encoder and decoder

MLP-Mixer: An all-MLP Architecture for Vision

import torch
from autoencodersplz.models import MAEMix

model = MAEMix(
    img_size = 224,
    patch_size = 16,
    in_chans = 3,
    mask_ratio = 0.5,
    embed_dim = 768,
    depth = 12,
    mlp_ratio = 4,
    decoder_embed_dim = 768,
    decoder_depth = 12,
)

img = torch.rand(1, 3, 224, 224)

loss, reconstructed_img = model(img)

IJEPA

Image-based joint-embedding predictive architecture (Thanks to Yiran for porting this implementation)

Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture

import torch
from autoencodersplz.models import IJEPA

model = IJEPA(
    img_size = 224,
    patch_size = 16,
    in_chans = 3,
    embed_dim = 768,
    depth = 12,
    num_heads = 12,
    mlp_ratio = 4,
    embed_dim_predictor = 384,
    predictor_depth = 12,
    num_targets = 4,
    target_aspect_ratio = 0.75,
    target_scale = 0.2,
    context_aspect_ratio = 1.,
    context_scale = 0.9
)

img = torch.rand(1, 3, 224, 224)

loss, reconstructed_img = model(img)

Training

Basic

The Trainer class enables basic training using a single CPU or GPU for any model in the autoencodersplz library. The Trainer class will also automatically save the autoencoder model, backbone/encoder, losses, and a visualization of the training process (.gif) if you provide a path to the output_dir argument.

from autoencodersplz.trainers import Trainer

trainer = Trainer(
    autoencoder,
    train = train_dataloader,
    valid = valid_dataloader,
    epochs = 128,
    learning_rate = 5e-4,
    betas = (0.9, 0.95),
    weight_decay = 0.05,
    patience = 10,
    scheduler = 'plateau',
    save_backbone = True,
    show_plots = False,
    output_dir = 'training_run/',
    device = None,
)

trainer.fit()

By default, Trainer uses an AdamW optimizer and either a CosineDecay ('cosine') or ReduceLROnPlateau ('plateau') scheduler. If you want to use different optimizers or schedulers, just re-assign a new optimizer or scheduler to the .optimizer or .scheduler attributes (with trainer.model.parameters()) prior to calling trainer.fit().

Lightning

To make it easier to scale to multi-gpu/distributed training, all autoencodersplz models are configured for use with pytorch lightning. Each model is setup with a default optimizer and scheduler and can be directly called by the pytorch lightning trainer. See an example below.

import lightning.pytorch as pl
from autoencodersplz.models import FSQVAE

model = FSQVAE(
    img_size = 28,
    in_chans = 1,
    channels = [8, 16],
    blocks = [1, 1],
    levels = [8],
    upsample_mode = 'nearest'
    learning_rate = 1e-3,
    factor = 0.1,
    patience = 30,
    min_lr = 1e-6
) 

trainer = pl.Trainer(gpus=4, max_epochs=256)

trainer.fit(model, train_dataloader, valid_dataloader)

Examples

Basic usage

Here's a basic example of training a fully connected autoencoder on MNIST. The data is downloaded and loaded and then the autoencoder is fit. The training info is logged to the output directory (training/) and a GIF of the training routine is generated for visual inspection.

from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from autoencodersplz.models import LinearAE
from autoencodersplz.trainers import Trainer

train_loader = DataLoader(
    MNIST(root='data/', train=True, download=True, transform=ToTensor()),
    batch_size = 32,
    shuffle = True,
)

test_loader = DataLoader(
    MNIST(root='data/', train=False, download=True, transform=ToTensor()),
    batch_size = 32,
    shuffle = False,
)

model = LinearAE(
    img_size = 28,
    in_chans = 1,
    hidden_layers = [256, 128],
    dropout_rate = 0,
    latent_dim = 32,
    beta = 0,
)

trainer = Trainer(
    model, 
    train_loader, 
    test_loader, 
    epochs = 32, 
    learning_rate = 1e-3, 
    output_dir = 'training/'
)
        
trainer.fit()

Future additions

References

@article{hinton2006reducing,
  title = {Reducing the dimensionality of data with neural networks},
  author = {Geoffrey Hinton and Ruslan Salakhutdinov},
  url = {10.1126/science.1127647},  
  year = {2006},
}
@article{orhan2018skip,
    title = {Skip Connections Eliminate Singularities},
    author = {Emin Orhan and Xaq Pitkow},
    url = {https://arxiv.org/abs/1701.09175},
    year = {2018},    
}
@article{he2015deep,
    title = {Deep Residual Learning for Image Recognition}, 
    author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun},
    url = {https://arxiv.org/abs/1512.03385},
    year = {2016},
}
@misc{oord2018neural,
    title={Neural Discrete Representation Learning}, 
    author={Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
    url = {https://arxiv.org/abs/1711.00937},
    year={2017},
}
@misc{mentzer2023finite,
    title = {Finite Scalar Quantization: VQ-VAE Made Simple}, 
    author = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen},    
    url = {https://arxiv.org/abs/2309.15505},
    year = {2023},
}
@misc{he2021masked,
    title = {Masked Autoencoders Are Scalable Vision Learners}, 
    author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
    url = {https://arxiv.org/abs/2111.06377},
    year = {2021},
}
@misc{tolstikhin2021mlpmixer,
    title = {MLP-Mixer: An all-MLP Architecture for Vision}, 
    author = {Ilya Tolstikhin and Neil Houlsby and Alexander Kolesnikov and Lucas Beyer and Xiaohua Zhai and Thomas Unterthiner and Jessica Yung and Andreas Steiner and Daniel Keysers and Jakob Uszkoreit and Mario Lucic and Alexey Dosovitskiy},
    url = {https://arxiv.org/abs/2105.01601},
    year = {2021},
}
@misc{assran2023selfsupervised,
    title = {Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture}, 
    author = {Mahmoud Assran and Quentin Duval and Ishan Misra and Piotr Bojanowski and Pascal Vincent and Michael Rabbat and Yann LeCun and Nicolas Ballas},
    url = {https://arxiv.org/abs/2301.08243},
    year = {2023},
}