Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import and extend PosteriorStats #431

Open
wants to merge 60 commits into
base: master
Choose a base branch
from
Open

Import and extend PosteriorStats #431

wants to merge 60 commits into from

Conversation

sethaxen
Copy link
Member

@sethaxen sethaxen commented Aug 20, 2023

This PR makes the following replacements everywhere:

  • hpd -> PosteriorStats.hdi (hpd retains its old behavior and is now deprecated)
  • summarize -> PosteriorStats.summarize
  • ChainDataFrame -> PosteriorStats.SummaryStats

The replacement of ChainDataFrame and the slight change in API and behavior of the methods makes this a breaking change.

Implements #430

e.g.

julia> val = rand(500, 2, 3);

julia> chn
Chains MCMC chain (500×2×3 Array{Float64, 3}):

Iterations        = 1:1:500
Number of chains  = 3
Samples per chain = 500
parameters        = param_1, param_2


Summary Statistics
          mean    std    hdi_3%  hdi_97%  mcse_mean  mcse_std  ess_tail  ess_bulk  rhat 
 param_1  0.50  0.283  0.00745     0.929     0.0073    0.0034      1480      1531  1.00
 param_2  0.49  0.285  0.000140    0.926     0.0074    0.0034      1499      1506  1.00

Quantiles
            2.5%  25.0%  50.0%  75.0%  97.5% 
 param_1  0.0252  0.254  0.510  0.745  0.969
 param_2  0.0180  0.242  0.483  0.727  0.969

julia> hdi(chn)
HDI
             lower  upper 
 param_1  0.00745   0.929
 param_2  0.000140  0.926

julia> describe(chn)
2-element Vector{SummaryStats}:
 Summary Statistics (2 rows, 10 cols)
 Quantiles (2 rows, 6 cols)

# Compute the change rates.
changerates, mvchangerate = changerate(chains)

# Summarize the results in a named tuple.
nt = (; zip(names_of_params, changerates)..., multivariate = mvchangerate)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lacking a parameter column meant the show method was broken. But since there is a changerate for every parameter, it makes more sense to do the same thing as gelmandiag_multivariate and return a SummaryStats for the marginal values and return the multivariate changerate separately.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Genuine question: what is the change-rate in this context?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From inspecting the code, it's, for each parameter and chain, the fraction of draws that are different from the previous draw. I suppose it's similar to "acceptance rate."

Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm very excited about this 👀

I have a few minor comments, but will approve as it looks good:)

src/plot.jl Outdated Show resolved Hide resolved
src/stats.jl Outdated Show resolved Hide resolved
@devmotion
Copy link
Member

It seemed you wanted to push a few additional updates @sethaxen? Or is this PR ready from your side?

@sethaxen
Copy link
Member Author

Yeah it's pending a breaking changes to PosteriorStats I need to finish up, to avoid another breaking release here.

@sethaxen
Copy link
Member Author

sethaxen commented Dec 24, 2023

@devmotion @torfjelde This is ready for final review.

@sethaxen
Copy link
Member Author

sethaxen commented Feb 1, 2024

Currently PosteriorStats has no try/catch mechanism for if a given statistic fails. That causes chains with less than 10 draws to error upon computation of summary stats (or display). For PosteriorStats that makes sense, but it's a major nuisance for MCMCChains. I actually think this has come up enough that it makes sense for MCMCDiagnosticTools to raise an informative warning and return NaNs in such cases. Will open an issue there.

@devmotion
Copy link
Member

Oh, I missed #431 (comment), probably due to the Christmas break. I added the PR to my todo list 🙂

src/MCMCChains.jl Outdated Show resolved Hide resolved
src/chains.jl Show resolved Hide resolved
@@ -64,7 +64,7 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
lags = 0:(maxlag === nothing ? round(Int, 10 * log10(length(range(c)))) : maxlag)
# Chains are already appended in `c` if desired, hence we use `append_chains=false`
ac = autocor(c; sections = nothing, lags = lags, append_chains=false)
ac_mat = convert(Array, ac)
ac_mat = stack(map(stack ∘ Base.Fix2(Iterators.drop, 1), ac))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit inefficient (particular the use of stack(map(stack) - maybe we should preallocate the desired array and copy directly? Possibly a utility function could be added to PosteriorStats?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may not be the most efficient, but compared to the actual autocor call, it should not be the computational bottleneck, and not using stack requires substantially more code.

It might make sense to add to PosteriorStats a function that converts a SummaryStats to a matrix, but this would effectively do the same thing as the inner stack. What kind of utility function do you have in mind?

Copy link
Member

@devmotion devmotion Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of utility function do you have in mind?

Basically a convert, Array/Matrix constructor, or other function that performs the same thing (if the first two alternatives are not desirable). So that downstream packages (such as MCMCChains) don't have to deal with internals of SummaryStats.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is not using any internals of SummaryStats, whose docstring explains it is an OrderedDict-like Table that can be iterated/indexed by its column names.

I think the methods we're missing in PosteriorStats are

function Base.getindex(stats::SummaryStats, cols::Union{Colon,AbstractVector{Int},AbstractVector{Symbol}})
    cols isa Colon && return Tables.matrix(stats)
    return stack(Tables.getcolumn(stats, k) for k in cols)
end
Base.firstindex(s::SummaryStats) = 1
Base.lastindex(s::SummaryStats) = length(s)

Then this would be

stack(ac_i[2:end] for ac_i in ac)

but this is no more efficient.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is not using any internals of SummaryStats, whose docstring explains it is an OrderedDict-like Table that can be iterated/indexed by its column names.

I was referring to the 2:end which seems a bit special. I missed that the docstring mentions that the first column is reserved for parameter names.

src/plot.jl Outdated
ordered = false
)

chain_dic = Dict(zip(quantile(chains)[:,1], quantile(chains)[:,4]))
chain_dic = Dict(zip(quantile(chains)[2], quantile(chains)[5]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a good opportunity to make the code a bit more efficient:

Suggested change
chain_dic = Dict(zip(quantile(chains)[2], quantile(chains)[5]))
quantile_chains = quantile(chains)
chain_dic = Dict(zip(quantile_chains[2], quantile_chains[5]))

(it would also be nice if it would be possible to use something more descriptive than 2 and 5)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yeah this and the code immediately below are a bit of a mess. Will rework to use the new methods.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it was able to be simplified a lot. But also, for n parameters, it was computing the median n^2 times. This has now been fixed.

Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few comments!

Am very excited about this though; bloody great stuff @sethaxen ❤️

src/MCMCChains.jl Outdated Show resolved Hide resolved
src/discretediag.jl Outdated Show resolved Hide resolved

Return the highest posterior density interval representing `1-alpha` probability mass.
Return the unimodal highest density interval (HDI) representing `prob` probability mass.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the meaning of "unimodal" here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default estimator used by hdi assumes that the distribution is unimodal. The alternative (not yet in PosteriorStats but doable with HighestDensityRegions.jl) for multimodal distributions first fits a KDE and then partitions into one or more intervals by density. For fast summary statistics, the unimodal version is more useful.

# Compute the change rates.
changerates, mvchangerate = changerate(chains)

# Summarize the results in a named tuple.
nt = (; zip(names_of_params, changerates)..., multivariate = mvchangerate)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Genuine question: what is the change-rate in this context?

src/stats.jl Outdated Show resolved Hide resolved
src/stats.jl Outdated Show resolved Hide resolved
src/stats.jl Outdated Show resolved Hide resolved
@sethaxen
Copy link
Member Author

This is ready for another review. Also, it would be nice to get input on arviz-devs/PosteriorStats.jl#25, since that's relevant for this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants