Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mass Matrix for NUTS and HMC #3208

Open
ConnorStoneAstro opened this issue May 10, 2023 · 2 comments
Open

Mass Matrix for NUTS and HMC #3208

ConnorStoneAstro opened this issue May 10, 2023 · 2 comments

Comments

@ConnorStoneAstro
Copy link

Issue Description

I would like to be able to interact with the mass matrix in the NUTS and HMC samplers. In many cases I have access to the covariance matrix at the MAP, so being able to set the mass matrix exactly would provide a large speedup. Also, it would be nice to be able to access the mass matrix. This could be used for other purposes since it approximates the covariance matrix of the data.

This isn't a bug, just a feature I would like to have access to.

I was able to "hack" a way to get access, but it is not a long term solution since it involves overwritting functions:

def new_configure(self, mass_matrix_shape, adapt_mass_matrix=True, options={}):
    """
    Sets up an initial mass matrix.
    
    :param dict mass_matrix_shape: a dict that maps tuples of site names to the shape of
        the corresponding mass matrix. Each tuple of site names corresponds to a block.
    :param bool adapt_mass_matrix: a flag to decide whether an adaptation scheme will be used.
    :param dict options: tensor options to construct the initial mass matrix.
    """
    inverse_mass_matrix = {}
    for site_names, shape in mass_matrix_shape.items():
        self._mass_matrix_size[site_names] = shape[0]
        diagonal = len(shape) == 1
        inverse_mass_matrix[site_names] = (
            torch.full(shape, self._init_scale, **options)
            if diagonal
            else torch.eye(*shape, **options) * self._init_scale
        )
        if adapt_mass_matrix:
            adapt_scheme = WelfordCovariance(diagonal=diagonal)
            self._adapt_scheme[site_names] = adapt_scheme

    if len(self.inverse_mass_matrix.keys()) == 0:
        self.inverse_mass_matrix = inverse_mass_matrix
BlockMassMatrix.configure = new_configure

Then later, once I had the mass matrix I could call:

nuts_kernel.mass_matrix_adapter.inverse_mass_matrix = {("x",): inv_mass}

I could probably make the change myself, but I imagine there is a more elegant way to do this.

@fehiepsi
Copy link
Member

fehiepsi commented May 10, 2023

I think making a subclass BlockMassMatrix to return the desire configuration makes sense. Then you can set:

nuts_kernel.mass_matrix_adapter = ThatSubclass()

I guess we can also expose inverse_mass_matrix to the kernel construction, like what we have in numpyro.

@ConnorStoneAstro
Copy link
Author

Subclassing BlockMassMatrix is more elegant than what I did, but it still has the issue that I have to do it myself and so any updates to Pyro could break something. Having it exposed to the user at construction would be great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants