Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding model with cycle consistency and VampPrior #2421

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a582251
add model and tests
Hrovatin Jan 14, 2024
bb359ac
update documentation
Hrovatin Jan 14, 2024
14e41f1
move embedding to device
Hrovatin Jan 14, 2024
5863b1f
Merge branch 'main' into main
martinkim0 Jan 19, 2024
3f49266
pr comments
Hrovatin Jan 21, 2024
6605682
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Jan 21, 2024
3319fc8
Merge branch 'main' into main
martinkim0 Jan 22, 2024
3a67d1c
Merge branch 'main' into main
martinkim0 Feb 5, 2024
5edad83
updates
Hrovatin Feb 7, 2024
a4b080e
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Feb 7, 2024
a24ef28
Merge branch 'main' into main
martinkim0 Feb 20, 2024
9b05bca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
8c25dba
Merge branch 'main' into main
martinkim0 Mar 11, 2024
c5f5c37
Update scvi/external/sysvi/_base_components.py
martinkim0 Mar 15, 2024
5b4838c
Update scvi/external/sysvi/_base_components.py
martinkim0 Mar 15, 2024
c885e20
Update scvi/external/sysvi/_base_components.py
martinkim0 Mar 15, 2024
661bbc6
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
9a49d24
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
3622eee
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
e4c1ef9
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
0f7bd06
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
9e0cba9
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
54f5734
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions scvi/external/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .scbasset import SCBASSET
from .solo import SOLO
from .stereoscope import RNAStereoscope, SpatialStereoscope
from .sysvi import SysVI
from .tangram import Tangram

__all__ = [
Expand All @@ -19,4 +20,5 @@
"SCBASSET",
"POISSONVI",
"ContrastiveVI",
"SysVI",
]
3 changes: 3 additions & 0 deletions scvi/external/sysvi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._model import SysVI

__all__ = ["SysVI"]
347 changes: 347 additions & 0 deletions scvi/external/sysvi/_base_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
from __future__ import annotations

from collections import OrderedDict
from typing import Literal

import torch
from torch.distributions import Normal
from torch.nn import (
BatchNorm1d,
Dropout,
LayerNorm,
Linear,
Module,
Parameter,
ReLU,
Sequential,
)


class Embedding(Module):
"""Module for obtaining embedding of categorical covariates

Parameters
----------
size
N categories
cov_embed_dims
Dimensions of embedding
normalize
Apply layer normalization
"""

def __init__(self, size, cov_embed_dims: int = 10, normalize: bool = True):
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()

self.normalize = normalize

self.embedding = torch.nn.Embedding(size, cov_embed_dims)

if self.normalize:
# TODO this could probably be implemented more efficiently as embed gives same result for every sample in
# a give class. However, if we have many balanced classes there wont be many repetitions within minibatch
self.layer_norm = torch.nn.LayerNorm(
cov_embed_dims, elementwise_affine=False
)

def forward(self, x):
x = self.embedding(x)
if self.normalize:
x = self.layer_norm(x)

return x


class EncoderDecoder(Module):
"""Module that can be used as probabilistic encoder or decoder

Based on inputs and optional covariates predicts output mean and var

Parameters
----------
n_input
n_output
n_cov
n_hidden
n_layers
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
var_eps
See :class:`~scvi.external.sysvi.nn.VarEncoder`
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
var_mode
See :class:`~scvi.external.sysvi.nn.VarEncoder`
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
sample
Return samples from predicted distribution
kwargs
Passed to :class:`~scvi.external.sysvi.nn.Layers`
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
n_input: int,
n_output: int,
n_cov: int,
n_hidden: int = 256,
n_layers: int = 3,
var_eps: float = 1e-4,
var_mode: str = "feature",
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
sample: bool = False,
**kwargs,
):
super().__init__()
self.sample = sample

self.var_eps = var_eps

self.decoder_y = Layers(
n_in=n_input,
n_cov=n_cov,
n_out=n_hidden,
n_hidden=n_hidden,
n_layers=n_layers,
**kwargs,
)

self.mean_encoder = Linear(n_hidden, n_output)
self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode, eps=var_eps)

def forward(self, x, cov: torch.Tensor | None = None):
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
y = self.decoder_y(x=x, cov=cov)
# TODO better handling of inappropriate edge-case values than nan_to_num or at least warn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have there been edge case values other than NaNs?

y_m = torch.nan_to_num(self.mean_encoder(y))
y_v = self.var_encoder(y, x_m=y_m)

outputs = {"y_m": y_m, "y_v": y_v}

# Sample from latent distribution
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
if self.sample:
y = Normal(y_m, y_v.sqrt()).rsample()
outputs["y"] = y

return outputs


class Layers(Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this class is similar to the existing implementation of FCLayers, could you refactor Layers so that it subclasses FCLayers and overrides methods as needed? This will make it clearer what's different about this implementation as well as remove duplicate code in inject_into_layer and set_online_update_hooks, as it looks like these methods are identical to the original

"""A helper class to build fully-connected layers for a neural network.

Adapted from scVI FCLayers to use covariates more flexibly

Parameters
----------
n_in
The dimensionality of the main input
n_out
The dimensionality of the output
n_cov
Dimensionality of covariates.
If there are no cov this should be set to None -
in this case cov will not be used.
n_layers
The number of fully-connected hidden layers
n_hidden
The number of nodes per hidden layer
dropout_rate
Dropout rate to apply to each of the hidden layers
use_batch_norm
Whether to have `BatchNorm` layers or not
use_layer_norm
Whether to have `LayerNorm` layers or not
use_activation
Whether to have layer activation or not
bias
Whether to learn bias in linear layers or not
inject_covariates
Whether to inject covariates in each layer, or just the first.
activation_fn
Which activation function to use
"""

def __init__(
self,
n_in: int,
n_out: int,
n_cov: int | None = None,
n_layers: int = 1,
n_hidden: int = 128,
dropout_rate: float = 0.1,
use_batch_norm: bool = True,
use_layer_norm: bool = False,
use_activation: bool = True,
bias: bool = True,
inject_covariates: bool = True,
activation_fn: Module = ReLU,
):
super().__init__()

self.inject_covariates = inject_covariates
self.n_cov = n_cov if n_cov is not None else 0

layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [n_out]

self.fc_layers = Sequential(
OrderedDict(
[
(
f"Layer {i}",
Sequential(
Linear(
n_in + self.n_cov * self.inject_into_layer(i),
n_out,
bias=bias,
),
# non-default params come from defaults in original Tensorflow implementation
BatchNorm1d(n_out, momentum=0.01, eps=0.001)
if use_batch_norm
else None,
LayerNorm(n_out, elementwise_affine=False)
if use_layer_norm
else None,
activation_fn() if use_activation else None,
Dropout(p=dropout_rate) if dropout_rate > 0 else None,
),
)
for i, (n_in, n_out) in enumerate(
zip(layers_dim[:-1], layers_dim[1:])
)
]
)
)

def inject_into_layer(self, layer_num) -> bool:
"""Helper to determine if covariates should be injected."""
user_cond = layer_num == 0 or (layer_num > 0 and self.inject_covariates)
return user_cond

def set_online_update_hooks(self, hook_first_layer=True):
self.hooks = []

def _hook_fn_weight(grad):
new_grad = torch.zeros_like(grad)
if self.n_cov > 0:
new_grad[:, -self.n_cov :] = grad[:, -self.n_cov :]
return new_grad

def _hook_fn_zero_out(grad):
return grad * 0

for i, layers in enumerate(self.fc_layers):
for layer in layers:
if i == 0 and not hook_first_layer:
continue
if isinstance(layer, Linear):
if self.inject_into_layer(i):
w = layer.weight.register_hook(_hook_fn_weight)
else:
w = layer.weight.register_hook(_hook_fn_zero_out)
self.hooks.append(w)
b = layer.bias.register_hook(_hook_fn_zero_out)
self.hooks.append(b)

def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None):
"""
Forward computation on ``x``.

Parameters
----------
x
tensor of values with shape ``(n_in,)``
cov
tensor of covariate values with shape ``(n_cov,)`` or None

Returns
-------
py:class:`torch.Tensor`
tensor of shape ``(n_out,)``

"""
for i, layers in enumerate(self.fc_layers):
for layer in layers:
if layer is not None:
if isinstance(layer, BatchNorm1d):
if x.dim() == 3:
x = torch.cat(
[(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0
)
else:
x = layer(x)
else:
# Injection of covariates
if (
self.n_cov > 0
and isinstance(layer, Linear)
and self.inject_into_layer(i)
):
x = torch.cat((x, cov), dim=-1)
x = layer(x)
return x


class VarEncoder(Module):
"""Encode variance (strictly positive).

Parameters
----------
n_input
Number of input dimensions, used if mode is sample_feature
n_output
Number of variances to predict
mode
How to compute var
'sample_feature' - learn per sample and feature
'feature' - learn per feature, constant across samples
'linear' - linear with respect to input mean, var = a1 * mean + a0;
not suggested to be used due to bad implementation for positive constraining
Comment on lines +284 to +289
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you format this similar to how it is done here?

eps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs documentation

"""

def __init__(
self,
n_input: int,
n_output: int,
mode: Literal["sample_feature", "feature", "linear"],
eps: float = 1e-4,
):
super().__init__()

self.eps = eps
self.mode = mode
if self.mode == "sample_feature":
self.encoder = Linear(n_input, n_output)
elif self.mode == "feature":
self.var_param = Parameter(torch.zeros(1, n_output))
elif self.mode == "linear":
self.var_param_a1 = Parameter(torch.tensor([1.0]))
self.var_param_a0 = Parameter(torch.tensor([self.eps]))
else:
raise ValueError("Mode not recognised.")
self.activation = torch.exp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably a good idea to make this an adjustable parameter as we've experienced numerical stability issues with the exponential.


def forward(self, x: torch.Tensor, x_m: torch.Tensor):
"""Forward pass through model

Parameters
----------
x
Used to encode var if mode is sample_feature; dim = n_samples x n_input
x_m
Used to predict var instead of x if mode is linear; dim = n_samples x 1

Returns
-------
Predicted var
"""
# Force to be non nan - TODO come up with better way to do so
if self.mode == "sample_feature":
v = self.encoder(x)
v = (
torch.nan_to_num(self.activation(v)) + self.eps
) # Ensure that var is strictly positive
elif self.mode == "feature":
v = self.var_param.expand(x.shape[0], -1) # Broadcast to input size
v = (
torch.nan_to_num(self.activation(v)) + self.eps
) # Ensure that var is strictly positive
elif self.mode == "linear":
v = self.var_param_a1 * x_m.detach().clone() + self.var_param_a0
# TODO come up with a better way to constrain this to positive while having lin relationship
# Could activation be used for log-lin relationship?
v = torch.clamp(torch.nan_to_num(v), min=self.eps)
return v