Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculating acceptance rate #409

Open
scheidan opened this issue Mar 6, 2023 · 5 comments
Open

Calculating acceptance rate #409

scheidan opened this issue Mar 6, 2023 · 5 comments

Comments

@scheidan
Copy link

scheidan commented Mar 6, 2023

I often miss a function that computes the acceptance rate of a chain.

I'm happy to put PR together if you feel that would be a useful addition.

Some points we should think about:

  • How to deal with multiple chains? I'd say that we should compute the acceptance rate for every chain separately.
  • How to do this fast? Usually I use something like size(unique(X, dims=1)) but that's probably slow. Maybe StatsBase:countmap is better?
@cpfiffer
Copy link
Member

cpfiffer commented Mar 6, 2023

I suppose I could be okay with trying this out. Could you try benchmarking these on some big-ass matrices? Like show us compute times for countmap and size(unique(...)) for for different matrix sizes.

I will say though, in general the real way to do this is to calculate acceptance rates if a sampler produces an isaccept flag or whatever, which IIRC doesn't happen too much for a lot of the samplers that talk to MCMChains. The thing you proposed is probably slow but maybe a good first approximation of an acceptance rate.

@scheidan
Copy link
Author

scheidan commented Mar 7, 2023

Thanks! Thinking about, we do not need (or want) to use unique or countmap. Besides of being much slower, it would also be wrong whenever we have discrete parameters.

The approach below should be correct: it simply checks for every iteration if at least one parameter has changed, i.e. a jump has happened. It should also be reasonably efficient and does not allocate (we could use threading, not sure if is is worth here).

If that looks reasonable, I will work on a PR in the next days.

# acceptance rate of a single chain
function _acceptance_rate(x)
    n_jumps = 0
    for i in axes(x, 1)[2:end]
        for j in axes(x, 2)
            if x[i,j] != x[i-1,j]
                n_jumps += 1
                break
            end
        end
    end
    n_jumps / (size(x, 1)-1)
end

function acceptance_rate(chn::Chains)
    nchains = size(chn, 3)
    ac = [_acceptance_rate(@view chn.value[:,:,i]) for i in 1:nchains]
    ac
end

## -------------
## test

n = 10000                    # chain length
d = 200                         # dimension
k = 100                         # number of chains

X = cat(randn(n÷2, d, k), ones(n÷2, d, k), dims=1);
chn = Chains(X);

acceptance_rate(chn)
@btime acceptance_rate($(chn))   # ~ 200ms, 1 allocation (900 bytes)

## just for comparison
@btime ess($chn)   # ~ 10 sec

@scheidan
Copy link
Author

scheidan commented Mar 7, 2023

Would a PR to MCMCDiagnosticTools.jl make more sense?

@devmotion
Copy link
Member

To me this heuristic seems to be, well, only a heuristic and a bit too brittle to be added to MCMCDiagnosticTools or MCMCChains. Even checking if all parameters are the same does not guarantee that a proposal was rejected, in particular not when working with distributions with finite discrete support. Also, similar to what @cpfiffer said above, I think acceptance rates are a rather algorithm-specific thing and not a concept that can be applied to an arbitrary Markov chain (e.g., elliptical slice sampling does not reject or accept any samples, it just returns a sequence of samples). So I think this should be addressed in e.g. AdvancedMH (TuringLang/AdvancedMH.jl#40, TuringLang/AdvancedMH.jl#38).

@scheidan
Copy link
Author

scheidan commented Mar 8, 2023

Fair enough, but then, many things are bit heuristic when working with MCMC :)

I agree the clean solution is that the sampler returns acceptance rate.
However, MCMCChains turns out to be also very useful to compare result from various samplers outside of the TuringLang domain, by converting the results (all in slightly different formats) into Chains for uniform plotting and diagnostics. In this context a heuristic acceptance rate would be much better than non (maybe with warning).

Of course it up to you do define the scope of MCMCChains. I just want to bring attention to the fact that it is maybe useful in a wider range of applications than you thought.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants