Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to interface with Turing.jl @model definitions? #2138

Open
arthurmloureiro opened this issue Nov 27, 2023 · 2 comments
Open

How to interface with Turing.jl @model definitions? #2138

arthurmloureiro opened this issue Nov 27, 2023 · 2 comments

Comments

@arthurmloureiro
Copy link

Hi!

We have been trying to build an AdvancedHMC sampler where we can provide the analytical gradient for a model built in Turing using the macro @model. It is quite unclear from the documentation how to go from a @model to a logprob such that it can be interfaced with AdvancedHCM.jl....

What I mean is... If I have a model like:

@model function gdemo(x, y)
    s² ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s²))
    x ~ Normal(m, sqrt(s²))
    y ~ Normal(m, sqrt(s²))
end

for example and I want to construct an AdvancedHMC.jl sampler which requires something like ℓπ = LogTargetDensity(D)... is there a function that interfaces between the generative model of Turing.jl and a log target density?

Thanks!

@torfjelde
Copy link
Member

This is more related to Turing.jl than AdvancedHMC.jl, I think, so I'll transfer the issue.

@torfjelde torfjelde transferred this issue from TuringLang/AdvancedHMC.jl Dec 3, 2023
@torfjelde
Copy link
Member

But regarding the question: this is possible, but a bit involved.

What you're looking for is: https://turinglang.org/DynamicPPL.jl/dev/api/#LogDensityProblems.jl-interface

(DynamicPPL is automatically available if you've done using Turing)

Taking the example from that docstring:

julia> model = demo(1.0);

julia> f = LogDensityFunction(model);

julia> # It implements the interface of LogDensityProblems.jl.
       using LogDensityProblems

julia> LogDensityProblems.logdensity(f, [0.0])
-2.3378770664093453

julia> LogDensityProblems.dimension(f)
1

You'll then have to overload LogDensityProblems.logdensity_and_gradient for this particular instance of f:

function LogDensityProblems.logdensity_and_gradient(::typeof(f), x::AbstractVector{<:Real})
    # ...
end

# Indicates that 1st order information is available.
LogDensityProblems.capabilities(::typeof(f)) = LogDensityProblems.LogDensityOrder{1}()

then you can just pass f to AdvancedHMC.jl and it should use your impl of logdensity_and_gradient 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants