diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index 9158423ee5..aee31140bb 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -11,6 +11,8 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/ #### Fixed +- Correctly apply non-default user parameters in {class}`scvi.external.POISSONVI` {pr}`2522`. + ### 1.1.0 (2024-02-13) #### Added diff --git a/scvi/external/poissonvi/_model.py b/scvi/external/poissonvi/_model.py index eef36a6a85..0acf9276fd 100644 --- a/scvi/external/poissonvi/_model.py +++ b/scvi/external/poissonvi/_model.py @@ -79,14 +79,13 @@ def __init__( adata: AnnData, n_hidden: int | None = None, n_latent: int | None = None, - n_layers: int | None = None, - dropout_rate: float | None = None, + n_layers: int = 2, + dropout_rate: float = 0.1, latent_distribution: Literal["normal", "ln"] = "normal", **model_kwargs, ): - super().__init__( - adata, - ) + # need to pass these in to get the correct defaults for peakvi + super().__init__(adata, n_hidden=n_hidden, n_latent=n_latent) n_batch = self.summary_stats.n_batch use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry @@ -104,11 +103,11 @@ def __init__( n_cats_per_cov=self.module.n_cats_per_cov, n_hidden=self.module.n_hidden, n_latent=self.module.n_latent, - n_layers=self.module.n_layers_encoder, - dropout_rate=self.module.dropout_rate, + n_layers=n_layers, + dropout_rate=dropout_rate, dispersion="gene", # not needed here gene_likelihood="poisson", # fixed value for now, but we could think of also allowing nb - latent_distribution=self.module.latent_distribution, + latent_distribution=latent_distribution, use_size_factor_key=use_size_factor_key, library_log_means=library_log_means, library_log_vars=library_log_vars, diff --git a/tests/external/test_poissonvi.py b/tests/external/test_poissonvi.py index d75c7493e6..53de237668 100644 --- a/tests/external/test_poissonvi.py +++ b/tests/external/test_poissonvi.py @@ -1,13 +1,69 @@ +from torch.nn import Linear + from scvi.data import synthetic_iid from scvi.external import POISSONVI def test_poissonvi(): adata = synthetic_iid(batch_size=100) - POISSONVI.setup_anndata( - adata, - ) + POISSONVI.setup_anndata(adata) model = POISSONVI(adata) model.train(max_epochs=1) model.get_latent_representation() model.get_accessibility_estimates() + + +def test_poissonvi_default_params(): + from scvi.model import PEAKVI + + adata = synthetic_iid(batch_size=100) + POISSONVI.setup_anndata(adata) + PEAKVI.setup_anndata(adata) + poissonvi = POISSONVI(adata) + peakvi = PEAKVI(adata) + + assert poissonvi.module.n_latent == peakvi.module.n_latent + assert poissonvi.module.latent_distribution == peakvi.module.latent_distribution + poisson_encoder = poissonvi.module.z_encoder.encoder + poisson_mean_encoder = poissonvi.module.z_encoder.mean_encoder + poisson_decoder = poissonvi.module.decoder.px_decoder + assert len(poisson_encoder.fc_layers) == peakvi.module.n_layers_encoder + assert len(poisson_decoder.fc_layers) == peakvi.module.n_layers_encoder + assert poisson_encoder.fc_layers[-1][0].in_features == peakvi.module.n_hidden + assert poisson_decoder.fc_layers[-1][0].in_features == peakvi.module.n_hidden + assert poisson_mean_encoder.out_features == peakvi.module.n_latent + assert poisson_decoder.fc_layers[0][0].in_features == peakvi.module.n_latent + + +def test_poissonvi_non_default_params( + n_hidden: int = 50, + n_latent: int = 5, + n_layers: int = 2, + dropout_rate: float = 0.4, + latent_distribution="ln", +): + adata = synthetic_iid(batch_size=100) + POISSONVI.setup_anndata(adata) + model = POISSONVI( + adata, + n_hidden=n_hidden, + n_latent=n_latent, + n_layers=n_layers, + dropout_rate=dropout_rate, + latent_distribution=latent_distribution, + ) + + assert model.module.n_latent == n_latent + assert model.module.latent_distribution == latent_distribution + + encoder = model.module.z_encoder.encoder + assert len(encoder.fc_layers) == n_layers + linear = encoder.fc_layers[-1][0] + assert isinstance(linear, Linear) + assert linear.in_features == n_hidden + mean_encoder = model.module.z_encoder.mean_encoder + assert isinstance(mean_encoder, Linear) + assert mean_encoder.out_features == n_latent + + model.train(max_epochs=1) + assert model.get_latent_representation().shape[1] == n_latent