-
Notifications
You must be signed in to change notification settings - Fork 342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding model with cycle consistency and VampPrior #2421
base: main
Are you sure you want to change the base?
Changes from all commits
a582251
bb359ac
14e41f1
5863b1f
3f49266
6605682
3319fc8
3a67d1c
5edad83
a4b080e
a24ef28
9b05bca
8c25dba
c5f5c37
5b4838c
c885e20
661bbc6
9a49d24
3622eee
e4c1ef9
0f7bd06
9e0cba9
54f5734
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ._base_components import Layers, VarEncoder | ||
from ._model import SysVI | ||
from ._module import SysVAE | ||
|
||
__all__ = ["SysVI", "VarEncoder", "Layers", "SysVAE"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,345 @@ | ||
from __future__ import annotations | ||
|
||
from collections import OrderedDict | ||
from typing import Literal | ||
|
||
import torch | ||
from torch.distributions import Normal | ||
from torch.nn import ( | ||
BatchNorm1d, | ||
Dropout, | ||
LayerNorm, | ||
Linear, | ||
Module, | ||
Parameter, | ||
ReLU, | ||
Sequential, | ||
) | ||
|
||
|
||
class Embedding(Module): | ||
"""Module for obtaining embedding of categorical covariates | ||
|
||
Parameters | ||
---------- | ||
size | ||
N categories | ||
cov_embed_dims | ||
Dimensions of embedding | ||
normalize | ||
Apply layer normalization | ||
""" | ||
|
||
def __init__(self, size: int, cov_embed_dims: int = 10, normalize: bool = True): | ||
super().__init__() | ||
|
||
self.normalize = normalize | ||
|
||
self.embedding = torch.nn.Embedding(size, cov_embed_dims) | ||
|
||
if self.normalize: | ||
# TODO this could probably be implemented more efficiently as embed gives same result for every sample in | ||
# a give class. However, if we have many balanced classes there wont be many repetitions within minibatch | ||
self.layer_norm = torch.nn.LayerNorm(cov_embed_dims, elementwise_affine=False) | ||
|
||
def forward(self, x): | ||
x = self.embedding(x) | ||
if self.normalize: | ||
x = self.layer_norm(x) | ||
|
||
return x | ||
|
||
|
||
class EncoderDecoder(Module): | ||
"""Module that can be used as probabilistic encoder or decoder | ||
|
||
Based on inputs and optional covariates predicts output mean and var | ||
|
||
Parameters | ||
---------- | ||
n_input | ||
The dimensionality of the main input | ||
n_output | ||
The dimensionality of the output | ||
n_cov | ||
Dimensionality of covariates. | ||
If there are no cov this should be set to None - | ||
in this case cov will not be used. | ||
n_hidden | ||
The number of fully-connected hidden layers | ||
n_layers | ||
Number of hidden layers | ||
var_eps | ||
See :class:`~scvi.external.sysvi.VarEncoder` | ||
var_mode | ||
See :class:`~scvi.external.sysvi.VarEncoder` | ||
sample | ||
Return samples from predicted distribution | ||
kwargs | ||
Passed to :class:`~scvi.external.sysvi.Layers` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n_input: int, | ||
n_output: int, | ||
n_cov: int, | ||
n_hidden: int = 256, | ||
n_layers: int = 3, | ||
var_eps: float = 1e-4, | ||
var_mode: Literal["sample_feature", "feature", "linear"] = "feature", | ||
sample: bool = False, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
self.sample = sample | ||
|
||
self.var_eps = var_eps | ||
|
||
self.decoder_y = Layers( | ||
n_in=n_input, | ||
n_cov=n_cov, | ||
n_out=n_hidden, | ||
n_hidden=n_hidden, | ||
n_layers=n_layers, | ||
**kwargs, | ||
) | ||
|
||
self.mean_encoder = Linear(n_hidden, n_output) | ||
self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode, eps=var_eps) | ||
|
||
def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None): | ||
y = self.decoder_y(x=x, cov=cov) | ||
# TODO better handling of inappropriate edge-case values than nan_to_num or at least warn | ||
y_m = torch.nan_to_num(self.mean_encoder(y)) | ||
y_v = self.var_encoder(y, x_m=y_m) | ||
|
||
outputs = {"y_m": y_m, "y_v": y_v} | ||
|
||
if self.sample: | ||
y = Normal(y_m, y_v.sqrt()).rsample() | ||
outputs["y"] = y | ||
|
||
return outputs | ||
|
||
|
||
class Layers(Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this class is similar to the existing implementation of |
||
"""A helper class to build fully-connected layers for a neural network. | ||
|
||
Adapted from scVI FCLayers to use covariates more flexibly | ||
|
||
Parameters | ||
---------- | ||
n_in | ||
The dimensionality of the main input | ||
n_out | ||
The dimensionality of the output | ||
n_cov | ||
Dimensionality of covariates. | ||
If there are no cov this should be set to None - | ||
in this case cov will not be used. | ||
n_layers | ||
The number of fully-connected hidden layers | ||
n_hidden | ||
The number of nodes per hidden layer | ||
dropout_rate | ||
Dropout rate to apply to each of the hidden layers | ||
use_batch_norm | ||
Whether to have `BatchNorm` layers or not | ||
use_layer_norm | ||
Whether to have `LayerNorm` layers or not | ||
use_activation | ||
Whether to have layer activation or not | ||
bias | ||
Whether to learn bias in linear layers or not | ||
inject_covariates | ||
Whether to inject covariates in each layer, or just the first. | ||
activation_fn | ||
Which activation function to use | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n_in: int, | ||
n_out: int, | ||
n_cov: int | None = None, | ||
n_layers: int = 1, | ||
n_hidden: int = 128, | ||
dropout_rate: float = 0.1, | ||
use_batch_norm: bool = True, | ||
use_layer_norm: bool = False, | ||
use_activation: bool = True, | ||
bias: bool = True, | ||
inject_covariates: bool = True, | ||
activation_fn: Module = ReLU, | ||
): | ||
super().__init__() | ||
|
||
self.inject_covariates = inject_covariates | ||
self.n_cov = n_cov if n_cov is not None else 0 | ||
|
||
layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [n_out] | ||
|
||
self.fc_layers = Sequential( | ||
OrderedDict( | ||
[ | ||
( | ||
f"Layer {i}", | ||
Sequential( | ||
Linear( | ||
n_in + self.n_cov * self.inject_into_layer(i), | ||
n_out, | ||
bias=bias, | ||
), | ||
# non-default params come from defaults in original Tensorflow implementation | ||
BatchNorm1d(n_out, momentum=0.01, eps=0.001) | ||
if use_batch_norm | ||
else None, | ||
LayerNorm(n_out, elementwise_affine=False) if use_layer_norm else None, | ||
activation_fn() if use_activation else None, | ||
Dropout(p=dropout_rate) if dropout_rate > 0 else None, | ||
), | ||
) | ||
for i, (n_in, n_out) in enumerate(zip(layers_dim[:-1], layers_dim[1:])) | ||
] | ||
) | ||
) | ||
|
||
def inject_into_layer(self, layer_num) -> bool: | ||
"""Helper to determine if covariates should be injected.""" | ||
user_cond = layer_num == 0 or (layer_num > 0 and self.inject_covariates) | ||
return user_cond | ||
|
||
def set_online_update_hooks(self, hook_first_layer=True): | ||
self.hooks = [] | ||
|
||
def _hook_fn_weight(grad): | ||
new_grad = torch.zeros_like(grad) | ||
if self.n_cov > 0: | ||
new_grad[:, -self.n_cov :] = grad[:, -self.n_cov :] | ||
return new_grad | ||
|
||
def _hook_fn_zero_out(grad): | ||
return grad * 0 | ||
|
||
for i, layers in enumerate(self.fc_layers): | ||
for layer in layers: | ||
if i == 0 and not hook_first_layer: | ||
continue | ||
if isinstance(layer, Linear): | ||
if self.inject_into_layer(i): | ||
w = layer.weight.register_hook(_hook_fn_weight) | ||
else: | ||
w = layer.weight.register_hook(_hook_fn_zero_out) | ||
self.hooks.append(w) | ||
b = layer.bias.register_hook(_hook_fn_zero_out) | ||
self.hooks.append(b) | ||
|
||
def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None): | ||
""" | ||
Forward computation on ``x``. | ||
|
||
Parameters | ||
---------- | ||
x | ||
tensor of values with shape ``(n_in,)`` | ||
cov | ||
tensor of covariate values with shape ``(n_cov,)`` or None | ||
|
||
Returns | ||
------- | ||
py:class:`torch.Tensor` | ||
tensor of shape ``(n_out,)`` | ||
|
||
""" | ||
for i, layers in enumerate(self.fc_layers): | ||
for layer in layers: | ||
if layer is not None: | ||
if isinstance(layer, BatchNorm1d): | ||
if x.dim() == 3: | ||
x = torch.cat([(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0) | ||
else: | ||
x = layer(x) | ||
else: | ||
# Injection of covariates | ||
if ( | ||
self.n_cov > 0 | ||
and isinstance(layer, Linear) | ||
and self.inject_into_layer(i) | ||
): | ||
x = torch.cat((x, cov), dim=-1) | ||
x = layer(x) | ||
return x | ||
|
||
|
||
class VarEncoder(Module): | ||
"""Encode variance (strictly positive). | ||
|
||
Parameters | ||
---------- | ||
n_input | ||
Number of input dimensions, used if mode is sample_feature | ||
n_output | ||
Number of variances to predict | ||
mode | ||
How to compute var | ||
'sample_feature' - learn per sample and feature | ||
'feature' - learn per feature, constant across samples | ||
'linear' - linear with respect to input mean, var = a1 * mean + a0; | ||
not suggested to be used due to bad implementation for positive constraining | ||
Comment on lines
+284
to
+289
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you format this similar to how it is done here? |
||
eps | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs documentation |
||
""" | ||
|
||
def __init__( | ||
self, | ||
n_input: int, | ||
n_output: int, | ||
mode: Literal["sample_feature", "feature", "linear"], | ||
eps: float = 1e-4, | ||
): | ||
super().__init__() | ||
|
||
self.eps = eps | ||
self.mode = mode | ||
if self.mode == "sample_feature": | ||
self.encoder = Linear(n_input, n_output) | ||
elif self.mode == "feature": | ||
self.var_param = Parameter(torch.zeros(1, n_output)) | ||
elif self.mode == "linear": | ||
self.var_param_a1 = Parameter(torch.tensor([1.0])) | ||
self.var_param_a0 = Parameter(torch.tensor([self.eps])) | ||
else: | ||
raise ValueError("Mode not recognised.") | ||
self.activation = torch.exp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably a good idea to make this an adjustable parameter as we've experienced numerical stability issues with the exponential. |
||
|
||
def forward(self, x: torch.Tensor, x_m: torch.Tensor): | ||
"""Forward pass through model | ||
|
||
Parameters | ||
---------- | ||
x | ||
Used to encode var if mode is sample_feature; dim = n_samples x n_input | ||
x_m | ||
Used to predict var instead of x if mode is linear; dim = n_samples x 1 | ||
|
||
Returns | ||
------- | ||
Predicted var | ||
""" | ||
# Force to be non nan - TODO come up with better way to do so | ||
if self.mode == "sample_feature": | ||
v = self.encoder(x) | ||
v = ( | ||
torch.nan_to_num(self.activation(v)) + self.eps | ||
) # Ensure that var is strictly positive | ||
elif self.mode == "feature": | ||
v = self.var_param.expand(x.shape[0], -1) # Broadcast to input size | ||
v = ( | ||
torch.nan_to_num(self.activation(v)) + self.eps | ||
) # Ensure that var is strictly positive | ||
elif self.mode == "linear": | ||
v = self.var_param_a1 * x_m.detach().clone() + self.var_param_a0 | ||
# TODO come up with a better way to constrain this to positive while having lin relationship | ||
# Could activation be used for log-lin relationship? | ||
v = torch.clamp(torch.nan_to_num(v), min=self.eps) | ||
return v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have there been edge case values other than NaNs?