Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Fullrank vi #479

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

xidulu
Copy link
Contributor

@xidulu xidulu commented Jan 29, 2023

Fullrank VI

TODO 1: Better test cases that verify posterior covariance recovery

TODO 2: Currently the user don't have access to the _real_to_vector function therefore cannot convert the unnormalized real space vector to the covariance matrix. This is not sensible and should be fixed.

__all__ = ["FullrankVIState", "FullrankVIInfo", "sample", "generate_fullrank_logdensity", "step"]


def _real_vector_to_cholesky(X):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this jitable? I think we might need to make m and n as kwarg, and create a closure below when we are setting up the parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this is jitable but jax seems to be happy with it. I didn't see any warning.

**optimizer_kwargs
) -> FullrankVIState:
"""Initialize the fullrank VI state."""
mu = jax.tree_map(jnp.zeros_like, position) # Is this a good init strategy?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also allow random initialization for both mu and L, maybe allowing user to pass a callable that takes random_key and shape as input (for zeros and ones we can just ignore the random_key)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, shall we also make the changes to MFVI?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to make a choice and let users initialise manually if they want something different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, the init function here only serves as a "default initialization option", if the users want to specify their own strategy, they could replace the call to "fullrank_vi.init()" https://github.com/blackjax-devs/blackjax/pull/479/files#diff-8923a0a4ea42b4d3c2e1756e182a25f482809bc5b1c23a601423f5908f0f03e2R36 with their own function

tests/test_fullrank_vi.py Outdated Show resolved Hide resolved
@xidulu
Copy link
Contributor Author

xidulu commented Feb 1, 2023

@rlouf @junpenglao
Any suggestions on how to let the user have access to the covariance matrix? Currently, the user can only access the flattened unnormalized version in the state https://github.com/blackjax-devs/blackjax/pull/479/files#diff-d7b1bbb7765064655362418341e99f42d8f9d15d120f946da70f32ee4295abe5R42
And the conversion from L to the covariance matrix is non-trivial....

@rlouf
Copy link
Member

rlouf commented Feb 27, 2023

Sorry for my late reaction. I think that the issue that you're facing with the covariance matrix is part of a more general discussion we're having at the library level (we have similar issues with mass matrix for the HMC algorithms). I suggest we leave it as is for now to get the ball rolling on this PR and get it in a mergeable state.

@rlouf rlouf changed the title Fullrank vi Implement Fullrank vi Jun 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants