Skip to content

Commit

Permalink
Backport PR #2522 on branch 1.1.x (Allow for non-default user params …
Browse files Browse the repository at this point in the history
…in `POISSONVI`) (#2524)

Backport PR #2522: Allow for non-default user params in `POISSONVI`

Co-authored-by: Martin Kim <46072231+martinkim0@users.noreply.github.com>
  • Loading branch information
meeseeksmachine and martinkim0 committed Feb 19, 2024
1 parent 6b4d351 commit efdceef
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 11 deletions.
2 changes: 2 additions & 0 deletions docs/release_notes/index.md
Expand Up @@ -11,6 +11,8 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/

#### Fixed

- Correctly apply non-default user parameters in {class}`scvi.external.POISSONVI` {pr}`2522`.

### 1.1.0 (2024-02-13)

#### Added
Expand Down
15 changes: 7 additions & 8 deletions scvi/external/poissonvi/_model.py
Expand Up @@ -79,14 +79,13 @@ def __init__(
adata: AnnData,
n_hidden: int | None = None,
n_latent: int | None = None,
n_layers: int | None = None,
dropout_rate: float | None = None,
n_layers: int = 2,
dropout_rate: float = 0.1,
latent_distribution: Literal["normal", "ln"] = "normal",
**model_kwargs,
):
super().__init__(
adata,
)
# need to pass these in to get the correct defaults for peakvi
super().__init__(adata, n_hidden=n_hidden, n_latent=n_latent)

n_batch = self.summary_stats.n_batch
use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
Expand All @@ -104,11 +103,11 @@ def __init__(
n_cats_per_cov=self.module.n_cats_per_cov,
n_hidden=self.module.n_hidden,
n_latent=self.module.n_latent,
n_layers=self.module.n_layers_encoder,
dropout_rate=self.module.dropout_rate,
n_layers=n_layers,
dropout_rate=dropout_rate,
dispersion="gene", # not needed here
gene_likelihood="poisson", # fixed value for now, but we could think of also allowing nb
latent_distribution=self.module.latent_distribution,
latent_distribution=latent_distribution,
use_size_factor_key=use_size_factor_key,
library_log_means=library_log_means,
library_log_vars=library_log_vars,
Expand Down
62 changes: 59 additions & 3 deletions tests/external/test_poissonvi.py
@@ -1,13 +1,69 @@
from torch.nn import Linear

from scvi.data import synthetic_iid
from scvi.external import POISSONVI


def test_poissonvi():
adata = synthetic_iid(batch_size=100)
POISSONVI.setup_anndata(
adata,
)
POISSONVI.setup_anndata(adata)
model = POISSONVI(adata)
model.train(max_epochs=1)
model.get_latent_representation()
model.get_accessibility_estimates()


def test_poissonvi_default_params():
from scvi.model import PEAKVI

adata = synthetic_iid(batch_size=100)
POISSONVI.setup_anndata(adata)
PEAKVI.setup_anndata(adata)
poissonvi = POISSONVI(adata)
peakvi = PEAKVI(adata)

assert poissonvi.module.n_latent == peakvi.module.n_latent
assert poissonvi.module.latent_distribution == peakvi.module.latent_distribution
poisson_encoder = poissonvi.module.z_encoder.encoder
poisson_mean_encoder = poissonvi.module.z_encoder.mean_encoder
poisson_decoder = poissonvi.module.decoder.px_decoder
assert len(poisson_encoder.fc_layers) == peakvi.module.n_layers_encoder
assert len(poisson_decoder.fc_layers) == peakvi.module.n_layers_encoder
assert poisson_encoder.fc_layers[-1][0].in_features == peakvi.module.n_hidden
assert poisson_decoder.fc_layers[-1][0].in_features == peakvi.module.n_hidden
assert poisson_mean_encoder.out_features == peakvi.module.n_latent
assert poisson_decoder.fc_layers[0][0].in_features == peakvi.module.n_latent


def test_poissonvi_non_default_params(
n_hidden: int = 50,
n_latent: int = 5,
n_layers: int = 2,
dropout_rate: float = 0.4,
latent_distribution="ln",
):
adata = synthetic_iid(batch_size=100)
POISSONVI.setup_anndata(adata)
model = POISSONVI(
adata,
n_hidden=n_hidden,
n_latent=n_latent,
n_layers=n_layers,
dropout_rate=dropout_rate,
latent_distribution=latent_distribution,
)

assert model.module.n_latent == n_latent
assert model.module.latent_distribution == latent_distribution

encoder = model.module.z_encoder.encoder
assert len(encoder.fc_layers) == n_layers
linear = encoder.fc_layers[-1][0]
assert isinstance(linear, Linear)
assert linear.in_features == n_hidden
mean_encoder = model.module.z_encoder.mean_encoder
assert isinstance(mean_encoder, Linear)
assert mean_encoder.out_features == n_latent

model.train(max_epochs=1)
assert model.get_latent_representation().shape[1] == n_latent

0 comments on commit efdceef

Please sign in to comment.