Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TDist - cdf and quantile function - Auto-Differentiability #152

Open
paschermayr opened this issue Jan 25, 2023 · 4 comments
Open

TDist - cdf and quantile function - Auto-Differentiability #152

paschermayr opened this issue Jan 25, 2023 · 4 comments

Comments

@paschermayr
Copy link

Hi there,

Thank you for all your work! I have seen that recently, pull request #147 was closed in favor of #149.

I believe the former issue tried to make the cdf and quantile function of the TDistribution auto-differentiable, but the latter closed it and did not alleviate this issue. MWE from a fresh project with all up-to-date libraries:

using Distributions, DistributionsAD, StatsBase
using ForwardDiff, ReverseDiff
using StatsFuns

function mytargetfunction(data::AbstractVector)
    function obtaingradient::AbstractVector{R}) where {R<:Real}
        nu = θ[1]
        distr = TDist(nu)
        data_uniform = [cdf(distr, data[iter]) for iter in eachindex(data)]
        data_real = [quantile(distr, data_uniform[iter]) for iter in eachindex(data_uniform)]
        return sum( logpdf(distr, data_real[iter]) for iter in eachindex(data_real) )
    end
end

#working
ν = [3.0]
data = randn(1000)
target = mytargetfunction(data)
target(ν)
#not working
ForwardDiff.gradient(target, ν) #MethodError: no method matching _beta_inc(::ForwardDiff.Dual
ReverseDiff.gradient(target, ν) #MethodError: no method matching _beta_inc(::ReverseDiff.TrackedReal

It seems like the beta_inc function is from the Specialfunctions.jl package and requires Float64 as arguments instead of just reals. Is there a reason for that? I believe I should probably open an issue there as well?

@andreasnoack
Copy link
Member

We should probably have a general derivative rule for cdf defined somewhere. @devmotion any thoughts?

@devmotion
Copy link
Member

Hmm, for ChainRules-compatible AD systems we can add the missing rules just in StatsFuns (or, of course, SpecialFunctions directly if the rule is missing there). I think we might want to make ChainRulesCore a weak dependency in the future anyway on Julia >= 1.9, and then the amount of definitions should not matter for loading and compilation times if users do not use ChainRules.

We could also add definitions for ForwardDiff and ReverseDiff by making them weak dependencies. That could fix the issue at least on Julia >= 1.9. Maybe even better would be to make DiffRules a weak dependency (which they use for defining rules automatically instead of ChainRules - even though there are approaches to bridge them with ChainRules, they would be type piracy here: https://github.com/ThummeTo/ForwardDiffChainRules.jl and https://juliadiff.org/ReverseDiff.jl/dev/api/#ChainRules-integration) but the current design of DiffRules does not allow to reliably add new rules in other packages, i.e., they might not be picked up by e.g. ForwardDiff and ReverseDiff since they only define their differentiation rules once when they are loaded based on the rules that are available at that time point. @KristofferC was looking into some of the issues with the current design of DiffRules: JuliaDiff/DiffRules.jl#90

So at least on Julia >= 1.9, maybe the best short-term solution would be to add weak dependencies on ReverseDiff and ForwardDiff, and define rules for them explicitly. And to add missing ChainRules definitions.

(As a side remark, #147 also used beta_inc and beta_inc_inv, so - without testing it - I would assume the same AD issues as reported above would show up there as well.)

@andreasnoack
Copy link
Member

My point is that the derivative of x -> cdf(..., x) is readily available. However, supporting all the partial derivates of betacdf will require more work.

@devmotion
Copy link
Member

Yes, that's why I assumed you might want to add rules to StatsFuns (or, e.g., Distributions) instead of SpecialFunctions (even though in the example above you would need rules for theta -> cdf(dist(theta), x) as well). But to me it seemed that in cases where the derivatives are readily available the question is still how to deal with AD systems such as ForwardDiff and ReverseDiff that only support DiffRules but not ChainRules.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants