Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KL Divergence helper #7062

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Add KL Divergence helper #7062

wants to merge 8 commits into from

Conversation

ferrine
Copy link
Member

@ferrine ferrine commented Dec 11, 2023

Description

Related Issue

  • Closes issue: #
  • Related issue (not closed by this PR): #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7062.org.readthedocs.build/en/7062/

Copy link

codecov bot commented Dec 11, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (94020c9) 90.17% compared to head (e3a6c3e) 90.18%.
Report is 7 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7062      +/-   ##
==========================================
+ Coverage   90.17%   90.18%   +0.01%     
==========================================
  Files         101      103       +2     
  Lines       16932    16952      +20     
==========================================
+ Hits        15269    15289      +20     
  Misses       1663     1663              
Files Coverage Δ
pymc/distributions/__init__.py 100.00% <100.00%> (ø)
pymc/distributions/stats/__init__.py 100.00% <100.00%> (ø)
pymc/distributions/stats/kl_divergence.py 100.00% <100.00%> (ø)
pymc/logprob/__init__.py 100.00% <ø> (ø)
pymc/logprob/abstract.py 96.07% <100.00%> (+0.62%) ⬆️
pymc/logprob/basic.py 94.48% <100.00%> (+0.07%) ⬆️

Copy link
Contributor

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very good @ferrine! One thing that you definitely need to do before you can merge this is to add a page about this to the documentation. Maybe make a folder for logprob and include a subfolder for the KL divergence, since it will start to grow as more divergences get added.

pymc/logprob/basic.py Show resolved Hide resolved
q_inputs: List[TensorVariable],
p_inputs: List[TensorVariable],
):
_, _, _, q_mu, q_sigma = q_inputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably consider size like moment does?

@ferrine
Copy link
Member Author

ferrine commented Dec 12, 2023

This looks very good @ferrine! One thing that you definitely need to do before you can merge this is to add a page about this to the documentation. Maybe make a folder for logprob and include a subfolder for the KL divergence, since it will start to grow as more divergences get added.

I found that our pm.Distribution like pm.Normal can pass isinstance(rv.owner.op, pm.Normal) checks, I wonder if I can further rely on such functionality

@ferrine
Copy link
Member Author

ferrine commented Dec 12, 2023

I am also confused about the design choice, since moment implementation is scattered across the codebase...

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 12, 2023

The base functionality should be in logprob, but the specific implementations should be in Distributions

I am also confused about the design choice, since moment implementation is scattered across the codebase...

moment was implemented at a different time for a different purpose. It's not even really moment but finite_logp_point. It's never used for logprob stuff, just samplers.

I found that our pm.Distribution like pm.Normal can pass isinstance(rv.owner.op, pm.Normal) checks, I wonder if I can further rely on such functionality

You can but it won't work for things that look like Distributions but are just helpers to create distributions like pm.OrderedLogistic, pm.ZeroInflatedBinomial, or LKJCholeskyCov. What should always be safe is the rv_op that's actually returned.

@ferrine
Copy link
Member Author

ferrine commented Dec 15, 2023

Moved kl into a private pymc.distribution._stats because these are functions that will be never used by anyone

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 15, 2023

I don't like the underscore, why not just distributions/kl_div?

@ferrine
Copy link
Member Author

ferrine commented Jan 23, 2024

@ricardoV94 can you please reiterate on the review? Did I miss something?

@ricardoV94
Copy link
Member

Tests are failing with import issue

@ricardoV94
Copy link
Member

How many pairs of distributions do we expect to actually be able to support?

@ferrine
Copy link
Member Author

ferrine commented Feb 3, 2024

How many pairs of distributions do we expect to actually be able to support?

Many of them https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/kl_divergence

_, _, _, q_mu, q_sigma = q_inputs
_, _, _, p_mu, p_sigma = p_inputs
diff_log_scale = pt.log(q_sigma) - pt.log(p_sigma)
return (
Copy link
Member

@ricardoV94 ricardoV94 Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May want to broadcast to size, like we do with moment, if someone does kl_div(pm.Normal.dist(shape=5), pm.Normal.dist(mu=1))

@ricardoV94 ricardoV94 changed the title add KL Divergence helper Add KL Divergence helper Feb 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants