Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make VI compatible with JAX backend #7103

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Make VI compatible with JAX backend #7103

wants to merge 6 commits into from

Conversation

ferrine
Copy link
Member

@ferrine ferrine commented Jan 15, 2024

Description

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7103.org.readthedocs.build/en/7103/

@ferrine ferrine changed the title add dispatch for identity Op, use static shapes for parameters VI: add dispatch for identity Op, use static shapes for parameters Jan 15, 2024
@ferrine ferrine added jax VI Variational Inference labels Jan 15, 2024
Copy link

codecov bot commented Jan 15, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 91.11%. Comparing base (a06081e) to head (30a2d73).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7103      +/-   ##
==========================================
- Coverage   91.87%   91.11%   -0.76%     
==========================================
  Files         100      100              
  Lines       16874    16858      -16     
==========================================
- Hits        15503    15361     -142     
- Misses       1371     1497     +126     
Files Coverage Δ
pymc/pytensorf.py 91.46% <100.00%> (+0.16%) ⬆️
pymc/variational/approximations.py 80.09% <100.00%> (-10.41%) ⬇️

... and 12 files with indirect coverage changes

pymc/sampling/jax.py Outdated Show resolved Hide resolved
@@ -47,6 +46,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.scalar.basic import Cast
from pytensor.scalar.basic import identity as scalar_identity
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to create a new Elemwise, there's already one defined in tensor.math (or basic), just called tensor_copy

@ricardoV94 ricardoV94 changed the title VI: add dispatch for identity Op, use static shapes for parameters Make VI compatible with JAX backend Jan 16, 2024
@@ -387,7 +386,7 @@ def hessian_diag(f, vars=None):
return empty_gradient


identity = Elemwise(scalar_identity, name="identity")
identity = tensor_copy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick just import it directly in the VI module, no need to define it in pytensorf?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be used by someone else I assume

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, but even if we keep we should add a deprecatation warning

@ferrine
Copy link
Member Author

ferrine commented Jan 22, 2024

Windows tests seem to be very weird and can't reproduce it on a Linux machine, is shape inference platform dependent?

@ricardoV94
Copy link
Member

windows behaves differently with regard to integers. Default type is int32, which sometimes causes problems due to some rewrite or check that doesn't expect that (shape in PyTensor is supposed to be int64)

Just a guess from previous experiences. I can have a look on my windows machine next week

@ferrine
Copy link
Member Author

ferrine commented Mar 7, 2024

I see one of the issues got resolved with sort op recently. Any updates for Windows?

@ricardoV94
Copy link
Member

Any updates for Windows?

I don't think anyone investigated the problem yet

@ferrine
Copy link
Member Author

ferrine commented Mar 9, 2024

How about marking these tests as xfail then?

@ricardoV94
Copy link
Member

How about marking these tests as xfail then?

Let me or someone investigate on a Windows machine. Seems like an important failure on Windows. In the meantime you can rebase and pin PyMC to the next PyTensor version to see if the current xfail can be removed?

@ferrine
Copy link
Member Author

ferrine commented Mar 17, 2024

@ricardoV94 updated the dependency on pytensor and commented on one of the xfails in the tests. Hope windows tests get resolved with newer pytensor

@ferrine
Copy link
Member Author

ferrine commented Mar 17, 2024

In addition, mypy started to complain about pytensor

[pymc/sampling/forward.py]
pymc/sampling/forward.py:201: error: No overload variant of "general_toposort" matches argument types "list[Variable[Any, Any]]", "Callable[[Any], Any]"
pymc/sampling/forward.py:201: note: Possible overload variants:
pymc/sampling/forward.py:201: note:     def [T <: Node] general_toposort(outputs: Iterable[T], deps: None, compute_deps_cache: Callable[[T], Union[OrderedSet, list[T], None]], deps_cache: Optional[dict[T, list[T]]], clients: Optional[dict[T, list[T]]]) -> list[T]
pymc/sampling/forward.py:201: note:     def [T <: Node] general_toposort(outputs: Iterable[T], deps: Callable[[T], Union[OrderedSet, list[T]]], compute_deps_cache: None, deps_cache: None, clients: Optional[dict[T, list[T]]]) -> list[T]

@ricardoV94
Copy link
Member

@ferrine feel free to rebase, we have already bumped the dependency on main

def test_vi_sampling_jax(method):
with pm.Model() as model:
x = pm.Normal("x")
pm.fit(10, method=method, fn_kwargs=dict(mode="JAX"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be consistent with pm.sample and the nuts_sampler= arg, should we have a dedicated argument for the VI backend instead of kwargs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vote yes, this API looks super weird.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What looks weird? This is the compilation mode, would be exactly the same if you wanted to use Numba or JAX for the PyMC nuts sampler or for prior/posterior predictive.

The only thing I would change is the name of fn_kwargs, which is called compile_kwargs I think in those other functions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this be what the user would have to do if they wanted to run VI on JAX?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the question, this PR is just doing minor tweaks so the PyMC VI module can compile to JAX. It's not linking to specific JAX VI libraries.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used this for sample_posterior_predictive for projects just last week, as we were sampling new variables that had heavy matmuls, went down from hours to minutes.

Great idea, should definitely add it there too.

pm.sample is still useful as you can sample discrete variables with JAX this way.

That makes sense, I'm not opposed to adding it there. Maybe we can add a warning that the sampler is still running Python and they likely will want to use nuts_sampler.

Copy link
Member

@ricardoV94 ricardoV94 Apr 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still doing python loops, it's exactly the same argument you need for pm.sample.

It's different than linking to a JAX VI library, which is what would be equivalent to the nuts_sampler kwarg that Chris mentioned in the first comment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still doing python loops, it's exactly the same argument you need for pm.sample.

Oh, I somehow assumed that VI was implemented mostly in PyTensor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for this, I'd prefer to focus this PR on backend compatibility and later address possible API changes in a new issue + PR. Agreed that there is inconsistency, we need to resolve that, but this will only defer the push to main with at least some working solution which went through many issues already.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed @ferrine. My only suggestion is to switch fn_kwargs to compile_kwargs which we use in the other sample methods

@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 91.11%. Comparing base (0ad689c) to head (30a2d73).

❗ Current head 30a2d73 differs from pull request most recent head 994da6c. Consider uploading reports for the commit 994da6c to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7103      +/-   ##
==========================================
- Coverage   92.34%   91.11%   -1.23%     
==========================================
  Files         102      100       -2     
  Lines       17032    16858     -174     
==========================================
- Hits        15728    15361     -367     
- Misses       1304     1497     +193     
Files Coverage Δ
pymc/pytensorf.py 91.46% <100.00%> (+0.23%) ⬆️
pymc/variational/approximations.py 80.09% <100.00%> (-10.78%) ⬇️

... and 60 files with indirect coverage changes

@ferrine
Copy link
Member Author

ferrine commented May 1, 2024

@ferrine feel free to rebase, we have already bumped the dependency on main

Just rebased, let's see how it goes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
jax VI Variational Inference
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: VI can't be used with Jax
5 participants