Skip to content

Commit

Permalink
feat(external): add mrvi (#2756)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justin Hong <jjhong922@berkeley.edu>
Co-authored-by: Justin Hong <justin.hong@columbia.edu>
Co-authored-by: Pierre Boyeau <pierre.boyeau@gmail.com>
  • Loading branch information
5 people committed May 10, 2024
1 parent c7a9894 commit 2182546
Show file tree
Hide file tree
Showing 19 changed files with 3,035 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ to [Semantic Versioning]. Full commit history is available in the
{pr}`2692`.
- Add support for custom dataloaders in {class}`scvi.model.base.VAEMixin` methods by specifying
the `dataloader` argument {pr}`2748`.
- Add {class}`scvi.external.MRVI` for modeling sample-level heterogeneity in single-cell RNA-seq
data {pr}`2756`.

#### Changed

Expand Down
3 changes: 1 addition & 2 deletions docs/api/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ Module classes in the external API with respective generative and inference proc
external.scbasset.ScBassetModule
external.contrastivevi.ContrastiveVAE
external.velovi.VELOVAE
external.mrvi.MRVAE
```

Expand Down Expand Up @@ -275,5 +276,3 @@ Utility functions used by scvi-tools.
utils.attrdict
model.get_max_epochs_heuristic
```

[ray tune]: https://docs.ray.io/en/latest/tune/index.html
1 change: 1 addition & 0 deletions docs/api/user.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import scvi
external.ContrastiveVI
external.POISSONVI
external.VELOVI
external.MRVI
```

Expand Down
10 changes: 10 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ @article{Boyeau19
publisher = {Machine Learning in Computational Biology}
}

@article{Boyeau24,
title = {Deep generative modeling of sample-level heterogeneity in single-cell genomics},
author = {Pierre Boyeau and Justin Hong and Adam Gayoso and Martin Kim and Jose L. McFaline-Figueroa and Michael I. Jordan and Elham Azizi and Can Ergen and Nir Yosef},
doi = {10.1101/2022.10.04.510898},
year = {2024},
month = may,
journal = {bioRxiv},
publisher = {Cold Spring Harbor Laboratory}
}

@article{Clivio19,
title = {Detecting Zero-Inflated Genes in Single-Cell Transcriptomics Data},
author = {Oscar Clivio and Romain Lopez and Jeffrey Regier and Adam Gayoso and Michael I. Jordan and Nir Yosef},
Expand Down
1 change: 1 addition & 0 deletions docs/tutorials/index_scrna.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ notebooks/scrna/amortized_lda
notebooks/scrna/scVI_DE_worm
notebooks/scrna/contrastiveVI_tutorial
notebooks/scrna/scanvi_fix
notebooks/scrna/MrVI_tutorial
```
15 changes: 9 additions & 6 deletions docs/user_guide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ scvi-tools is composed of models that can perform one or many analysis tasks. In
* - :doc:`/user_guide/models/contrastivevi`
- scVI tasks with contrastive analysis
- :cite:p:`Weinberger23`
* - :doc:`/user_guide/models/mrvi`
- Characterization of sample-level heterogeneity
- :cite:p:`Boyeau24`
```

Expand Down Expand Up @@ -141,11 +144,11 @@ scvi-tools is composed of models that can perform one or many analysis tasks. In

## Background

- {doc}`/user_guide/background/variational_inference`
- {doc}`/user_guide/background/differential_expression`
- {doc}`/user_guide/background/counterfactual_prediction`
- {doc}`/user_guide/background/transfer_learning`
- {doc}`/user_guide/background/codebase_overview`
- {doc}`/user_guide/background/variational_inference`
- {doc}`/user_guide/background/differential_expression`
- {doc}`/user_guide/background/counterfactual_prediction`
- {doc}`/user_guide/background/transfer_learning`
- {doc}`/user_guide/background/codebase_overview`

## Glossary

Expand All @@ -166,7 +169,7 @@ A module is the lower-level object that defines a generative model and inference
either inherit {class}`~scvi.module.base.BaseModuleClass` or {class}`~scvi.module.base.PyroBaseModuleClass`.
Consequently, a module can either be implemented with PyTorch alone, or Pyro. In the PyTorch only case, the
generative process and inference scheme are implemented respectively in the `generative` and `inference` methods,
while the `loss` method computes the loss, e.g, ELBO in the case of variational inference.
while the `loss` method computes the loss, e.g, ELBO in the case of variational inference.
:::
::::

Expand Down
31 changes: 31 additions & 0 deletions docs/user_guide/models/figures/mrvi_graphical_model.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
148 changes: 148 additions & 0 deletions docs/user_guide/models/mrvi.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# MrVI

**MrVI** [^ref1] (Multi-resolution Variational Inference; Python class
{class}`~scvi.external.MRVI`) is a deep generative model designed for the analysis of large-scale
single-cell transcriptomics data with multi-sample, multi-batch experimental designs.

MrVI conducts both **exploratory analyses** (locally dividing samples into groups based on molecular properties)
and **comparative analyses** (comparing pre-defined groups of samples in terms of differential expression and differential abundance) at single-cell resolution.
It can capture nonlinear and cell-type specific variation of sample-level covariates on gene expression.

```{topic} Tutorials:
- {doc}`/tutorials/notebooks/scrna/MrVI_tutorial`
```

## Preliminaries

MrVI takes as input a scRNA-seq gene expression matrix $X$ with $N$ cells and $G$ genes.
Additionally, it requires specification, for each cell $n$:
- a sample-level target covariate $s_n$, that typically corresponds to the sample ID,
which defines which sample entities will be compared in exploratory and comparative analyses.
- nuisance covariates $b_n$ (e.g. sequencing run, processing day).

Optionally, MrVI can also take as input
- Cell-type labels for guided integration across samples, via a mixture of Gaussians prior where each mixture component serves as the mode of a cell type.
- Additional sample-level covariates of interest $c_s$ for each sample $s$ (e.g.
disease status, age, treatment) for comparative analysis.

## Generative process

MrVI posits a two-level hierarchical model (Figure 1):

1. A cell-level latent variable $u_n$ capturing cell state in a batch-corrected manner:
$u_n \sim \mathrm{MixtureOfGaussians}(\mu_1, ..., \mu_K, \Sigma_1, ..., \Sigma_K, \pi_1, ..., \pi_K)$
2. A cell-level latent variable $z_n$ capturing both cell state and effects of sample $s_n$:
$z_n | u_n \sim \mathcal{N}(u_n, I_L)$
3. The normalized gene expression levels $h_n$ are generated from $z_n$ as:
$h_n = \mathrm{softmax}(A_{zh} \times [z_n + g_\theta(z_n, b_n)] + \gamma_{zh})$
4. Finally the gene expression counts are generated as:
$x_{ng} | h_{ng} \sim \mathrm{NegativeBinomial}(l_n h_{ng}, r_{ng})$

Here $l_n$ is the library size of cell $n$, $r_{ng}$ is the gene-specific inverse dispersion,
$A_{zh}$ is a linear matrix of dimension $G \times L$, $\gamma_{zh}$ is a bias vector of dimension
$G$, and $\theta$ are neural network parameters.
$u_n$ captures broad cell states invariant to sample and batch,
while $z_n$ augments $u_n$ with sample-specific effects while correcting for nuisance covariate effects.
Gene expression is obtained from $z_n$ using multi-head attention mechanisms to
flexibly model batch and sample effects.

:::{figure} figures/mrvi_graphical_model.svg
:align: center
:alt: MrVI graphical model
:class: img-fluid

MrVI graphical model. Shaded nodes represent observed data, unshaded nodes represent latent variables.
:::

The latent variables, along with their description are summarized in the following table:

```{eval-rst}
.. list-table::
:widths: 20 90 15
:header-rows: 1
* - Latent variable
- Description
- Code variable (if different)
* - :math:`u_n \in \mathbb{R}^L`
- "sample-unaware" representation of a cell, invariant to both sample and nuisance covariates.
- ``u``
* - :math:`z_n \in \mathbb{R}^L`
- "sample-aware" representation of a cell, invariant to nuisance covariates.
- ``z``
* - :math:`h_n \in \mathbb{R}^G`
- Cell-specific normalized gene expression.
- ``h``
* - :math:`l_n \in \mathbb{R}^+`
- Cell size factor.
- ``library``
* - :math:`r_{ng} \in \mathbb{R}^+`
- Gene and cell-specific inverse dispersion.
- ``px_r``
* - :math:`\mu_1, ..., \mu_K \in \mathbb{R}^L`
- Mixture of Gaussians means for prior on $u_n$.
- ``u_prior_means``
* - :math:`\Sigma_1, ..., \Sigma_K \in \mathbb{R}^{L \times L}`
- Mixture of Gaussians covariance matrices for prior on $u_n$.
- ``u_prior_scales``
* - :math:`\pi_1, ..., \pi_K \in \mathbb{R}^+`
- Mixture of Gaussians weights for prior on $u_n$.
- ``u_prior_logits``
```

## Inference

MrVI uses variational inference to approximate the posterior of $u_n$ and $z_n$. The variational
distributions are:

$q_{\phi}(u_n | x_n) := \mathcal{N}(\mu_{\phi}(x_n), \sigma^2_{\phi}(x_n)I)$

$z_n := u_n + f_{\phi}(u_n, s_n)$

Here $\mu_{\phi}, \sigma^2_{\phi}$ are encoder neural networks and $f_{\phi}$ is a deterministic
mapping based on multi-head attention between $u_n$ and a learned embedding for sample $s_n$.

## Tasks

### Exploratory analysis

MrVI enables unsupervised local sample stratification via the construction of cell-specific
sample-sample distance matrices, for every cell $n$:

1. For each cell state $u_n$, compute counterfactual cell states $z^{(s)}_n$ for all possible samples $s$.
2. Compute cell-specific sample-sample distance matrices $D^{(n)}$ based on the Euclidean distance between all pairs of $z^{(s)}_n$.
3. Cluster cells based on their $D^{(n)}$ to find cell populations with distinct sample stratifications.
4. Average $D^{(n)}$ within each cell cluster and hierarchically cluster samples
This automatically reveals distinct sample stratifications that are specific to particular cell
subsets.

### Comparative analysis
MrVI also enables supervised comparative analysis to detect cell-type specific DE and DA between sample groups.

#### Differential expression
At a high level, the DE procedure regresses, within each cell $n$, counterfactual cell states $z^{(s)}_n$ on sample-level covariates $c_s$ of interest for analysis as
$z^{(s)}_n = \beta_n c_s + \beta_0 + \epsilon_n$.
For instance, if $c_s$ is a binary covariate, then $\beta_n$ will capture the shift (in $z$-space) induced by samples for which $c_s = 1$ compared to samples for which $c_s = 0$.
This procedure, repeated for all cells, allows several downstream analyses.
First, comparing the norm of $\beta_n$ (using $\chi^2$ statistics) across cells can identify cell-states that vary the most for a given covariate, or conversely, identify sample covariates that strongly associate with specific cell states.
Second, by decoding the linear approximation of $z^{(s)}_n$ for different covariate vectors that we would like to compare, we can compute associated log fold-changes to identify DE genes at the cell level.

#### Differential abundance
To compare two sets of samples, MrVI computes the log-ratio between the aggregated posteriors of the two groups, $A_1 \subset [[1, S]]$ and $A_2 \subset [[1, S]]$, where $S$ is the total number of samples.
In particular, the aggregated posterior for any sample $s$ is defined as
$q_s := \frac{1}{|s|} \sum_{n, s_n=s} q^{u}_{n}$,
where $q_n$ is the posterior of cell $n$ in $u$-space.
This aggregated posterior $q_s$ characterizes the distribution of all cells in sample $s$.
To characterize the distribution of cells in a group of samples $A$, we can consider the mixture of aggregated posteriors $q_s$ for all $s \in A$, corresponding to
$q_A := \frac{1}{|A|} \sum_{s \in A} q_s$.
In particular, cell states $u$ that are abundant in a sample group $A$ will have a high probability mass in $q_A$, while rare states will have low probability mass.
More generally, we can consider the log-ratio of aggregated posteriors between two groups of samples $A_1$ and $A_2$ as a measure of differential abundance:
$r = \log \frac{q_{A_1}}{q_{A_2}}$.
We can evaluate these log-ratios for all cell states $u$ to identify DA cell-state regions.
In particular, large positive (resp. negative) values of $r$ indicate that cell states are more abundant in $A_1$ (resp. $A_2$).

[^ref1]:
Pierre Boyeau, Justin Hong, Adam Gayoso, Martin Kim, Jose L McFaline-Figueroa, Michael Jordan, Elham Azizi, Can Ergen, Nir Yosef (2024),
_Deep generative modeling of sample-level heterogeneity in single-cell genomics_,
[bioRxiv](https://doi.org/10.1101/2022.10.04.510898).
1 change: 1 addition & 0 deletions src/scvi/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
class _REGISTRY_KEYS_NT(NamedTuple):
X_KEY: str = "X"
BATCH_KEY: str = "batch"
SAMPLE_KEY: str = "sample"
LABELS_KEY: str = "labels"
PROTEIN_EXP_KEY: str = "proteins"
CAT_COVS_KEY: str = "extra_categorical_covs"
Expand Down
2 changes: 2 additions & 0 deletions src/scvi/external/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .cellassign import CellAssign
from .contrastivevi import ContrastiveVI
from .gimvi import GIMVI
from .mrvi import MRVI
from .poissonvi import POISSONVI
from .scar import SCAR
from .scbasset import SCBASSET
Expand All @@ -21,4 +22,5 @@
"POISSONVI",
"ContrastiveVI",
"VELOVI",
"MRVI",
]
4 changes: 4 additions & 0 deletions src/scvi/external/mrvi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from ._model import MRVI
from ._module import MRVAE

__all__ = ["MRVI", "MRVAE"]

0 comments on commit 2182546

Please sign in to comment.