Skip to content

Commit

Permalink
Updated destVI for Curio.
Browse files Browse the repository at this point in the history
  • Loading branch information
canergen committed Dec 8, 2023
1 parent b066b95 commit 34e03d6
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 34 deletions.
2 changes: 1 addition & 1 deletion docs/user_guide/models/destvi.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ The loss is defined as:

where $\mathrm{Var}(\alpha)$ refers to the empirical variance of the parameters alpha across all genes. We used this as a practical form of regularization (a similar regularizer is used in the ZINB-WaVE model [^ref3]).

$\lambda_{\beta}$ (`celltype_reg` in code with key `l1` alternatively we support entropy regularization), $\lambda_{\eta}$ (`eta_reg` in code) and $\lambda_{\alpha}$ (`beta_reg` in code) are hyperparameters used to scale the loss term. Increasing $\lambda_{\beta}$ leads to increased sparsity of cell type proportions. Increasing $\lambda_{\alpha}$ leads to less model flexibility for technical variation between single cell and spatial sequencing dataset. Increasing $\lambda_{\eta}$ leads to more genes being explained by the dummy cell type (we recommend to not change the default value). We support defining `expected_proportions` as input to `scvi.module.DestVI.setup_anndata` this is an obsm field filled with expected cell-type proportions in each spot. This can be used to estimate cell-type activation state $\gamma^c of output of another deconvolution algorithm. As the definition of library_size is different for other deconvolution algorithm, we set it as learnable in this use case without a prior.
$\lambda_{\beta}$ (`celltype_reg` in code with key `l1` alternatively we support entropy regularization), $\lambda_{\eta}$ (`eta_reg` in code) and $\lambda_{\alpha}$ (`beta_reg` in code) are hyperparameters used to scale the loss term. Increasing $\lambda_{\beta}$ leads to increased sparsity of cell type proportions. Increasing $\lambda_{\alpha}$ leads to less model flexibility for technical variation between single cell and spatial sequencing dataset. Increasing $\lambda_{\eta}$ leads to more genes being explained by the dummy cell type (we recommend to not change the default value). We support defining `expected_proportions` as input to `scvi.module.DestVI.setup_anndata` this is an obsm field filled with expected cell-type proportions in each spot. This can be used to estimate cell-type activation state $\gamma^c of output of another deconvolution algorithm.

To avoid overfitting, DestVI amortizes inference using a neural network to parametrize the latent variables.
Via the `amortization` parameter of {class}`scvi.module.MRDeconv`, the user can specify which of
Expand Down
10 changes: 8 additions & 2 deletions scvi/data/fields/_arraylike_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,13 @@ def __init__(
colnames_uns_key: Optional[str] = None,
is_count_data: bool = False,
correct_data_format: bool = True,
required: bool = True,
) -> None:
super().__init__(registry_key)
if required and attr_key is None:
raise ValueError(
"`attr_key` cannot be `None` if `required=True`. Please provide an `attr_key`."
)
if field_type == "obsm":
self._attr_name = _constants._ADATA_ATTRS.OBSM
elif field_type == "varm":
Expand All @@ -88,6 +93,7 @@ def __init__(
raise ValueError("`field_type` must be either 'obsm' or 'varm'.")

self._attr_key = attr_key
self._is_empty = attr_key is None
self.colnames_uns_key = colnames_uns_key
self.is_count_data = is_count_data
self.correct_data_format = correct_data_format
Expand All @@ -99,7 +105,7 @@ def attr_key(self) -> str:

@property
def is_empty(self) -> bool:
return False
return self._is_empty

def validate_field(self, adata: AnnData) -> None:
"""Validate the field."""
Expand Down Expand Up @@ -175,7 +181,7 @@ def transfer_field(

def get_summary_stats(self, state_registry: dict) -> dict:
"""Get summary stats."""
n_array_cols = len(state_registry[self.COLUMN_NAMES_KEY])
n_array_cols = len(state_registry.get(self.COLUMN_NAMES_KEY, []))
return {self.count_stat_key: n_array_cols}

def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]:
Expand Down
16 changes: 8 additions & 8 deletions scvi/data/fields/_obsm_field.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import warnings
from typing import Dict, List, Optional, Union
from typing import Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -193,7 +193,7 @@ class JointObsField(BaseObsmField):
Sequence of keys to combine to form the obsm field.
"""

def __init__(self, registry_key: str, obs_keys: Optional[List[str]]) -> None:
def __init__(self, registry_key: str, obs_keys: Optional[list[str]]) -> None:
super().__init__(registry_key)
self._attr_key = f"_scvi_{registry_key}"
self._obs_keys = obs_keys if obs_keys is not None else []
Expand All @@ -209,7 +209,7 @@ def _combine_obs_fields(self, adata: AnnData) -> None:
adata.obsm[self.attr_key] = adata.obs[self.obs_keys].copy()

@property
def obs_keys(self) -> List[str]:
def obs_keys(self) -> list[str]:
"""List of .obs keys that make up this joint field."""
return self._obs_keys

Expand Down Expand Up @@ -238,7 +238,7 @@ class NumericalJointObsField(JointObsField):

COLUMNS_KEY = "columns"

def __init__(self, registry_key: str, obs_keys: Optional[List[str]]) -> None:
def __init__(self, registry_key: str, obs_keys: Optional[list[str]]) -> None:
super().__init__(registry_key, obs_keys)

self.count_stat_key = f"n_{self.registry_key}"
Expand Down Expand Up @@ -274,7 +274,7 @@ def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table
overflow="fold",
)
for key in state_registry[self.COLUMNS_KEY]:
t.add_row("adata.obs['{}']".format(key))
t.add_row(f"adata.obs['{key}']")
return t


Expand All @@ -300,7 +300,7 @@ class CategoricalJointObsField(JointObsField):
FIELD_KEYS_KEY = "field_keys"
N_CATS_PER_KEY = "n_cats_per_key"

def __init__(self, registry_key: str, obs_keys: Optional[List[str]]) -> None:
def __init__(self, registry_key: str, obs_keys: Optional[list[str]]) -> None:
super().__init__(registry_key, obs_keys)
self.count_stat_key = f"n_{self.registry_key}"

Expand All @@ -312,7 +312,7 @@ def _default_mappings_dict(self) -> dict:
}

def _make_obsm_categorical(
self, adata: AnnData, category_dict: Optional[Dict[str, List[str]]] = None
self, adata: AnnData, category_dict: Optional[dict[str, list[str]]] = None
) -> dict:
if self.obs_keys != adata.obsm[self.attr_key].columns.tolist():
raise ValueError(
Expand Down Expand Up @@ -402,7 +402,7 @@ def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table
for key, mappings in state_registry[self.MAPPINGS_KEY].items():
for i, mapping in enumerate(mappings):
if i == 0:
t.add_row("adata.obs['{}']".format(key), str(mapping), str(i))
t.add_row(f"adata.obs['{key}']", str(mapping), str(i))
else:
t.add_row("", str(mapping), str(i))
t.add_row("", "")
Expand Down
8 changes: 3 additions & 5 deletions scvi/model/_destvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ def get_scale_for_ct(
def get_expression_for_ct(
self,
label: str,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
indices: Sequence[int] | None = None,
batch_size: int | None = None,
) -> pd.DataFrame:
r"""
Return the per cell-type expression based on likelihood for every spot in queried cell types.
Expand All @@ -341,9 +341,7 @@ def get_expression_for_ct(
raise ValueError("Unknown cell type")
y = np.where(label == self.cell_type_mapping)[0][0]

stdl = self._make_data_loader(
self.adata, indices=indices, batch_size=batch_size
)
stdl = self._make_data_loader(self.adata, indices=indices, batch_size=batch_size)
scale = []
for tensors in stdl:
generative_inputs = self.module._get_generative_input(tensors, None)
Expand Down
25 changes: 11 additions & 14 deletions scvi/module/_mrdeconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,21 +291,18 @@ def loss(
glo_neg_log_likelihood_prior += self.beta_reg * torch.var(self.beta)

if expected_proportion is not None:
v_sparsity_loss = self.celltype_reg.values[0] * torch.sum(
torch.abs(v[:, :-1] - expected_proportion), axis=1
)
elif "l1" in self.celltype_reg.keys():
v_sparsity_loss = self.celltype_reg["l1"] * torch.sum(v, axis=1)
elif "entropy" in self.celltype_reg.keys():
v_sparsity_loss = (
self.celltype_reg["entropy"]
* torch.distributions.Categorical(probs=v).entropy().mean()
v_sparsity_loss = self.celltype_reg.values[0] * torch.sqrt(
torch.sum(torch.square(v[:, :-1] - expected_proportion), axis=1)
)
else:
raise ValueError(
"celltype_reg must be one of ['l1', 'entropy'], but input was "
"{}.format(self.celltype_reg.keys[0])"
)
v_sparsity_loss = 0
if "l1" in self.celltype_reg.keys():
v_sparsity_loss += self.celltype_reg["l1"] * torch.sum(v, axis=1)
if "entropy" in self.celltype_reg.keys():
v_sparsity_loss += (
self.celltype_reg["entropy"]
* torch.distributions.Categorical(probs=v).entropy().mean()
)

# gamma prior likelihood
if self.mean_vprior is None:
Expand Down Expand Up @@ -480,7 +477,7 @@ def get_ct_specific_expression(
(-1, self.n_latent)
) # minibatch_size * n_labels, n_latent
enum_label = (
torch.arange(0, self.n_labels).repeat((x.shape[0])).view((-1, 1))
torch.arange(0, self.n_labels).repeat(x.shape[0]).view((-1, 1))
) # minibatch_size * n_labels, 1
h = self.decoder(gamma_reshape, enum_label)
px_scale = self.px_decoder(h) # (minibatch, n_genes)
Expand Down
6 changes: 5 additions & 1 deletion tests/model/test_destvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ def test_destvi():
for amor_scheme in ["both", "none", "proportion", "latent"]:
DestVI.setup_anndata(dataset, layer=None)
# add l1_regularization to cell type proportions
if amor_scheme == "proportion":
celltype_reg = {"l1": 50}
else:
celltype_reg = {"entropy": 50}
spatial_model = DestVI.from_rna_model(
dataset, sc_model, amortization=amor_scheme, l1_reg=50
dataset, sc_model, amortization=amor_scheme, celltype_reg=celltype_reg
)
spatial_model.view_anndata_setup()
spatial_model.train(max_epochs=1)
Expand Down
11 changes: 8 additions & 3 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,7 @@ def test_scvi(save_path):
)
params = model.get_likelihood_parameters()
assert params["mean"].shape == adata.shape
assert (
params["mean"].shape == params["dispersions"].shape == params["dropout"].shape
)
assert params["mean"].shape == params["dispersions"].shape == params["dropout"].shape
params = model.get_likelihood_parameters(adata2, indices=[1, 2, 3])
assert params["mean"].shape == (3, adata.n_vars)
params = model.get_likelihood_parameters(
Expand Down Expand Up @@ -1425,6 +1423,13 @@ def test_destvi(save_path):
for amor_scheme in ["both", "none", "proportion", "latent"]:
DestVI.setup_anndata(dataset, layer=None)
# add l1_regularization to cell type proportions
if amor_scheme == "latent":
spatial_model = DestVI.from_rna_model(
dataset,
sc_model,
amortization=amor_scheme,
celltype_reg={"entropy": 50.0},
)
spatial_model = DestVI.from_rna_model(
dataset, sc_model, amortization=amor_scheme, celltype_reg={"l1": 50.0}
)
Expand Down

0 comments on commit 34e03d6

Please sign in to comment.