Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Autonormal encoder #2849

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open

[WIP] Autonormal encoder #2849

wants to merge 6 commits into from

Conversation

vitkl
Copy link
Contributor

@vitkl vitkl commented May 17, 2021

@martinjankowiak @fritzo @eb8680 following our conversation here scverse/scvi-tools#930 (review), creating this PR to discuss adding Autonormal encoder class.

This class need users to specify encoder network class, data transformation, and amortised_plate_sites dictionary which tells which variables are amortised, which model args/kwargs need to be passed to the encoder and which plate the variables belong to.

One of the main assumptions at the moment is that encoded variables are 2D tensors - but I guess that the shape can be automatically guessed, I just did not think that through and do not have applications where variables are more than 2D.

self._cond_indep_stacks[name] = site["cond_indep_stack"]

# add linear layer for locs and scales
param_dim = (self.n_hidden, self.amortised_plate_sites["sites"][name])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the main assumptions at the moment is that encoded variables are 2D tensors of shape (plate subsample_size aka batch size, self.amortised_plate_sites["sites"][name]) - but I guess that the shape can be automatically guessed, I just did not think that through and do not have applications where variables are more than 2D.

Copy link
Member

@fritzo fritzo May 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more important to get an initial version merged quickly than to make something fully general right away. So WDYT about simply adding assertions or NotImplementedError("not yet supported") checks for your current assumptions?

Also feel free to start the class docstring with EXPERIMENTAL and add a .. warning:: Interface may change` to give you/us room to slightly change the interface later in case that fully-general version needs slight changes.

@vitkl
Copy link
Contributor Author

vitkl commented May 17, 2021

Looks like accidentally included changes to AutoGuideList from PR #2837 (still learning how to use git correctly).

@fritzo
Copy link
Member

fritzo commented May 17, 2021

Looks like accidentally included changes from ...

No worries, #2837 should merge soon. We often add "Blocked by #xxx" in the PR description to denote merge order dependencies.

Comment on lines +1649 to +1656
init_param = torch.normal(
torch.full(size=param_dim, fill_value=0.0, device=site["value"].device),
torch.full(
size=param_dim,
fill_value=(1 * self.init_param_scale) / np.sqrt(self.n_hidden),
device=site["value"].device,
),
)
Copy link
Contributor Author

@vitkl vitkl May 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I use torch.normal rather than numpy.random.normal, I get this warning:

/scvi-tools/scvi/external/cell2location/autoguide.py:218: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  init_param, device=site["value"].device, requires_grad=True

I also get different results after training the model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numpy alternative

init_param = np.random.normal(
                np.zeros(param_dim),
                (np.ones(param_dim) * self.init_param_scale) / np.sqrt(self.n_hidden),
            ).astype("float32")

@vitkl
Copy link
Contributor Author

vitkl commented May 17, 2021

What is missing at the moment is a simple encoder NN class. @fritzo @martinjankowiak is there anything already defined in pyro or a good example?

self.hidden2locs,
name,
PyroParam(
torch.tensor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that UserWarning: To copy construct ... is actuall due to this line. I believe you can fix that by using as_tensor:

  PyroParam(
-     torch.tensor(
+     torch.as_tensor(
          init_param, ...

self.hidden2scales,
name,
PyroParam(
torch.tensor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: torch.tensor -> torch.as_tensor

@fritzo
Copy link
Member

fritzo commented May 17, 2021

What is missing at the moment is a simple encoder NN class. ...
is there anything already defined in pyro or a good example?

Existing code includes:

Also feel free to add something to a new file in pyro/nn/

@vitkl
Copy link
Contributor Author

vitkl commented Jun 1, 2021

I started thinking about tests and realised that this testing class also needs a model with local variables and some data. @fritzo do you have any good example in mind (which is ideally already implemented in tests)?

One alternative would be to use scVI pyro test regression model, and write simple training and posterior sampling code to test this class. Actually, for posterior sampling and computing median/quantiles, concatenating encoded local variables in plate dimension is non-trivial and a subject of this scVI PR (PyroSampleMixin class): scverse/scvi-tools#1059

Could be good if the AutoNormalEncoder class provided a method to merge quantiles, medians and posterior samples along the plate dimension. WDYT?

First guess plate dimension:

    def _guess_obs_plate_sites(self, args, kwargs):
        """
        Automatically guess which model sites belong to observation/minibatch plate.

        This function requires minibatch plate name specified in `self.amortised_plate_sites["name"]`.

        Parameters
        ----------
        args
            Arguments to the model.
        kwargs
            Keyword arguments to the model.

        Returns
        -------
        Dictionary with keys corresponding to site names and values to plate dimension.
        """

        plate_name = self.amortised_plate_sites["name"]

        # find plate dimension
        trace = poutine.trace(self.model).get_trace(*args, **kwargs)
        obs_plate = {
            name: site["cond_indep_stack"][0].dim
            for name, site in trace.nodes.items()
            if site["type"] == "sample"
            if any(f.name == plate_name for f in site["cond_indep_stack"])
        }

        return obs_plate

Then concatenate samples in that dimension:

i=0
for args, kwargs in dataloader:
    if i==0:
        samples = guide.quantiles(0.5, *args, **kwargs)
        obs_plate_sites = guide._guess_obs_plate_sites(args, kwargs)
        obs_plate_dim = list(obs_plate_sites.values())[0]
    else:
        samples_ = guide.quantiles(0.5, *args, **kwargs)
        samples = {
                    k: np.array(
                        [
                            np.concatenate(
                                [samples[k][j], samples_[k][j]],
                                axis=obs_plate_dim,
                            )
                            for j in range(
                                len(samples[k])
                            )  # for each sample (in 0 dimension
                        ]
                    )
                    for k in samples.keys()  # for each variable
                }
    i = i + 1

@vitkl
Copy link
Contributor Author

vitkl commented Jul 16, 2021

I extended this class further to enable more complex architectures (see below) and a different number of hidden nodes for each model site.

  1. Single encoder NN for all pyro model sites (encoder_mode='single') where means and scales linearly depend on the last NN layer
    A -> site1, site2, site3 ... siteN;

  2. Separate NN for each pyro model site (encoder_mode='multiple') where means and scales linearly depend on the last NN layer
    B -> site1;
    B -> site2
    ...
    B -> siteN;

  3. Single encoder NN followed by another layer of separate NN for each pyro model site (encoder_mode='single-multiple') where means and scales linearly depend on the last NN layer. Aka branching network:
    A -> B;
    B -> site1;
    B -> site2;
    ...
    B -> siteN;

Code is here for now: https://github.com/vitkl/scvi-tools/blob/pyro-cell2location/scvi/external/cell2location/autoguide.py

Still looking for good example data for tests.
I will start working on this when we resubmit the paper revision (hopefully in August).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants