Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax version of VAEMixin #2779

Open
justjhong opened this issue May 6, 2024 · 0 comments
Open

Jax version of VAEMixin #2779

justjhong opened this issue May 6, 2024 · 0 comments

Comments

@justjhong
Copy link
Contributor

Should be easy to reimplement VAEMixin for JAX models. Will require a whole new class since the forward pass call is completely different.

Example implementation of get_reconstruction_error:

 def get_reconstruction_error(
        self,
        adata: AnnData | None = None,
        indices: list[int] | None = None,
        batch_size: int | None = None,
        **kwargs,
    ) -> dict[str, float]:
        adata = self._validate_anndata(adata)
        dataloader = self._make_data_loader(
            adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True
        )

        reconstruction_loss_sum = 0.0
        for batch in dataloader:
            vars_in = {"params": self.module.params, **self.module.state}
            outputs = self.module.apply(vars_in, batch, rngs=self.module.rngs, **kwargs)
            rec_loss_output = outputs[2].reconstruction_loss_sum.item()
            reconstruction_loss_sum += rec_loss_output

        return -(reconstruction_loss_sum / len(dataloader.dataset))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant