Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes error in get losses functions #2362

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
22828b6
HEAD
canergen Jul 3, 2022
821129a
Using expected_proportions in destVI.
canergen Jul 3, 2022
d8b5ce7
Pre-commit.
canergen Jul 4, 2022
5096aa1
Test not working. Needs update.
canergen Jul 4, 2022
f149d3b
COLUMN_NAMES in obsm optinal.
canergen Jul 4, 2022
b82c1b0
Added entropy regularization. Added get_expected_expression
canergen Jul 28, 2022
afa83e6
Celltype_reg is test.
canergen Jul 29, 2022
b0bca37
Precommit.
canergen Jul 29, 2022
0ea7972
Get specific expression changes
canergen Jul 29, 2022
4efe0e2
Fixed gamma in ct_specific_expression
canergen Jul 29, 2022
b066b95
Fixed ct_specific_expression.
canergen Jul 29, 2022
34e03d6
Updated destVI for Curio.
canergen Dec 8, 2023
ddfea2d
Fixed criticism when using subset of adata.
canergen Dec 8, 2023
b484700
Elbo,reconstruction loss per cell. Fixed error in
canergen Dec 8, 2023
e233575
Remove to have clear PR.
canergen Dec 8, 2023
ee1aeb5
Clean PR
canergen Dec 8, 2023
0a7b36e
Fix importance_weighting for scANVI
canergen Dec 8, 2023
24f34e0
Merge branch 'main' into can-expose_mc_samples_de
martinkim0 Dec 8, 2023
0ca691f
Changed shapes of disabled losses. Changelog.
canergen Dec 12, 2023
616c486
Merge branch 'main' into can-expose_mc_samples_de
martinkim0 Dec 13, 2023
7ec29c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2023
763c3c1
Merge branch 'main' into can-expose_mc_samples_de
martinkim0 Jan 10, 2024
8ea61f3
Merge branch 'main' into can-expose_mc_samples_de
martinkim0 Jan 17, 2024
4da32ba
Merge branch 'main' into can-expose_mc_samples_de
martinkim0 Jan 18, 2024
c16b69a
Merge branch 'main' into can-expose_mc_samples_de
martinkim0 Jan 18, 2024
bebea1e
Merge branch 'main' into can-expose_mc_samples_de
martinkim0 Jan 19, 2024
a3feb93
Merge branch 'main' into can-expose_mc_samples_de
martinkim0 Jan 25, 2024
37002f9
Merge branch 'main' into can-expose_mc_samples_de
martinkim0 Jan 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/release_notes/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/
by initializing with {class}`scvi.train.SaveCheckpoint` {pr}`2317`.
- {attr}`scvi.settings.dl_num_workers` is now correctly applied as the default
`num_workers` in {class}`scvi.dataloaders.AnnDataLoader` {pr}`2322`.
- Add argument `return_mean` to {meth}`scvi.model.base.VAEMixin.get_reconstruction_error`
and {meth}`scvi.model.base.VAEMixin.get_elbo` to allow computation
without averaging across cells {pr}`2362`.
- Add support for setting `weights="importance"` in
{meth}`scvi.model.SCANVI.differential_expression` {pr}`2362`.

#### Fixed

Expand All @@ -57,6 +62,7 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/
- Fix bug in {class}`scvi.module.SCANVAE` where classifier probabilities
were interpreted as logits. This is backwards compatible as loading older
models will use the old code path {pr}`2301`.
- Fix {meth}`scvi.module.VAE.marginal_ll` when `n_mc_samples_per_pass=1` {pr}`2362`.
- Fix bug in {class}`scvi.external.GIMVI` where `batch_size` was not
properly used in inference methods {pr}`2366`.
- Fix error message formatting in {meth}`scvi.data.fields.LayerField.transfer_field` {pr}`2368`.
Expand Down
2 changes: 1 addition & 1 deletion scvi/external/scar/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def loss(
generative_outputs["pl"],
).sum(dim=1)
else:
kl_divergence_l = 0.0
kl_divergence_l = torch.zeros_like(kl_divergence_z)

# need to add the ambient rate and scale to the distribution for the loss
px = generative_outputs["px"]
Expand Down
52 changes: 38 additions & 14 deletions scvi/model/base/_log_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""File for computing log likelihood of the data."""
import numpy as np
import torch


def compute_elbo(vae, data_loader, feed_labels=True, **kwargs):
def compute_elbo(vae, data_loader, feed_labels=True, return_mean=True, **kwargs):
"""Computes the ELBO.

The ELBO is the reconstruction error + the KL divergences
Expand All @@ -13,22 +14,34 @@ def compute_elbo(vae, data_loader, feed_labels=True, **kwargs):
It still gives good insights on the modeling of the data, and is fast to compute.
"""
# Iterate once over the data and compute the elbo
elbo = 0
if return_mean:
elbo = 0
else:
elbo = np.array([])
for tensors in data_loader:
_, _, scvi_loss = vae(tensors, **kwargs)

recon_loss = scvi_loss.reconstruction_loss_sum
kl_local = scvi_loss.kl_local_sum
elbo += (recon_loss + kl_local).item()
recon_loss = np.sum(
[np.array(i) for i in scvi_loss.reconstruction_loss.values()], axis=0
)
kl_local = np.sum([np.array(i) for i in scvi_loss.kl_local.values()], axis=0)

kl_global = scvi_loss.kl_global_sum
if return_mean:
elbo += (recon_loss + kl_local).sum(0).item()
else:
elbo = np.concatenate((elbo, recon_loss + kl_local), axis=0)

kl_global = np.sum([np.array(i) for i in scvi_loss.kl_global.values()], axis=0)
n_samples = len(data_loader.indices)
elbo += kl_global
return elbo / n_samples
if return_mean:
elbo += kl_global
return elbo / n_samples
else:
return elbo + kl_global / n_samples
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best to also use torch operations here since compute_reconstruction_error doesn't use numpy. Let me know what you think

Copy link
Contributor Author

@canergen canergen Jan 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with me. Can you puh those changes?



# do each one
def compute_reconstruction_error(vae, data_loader, **kwargs):
def compute_reconstruction_error(vae, data_loader, return_mean=True, **kwargs):
"""Computes log p(x/z), which is the reconstruction error.

Differs from the marginal log likelihood, but still gives good
Expand All @@ -44,13 +57,24 @@ def compute_reconstruction_error(vae, data_loader, **kwargs):
else:
rec_loss_dict = losses.reconstruction_loss
for key, value in rec_loss_dict.items():
if key in log_lkl:
log_lkl[key] += torch.sum(value).item()
else:
log_lkl[key] = torch.sum(value).item()
if return_mean:
if key in log_lkl:
if return_mean:
log_lkl[key] += torch.sum(value).item()
else:
log_lkl[key].append(value)
else:
if return_mean:
log_lkl[key] = torch.sum(value).item()
else:
log_lkl[key] = value

n_samples = len(data_loader.indices)
for key, _ in log_lkl.items():
log_lkl[key] = log_lkl[key] / n_samples
if return_mean:
log_lkl[key] = log_lkl[key] / n_samples
else:
log_lkl[key] = torch.cat(log_lkl[key], dim=0)

log_lkl[key] = -log_lkl[key]
return log_lkl
18 changes: 14 additions & 4 deletions scvi/model/base/_vaemixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def get_elbo(
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
return_mean: bool = True,
) -> float:
"""Return the ELBO for the data.

Expand All @@ -38,12 +39,15 @@ def get_elbo(
Indices of cells in adata to use. If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
return_mean
If False, return the ELBO for each observation.
Otherwise, return the mean ELBO.
"""
adata = self._validate_anndata(adata)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
elbo = compute_elbo(self.module, scdl)
elbo = compute_elbo(self.module, scdl, return_mean=return_mean)
return -elbo

@torch.inference_mode()
Expand Down Expand Up @@ -75,7 +79,7 @@ def get_marginal_ll(
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
return_mean
If False, return the marginal log likelihood for each observation.
Otherwise, return the mmean arginal log likelihood.
Otherwise, return the mean marginal log likelihood.
"""
adata = self._validate_anndata(adata)
if indices is None:
Expand All @@ -98,7 +102,7 @@ def get_marginal_ll(
)
)
if not return_mean:
return torch.cat(log_lkl, 0)
return torch.cat(log_lkl, dim=0)
else:
return np.mean(log_lkl)
else:
Expand All @@ -114,6 +118,7 @@ def get_reconstruction_error(
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
return_mean: Optional[bool] = True,
) -> float:
r"""Return the reconstruction error for the data.

Expand All @@ -129,12 +134,17 @@ def get_reconstruction_error(
Indices of cells in adata to use. If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
return_mean
If False, return the reconstruction loss for each observation.
Otherwise, return the mean reconstruction loss.
"""
adata = self._validate_anndata(adata)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
reconstruction_error = compute_reconstruction_error(self.module, scdl)
reconstruction_error = compute_reconstruction_error(
self.module, scdl, return_mean=return_mean
)
return reconstruction_error

@torch.inference_mode()
Expand Down
2 changes: 1 addition & 1 deletion scvi/module/_autozivae.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def loss(
Normal(local_library_log_means, torch.sqrt(local_library_log_vars)),
).sum(dim=1)
else:
kl_divergence_l = 0.0
kl_divergence_l = torch.zeros_like(kl_divergence_z)

# KL divergence wrt Bernoulli parameters
kl_divergence_bernoulli = self.compute_global_kl_divergence()
Expand Down
3 changes: 1 addition & 2 deletions scvi/module/_scanvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,7 @@ def loss(
Normal(local_library_log_means, torch.sqrt(local_library_log_vars)),
).sum(dim=1)
else:
kl_divergence_l = 0.0

kl_divergence_l = torch.zeros_like(loss_z1_weight)
if is_labelled:
loss = reconst_loss + loss_z1_weight + loss_z1_unweight
kl_locals = {
Expand Down
2 changes: 1 addition & 1 deletion scvi/module/_totalvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def loss(
Normal(local_library_log_means, torch.sqrt(local_library_log_vars)),
).sum(dim=1)
else:
kl_div_l_gene = 0.0
kl_div_l_gene = torch.zeros_like(kl_div_z)

kl_div_back_pro_full = kl(
Normal(py_["back_alpha"], py_["back_beta"]), self.back_mean_prior
Expand Down
4 changes: 3 additions & 1 deletion scvi/module/_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def loss(
generative_outputs["pl"],
).sum(dim=1)
else:
kl_divergence_l = torch.tensor(0.0, device=x.device)
kl_divergence_l = torch.zeros_like(kl_divergence_z)

reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)

Expand Down Expand Up @@ -609,6 +609,8 @@ def marginal_ll(
q_l_x = ql.log_prob(library).sum(dim=-1)

log_prob_sum += p_l - q_l_x
if n_mc_samples_per_pass == 1:
log_prob_sum = log_prob_sum.unsqueeze(0)
Comment on lines +612 to +613
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait so does this mean that this method was not working properly before? Since the default is n_mc_samples_per_pass=1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but we only use it for DE genes where it's called with n_samples_per_mc of 100 (?). ScANVI wasn't supporting importance weighting in DE beforehand.
ScANVI doesn't work with multiple samples (major work to fix this) and this might be an issue for multi-GPU support (see bug report). If you want to fix this, I have a fix for broadcast labels but then dropped it as also the encoder doesn't support it.


to_sum.append(log_prob_sum)
to_sum = torch.cat(to_sum, dim=0)
Expand Down
11 changes: 11 additions & 0 deletions tests/model/test_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,17 @@ def test_multiple_covariates_scanvi():
m.get_latent_representation()
m.get_elbo()
m.get_marginal_ll(n_mc_samples=3)
# m.get_marginal_ll(adata, return_mean=True, n_mc_samples=6, n_mc_samples_per_pass=1)
canergen marked this conversation as resolved.
Show resolved Hide resolved
m.differential_expression(
idx1=np.arange(50), idx2=51 + np.arange(50), mode="vanilla", weights="uniform"
)
m.differential_expression(
idx1=np.arange(50),
idx2=51 + np.arange(50),
mode="vanilla",
weights="importance",
importance_weighting_kwargs={"n_mc_samples": 10, "n_mc_samples_per_pass": 1},
)
m.get_reconstruction_error()
m.get_normalized_expression(n_samples=1)
m.get_normalized_expression(n_samples=2)
Expand Down
3 changes: 3 additions & 0 deletions tests/model/test_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,11 @@ def test_scvi(n_latent: int = 5):
assert z.shape == (adata.shape[0], n_latent)
assert len(model.history["elbo_train"]) == 2
model.get_elbo()
model.get_elbo(return_mean=False)
model.get_marginal_ll(n_mc_samples=3)
model.get_marginal_ll(n_mc_samples=3, return_mean=False)
model.get_reconstruction_error()
model.get_reconstruction_error(return_mean=False)
model.get_normalized_expression(transform_batch="batch_1")
model.get_normalized_expression(n_samples=2)

Expand Down