Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DestVI changed l1 loss. Expected proportions #1591

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 17 additions & 3 deletions docs/user_guide/models/destvi.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ as a dummy cell type to represent gene specific noise. The dummy cell type's exp
as $\epsilon_g := \mathrm{Softplus}(\eta_g)$ where $\eta_g \sim \mathrm{Normal}(0, 1)$.
Like the other cell types, there is an associated cell type abundance parameter $\beta_{sc}$ associated with $\eta$.
We suspect each spot to only contain a fraction of the different cell types. To increase sparsity of the cell type
proportions, the stLVM supports L1 regularization on the cell types proportions $\beta_{sc}$. By default this loss is
not used.
proportions, the stLVM supports L1 and entropy regularization on the cell types proportions $\beta_{sc}$. By default this loss is
not used. For cell-type proportion amortization we found entropy regularization to work better than L1 regularization. In some
experiments, we found cell-types to be predicted to be present in no spot. In these cases, it helps setting entropy regularization
with a small negative value to increase diversity of predicted cell-types.

This generative process is also summarized in the following graphical model:

Expand Down Expand Up @@ -190,7 +192,8 @@ The loss is defined as:

where $\mathrm{Var}(\alpha)$ refers to the empirical variance of the parameters alpha across all genes. We used this as a practical form of regularization (a similar regularizer is used in the ZINB-WaVE model [^ref3]).

$\lambda_{\beta}$ (`l1_reg` in code), $\lambda_{\eta}$ (`eta_reg` in code) and $\lambda_{\alpha}$ (`beta_reg` in code) are hyperparameters used to scale the loss term. Increasing $\lambda_{\beta}$ leads to increased sparsity of cell type proportions. Increasing $\lambda_{\alpha}$ leads to less model flexibility for technical variation between single cell and spatial sequencing dataset. Increasing $\lambda_{\eta}$ leads to more genes being explained by the dummy cell type (we recommend to not change the default value).
$\lambda_{\beta}$ (`celltype_reg` in code with key `l1` alternatively we support entropy regularization), $\lambda_{\eta}$ (`eta_reg` in code) and $\lambda_{\alpha}$ (`beta_reg` in code) are hyperparameters used to scale the loss term. Increasing $\lambda_{\beta}$ leads to increased sparsity of cell type proportions. Increasing $\lambda_{\alpha}$ leads to less model flexibility for technical variation between single cell and spatial sequencing dataset. Increasing $\lambda_{\eta}$ leads to more genes being explained by the dummy cell type (we recommend to not change the default value). We support defining `expected_proportions` as input to `scvi.module.DestVI.setup_anndata` this is an obsm field filled with expected cell-type proportions in each spot. This can be used to estimate cell-type activation state $\gamma^c of output of another deconvolution algorithm.

To avoid overfitting, DestVI amortizes inference using a neural network to parametrize the latent variables.
Via the `amortization` parameter of {class}`scvi.module.MRDeconv`, the user can specify which of
$\beta$ and $\gamma^c$ will be parametrized by the neural network.
Expand Down Expand Up @@ -238,6 +241,17 @@ impute the spatial pattern of the cell-type-specific gene expression with:
>>> imputed_counts = st_model.get_scale_for_ct("Monocyte", indices=indices)[["Cxcl9", "Cxcl10", "Fcgr1"]]
```

### Cell-type-specific gene expression assignment

Assuming the user has identified key gene modules that vary within a cell type of interest, they can
assign the measured expression to the spatial pattern of the cell-type-specific gene expression with:

```
>>> # Filter spots with low abundance.
>>> indices = np.where(st_adata.obsm["proportions"][ct_name].values > 0.03)[0]
>>> imputed_counts = st_model.get_scale_for_ct("Monocyte", indices=indices)[["Cxcl9", "Cxcl10", "Fcgr1"]]
```

### Comparative analysis between samples

To perform differential expression across samples, one can apply a frequentist test by taking samples
Expand Down
10 changes: 8 additions & 2 deletions scvi/data/fields/_arraylike_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,13 @@ def __init__(
colnames_uns_key: Optional[str] = None,
is_count_data: bool = False,
correct_data_format: bool = True,
required: bool = True,
) -> None:
super().__init__(registry_key)
if required and attr_key is None:
raise ValueError(
"`attr_key` cannot be `None` if `required=True`. Please provide an `attr_key`."
)
if field_type == "obsm":
self._attr_name = _constants._ADATA_ATTRS.OBSM
elif field_type == "varm":
Expand All @@ -88,6 +93,7 @@ def __init__(
raise ValueError("`field_type` must be either 'obsm' or 'varm'.")

self._attr_key = attr_key
self._is_empty = attr_key is None
self.colnames_uns_key = colnames_uns_key
self.is_count_data = is_count_data
self.correct_data_format = correct_data_format
Expand All @@ -99,7 +105,7 @@ def attr_key(self) -> str:

@property
def is_empty(self) -> bool:
return False
return self._is_empty

def validate_field(self, adata: AnnData) -> None:
"""Validate the field."""
Expand Down Expand Up @@ -175,7 +181,7 @@ def transfer_field(

def get_summary_stats(self, state_registry: dict) -> dict:
"""Get summary stats."""
n_array_cols = len(state_registry[self.COLUMN_NAMES_KEY])
n_array_cols = len(state_registry.get(self.COLUMN_NAMES_KEY, []))
return {self.count_stat_key: n_array_cols}

def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]:
Expand Down
76 changes: 66 additions & 10 deletions scvi/model/_destvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import LayerField, NumericalObsField
from scvi.data.fields import LayerField, NumericalObsField, ObsmField
from scvi.model import CondSCVI
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin
from scvi.module import MRDeconv
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
n_latent: int,
n_layers: int,
dropout_decoder: float,
l1_reg: float,
celltype_reg: dict,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this argument would be clearer if we separate it out into two arguments: celltype_regularization: Literal["l1", "entropy"] | None = None and celltype_regularization_weight: float | None = None

**module_kwargs,
):
super().__init__(st_adata)
Expand All @@ -93,7 +93,7 @@ def __init__(
n_layers=n_layers,
n_hidden=n_hidden,
dropout_decoder=dropout_decoder,
l1_reg=l1_reg,
celltype_reg=celltype_reg,
**module_kwargs,
)
self.cell_type_mapping = cell_type_mapping
Expand All @@ -106,7 +106,7 @@ def from_rna_model(
st_adata: AnnData,
sc_model: CondSCVI,
vamp_prior_p: int = 15,
l1_reg: float = 0.0,
celltype_reg: dict | None = None,
**module_kwargs,
):
"""Alternate constructor for exploiting a pre-trained model on a RNA-seq dataset.
Expand All @@ -119,9 +119,10 @@ def from_rna_model(
trained CondSCVI model
vamp_prior_p
number of mixture parameter for VampPrior calculations
l1_reg
Scalar parameter indicating the strength of L1 regularization on cell type proportions.
A value of 50 leads to sparser results.
celltype_reg
Dictionary indicating the strength and type ("l1" and "entropy" supported of regularization on cell type proportions).
A value of 200 for entropy loss leads to sparser results. If cell-types are predicted to be not present setting
"entropy" to negative values increases chances of detecting all cell-types.
**model_kwargs
Keyword args for :class:`~scvi.model.DestVI`
"""
Expand All @@ -139,6 +140,8 @@ def from_rna_model(
mean_vprior, var_vprior, mp_vprior = sc_model.get_vamp_prior(
sc_model.adata, p=vamp_prior_p
)
if celltype_reg is None:
celltype_reg = {"l1": 0}

return cls(
st_adata,
Expand All @@ -153,7 +156,7 @@ def from_rna_model(
var_vprior=var_vprior,
mp_vprior=mp_vprior,
dropout_decoder=dropout_decoder,
l1_reg=l1_reg,
celltype_reg=celltype_reg,
**module_kwargs,
)

Expand All @@ -162,6 +165,7 @@ def get_proportions(
keep_noise: bool = False,
indices: Sequence[int] | None = None,
batch_size: int | None = None,
normalize: bool = True,
) -> pd.DataFrame:
"""Returns the estimated cell type proportion for the spatial data.

Expand All @@ -171,6 +175,8 @@ def get_proportions(
----------
keep_noise
whether to account for the noise term as a standalone cell type in the proportion estimate.
normalize
Normalize outputs of proportions to have sum 1.
indices
Indices of cells in adata to use. Only used if amortization. If `None`, all cells are used.
batch_size
Expand All @@ -191,7 +197,7 @@ def get_proportions(
for tensors in stdl:
generative_inputs = self.module._get_generative_input(tensors, None)
prop_local = self.module.get_proportions(
x=generative_inputs["x"], keep_noise=keep_noise
x=generative_inputs["x"], keep_noise=keep_noise, normalize=normalize
)
prop_ += [prop_local.cpu()]
data = torch.cat(prop_).numpy()
Expand All @@ -202,7 +208,9 @@ def get_proportions(
logger.info(
"No amortization for proportions, ignoring indices and returning results for the full data"
)
data = self.module.get_proportions(keep_noise=keep_noise)
data = self.module.get_proportions(
keep_noise=keep_noise, normalize=normalize
)

return pd.DataFrame(
data=data,
Expand Down Expand Up @@ -289,6 +297,52 @@ def get_scale_for_ct(
raise ValueError("Unknown cell type")
y = np.where(label == self.cell_type_mapping)[0][0]

stdl = self._make_data_loader(self.adata, indices=indices, batch_size=batch_size)
scale = []
for tensors in stdl:
generative_inputs = self.module._get_generative_input(tensors, None)
x, ind_x = (
generative_inputs["x"],
generative_inputs["ind_x"],
)
px_scale = self.module.get_ct_specific_scale(x, ind_x, y)
scale += [px_scale.cpu()]

data = torch.cat(scale).numpy()
column_names = self.adata.var.index
index_names = self.adata.obs.index
if indices is not None:
index_names = index_names[indices]
return pd.DataFrame(data=data, columns=column_names, index=index_names)

def get_expression_for_ct(
self,
label: str,
indices: Sequence[int] | None = None,
batch_size: int | None = None,
) -> pd.DataFrame:
r"""
Return the per cell-type expression based on likelihood for every spot in queried cell types.

Parameters
----------
label
cell type of interest
indices
Indices of cells in self.adata to use. If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.

Returns
-------
Pandas dataframe of gene_expression
"""
self._check_if_trained()

if label not in self.cell_type_mapping:
raise ValueError("Unknown cell type")
y = np.where(label == self.cell_type_mapping)[0][0]

stdl = self._make_data_loader(self.adata, indices=indices, batch_size=batch_size)
scale = []
for tensors in stdl:
Expand Down Expand Up @@ -380,6 +434,7 @@ def setup_anndata(
cls,
adata: AnnData,
layer: str | None = None,
expected_proportions: pd.DataFrame | None = None,
**kwargs,
):
"""%(summary)s.
Expand All @@ -395,6 +450,7 @@ def setup_anndata(
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"),
ObsmField("expected_proportions", expected_proportions, required=False),
]
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
Expand Down