Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor jax internals to support dense_mass kwarg for numpyro #7050

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

ferrine
Copy link
Member

@ferrine ferrine commented Dec 6, 2023

What is this PR about?
Enables Block Dense mass matrix adaptation for numpyro

Checklist

Major / Breaking Changes

  • ...

New features

  • Block mass matrix for numpyro
  • get_jaxified_logp now accepts point_fn argument
with pm.Model(
        coords=dict(level=["Basement", "Floor"], county=[1, 2]),
) as model:
    # multilevel modelling
    a = pm.Normal("a")
    s = pm.HalfNormal("s")
    a_g = pm.Normal("a_g", a, s, dims="level")
    s_g = pm.HalfNormal("s_g")
    a_ig = pm.Normal("a_ig", a_g, s_g, dims=("county", "level"))
    trace = sample_numpyro_nuts(
        nuts_kwargs=dict(
            dense_mass=[
                ("a", "a_g"),
            ]
        )
    )

Bugfixes

  • ...

Documentation

  • ...

Maintenance

  • ...

馃摎 Documentation preview 馃摎: https://pymc--7050.org.readthedocs.build/en/7050/

Copy link

codecov bot commented Dec 6, 2023

Codecov Report

Merging #7050 (8eb4284) into main (2e05854) will decrease coverage by 12.23%.
The diff coverage is 0.00%.

Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main    #7050       +/-   ##
===========================================
- Coverage   92.19%   79.97%   -12.23%     
===========================================
  Files         101      101               
  Lines       16893    16911       +18     
===========================================
- Hits        15575    13524     -2051     
- Misses       1318     3387     +2069     
Files Coverage 螖
pymc/sampling/jax.py 0.00% <0.00%> (-93.08%) 猬囷笍

... and 31 files with indirect coverage changes

pymc/sampling/jax.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 marked this pull request as draft December 10, 2023 11:36
@ricardoV94 ricardoV94 marked this pull request as ready for review December 11, 2023 13:55
@ricardoV94
Copy link
Member

ricardoV94 commented Dec 11, 2023

These failing tests are definitely a latest PyTensor issue, I'll patch it

@ricardoV94 ricardoV94 changed the title refactor jax internals to support dense_mass kwarg for numpyro Refactor jax internals to support dense_mass kwarg for numpyro Dec 11, 2023
@ricardoV94
Copy link
Member

Failing tests due to PyTensor should be fixed by pymc-devs/pytensor#546

@ricardoV94
Copy link
Member

@ferrine can you rebase?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants