Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LKJCholesky does not work with compiled ReverseDiff.jl #2091

Open
torfjelde opened this issue Oct 3, 2023 · 10 comments
Open

LKJCholesky does not work with compiled ReverseDiff.jl #2091

torfjelde opened this issue Oct 3, 2023 · 10 comments

Comments

@torfjelde
Copy link
Member

MWE:

julia> using Turing, TuringBenchmarking

julia> @model demo() = L ~ LKJCholesky(3, 1.0)
demo (generic function with 2 methods)

julia> suite = TuringBenchmarking.make_turing_suite(demo(); adbackends=[:reversediff, :reversediff_compiled], check=true);
┌ Warning: There is disagreement in the log-density values!
└ @ TuringBenchmarking ~/.julia/packages/TuringBenchmarking/eRnfy/src/TuringBenchmarking.jl:246
┌──────────────────────────────────────┬─────────────┐
│                             Standard │ Log-density │
│                              backend │    distance │
├──────────────────────────────────────┼─────────────┤
│ ReverseDiff vs ReverseDiff[compiled] │         Inf │
└──────────────────────────────────────┴─────────────┘
┌ Warning: There is disagreement in the gradients!
└ @ TuringBenchmarking ~/.julia/packages/TuringBenchmarking/eRnfy/src/TuringBenchmarking.jl:253
┌──────────────────────────────────────┬──────────┐
│                             Standard │ Gradient │
│                              backend │ distance │
├──────────────────────────────────────┼──────────┤
│ ReverseDiff vs ReverseDiff[compiled] │     1.23 │
└──────────────────────────────────────┴──────────┘
┌ Warning: There is disagreement in the gradients!
└ @ TuringBenchmarking ~/.julia/packages/TuringBenchmarking/eRnfy/src/TuringBenchmarking.jl:253
┌──────────────────────────────────────┬──────────┐
│                               Linked │ Gradient │
│                              backend │ distance │
├──────────────────────────────────────┼──────────┤
│ ReverseDiff vs ReverseDiff[compiled] │     5.08 │
└──────────────────────────────────────┴──────────┘
Manifest.toml
(jl_3WWgcx) pkg> st --manifest
Status `/tmp/jl_3WWgcx/Manifest.toml`
⌃ [47edcb42] ADTypes v0.1.6
  [621f4979] AbstractFFTs v1.5.0
  [80f14c24] AbstractMCMC v4.4.2
  [7a57a42e] AbstractPPL v0.6.2
  [1520ce14] AbstractTrees v0.4.4
  [79e6a3ab] Adapt v3.6.2
  [0bf59076] AdvancedHMC v0.5.5
  [5b7e9947] AdvancedMH v0.7.5
  [576499cb] AdvancedPS v0.4.3
  [b5ca4192] AdvancedVI v0.2.4
  [dce04be8] ArgCheck v2.3.0
  [4fba245c] ArrayInterface v7.4.11
  [a9b6321e] Atomix v0.1.0
  [13072b0f] AxisAlgorithms v1.0.1
  [39de3d68] AxisArrays v0.4.7
  [198e06fe] BangBang v0.3.39
  [9718e550] Baselet v0.1.1
  [6e4b80f9] BenchmarkTools v1.3.2
  [76274a88] Bijectors v0.13.7
  [fa961155] CEnum v0.4.2
  [49dc2e85] Calculus v0.5.1
  [082447d4] ChainRules v1.54.0
  [d360d2e6] ChainRulesCore v1.16.0
  [9e997f8a] ChangesOfVariables v0.1.8
  [861a8166] Combinatorics v1.0.2
  [38540f10] CommonSolve v0.2.4
  [bbf7d656] CommonSubexpressions v0.3.0
  [34da2185] Compat v4.10.0
  [a33af91c] CompositionsBase v0.1.2
  [88cd18e8] ConsoleProgressMonitor v0.1.2
  [187b0558] ConstructionBase v1.5.4
  [a8cc5b0e] Crayons v4.1.1
  [9a962f9c] DataAPI v1.15.0
  [864edb3b] DataStructures v0.18.15
  [e2d170a0] DataValueInterfaces v1.0.0
  [244e2a9f] DefineSingletons v0.1.2
  [8bb1440f] DelimitedFiles v1.9.1
  [b429d917] DensityInterface v0.4.0
  [163ba53b] DiffResults v1.1.0
  [b552c78f] DiffRules v1.15.1
  [31c24e10] Distributions v0.25.102
  [ced4e74d] DistributionsAD v0.6.53
  [ffbed154] DocStringExtensions v0.9.3
  [fa6b7ba4] DualNumbers v0.6.8
  [366bfd00] DynamicPPL v0.23.18
  [cad2338a] EllipticalSliceSampling v1.1.0
  [4e289a0a] EnumX v1.0.4
  [e2ba6199] ExprTools v0.1.10
  [7a1cc6ca] FFTW v1.7.1
  [1a297f60] FillArrays v1.6.1
  [59287772] Formatting v0.4.2
  [f6369f11] ForwardDiff v0.10.36
  [069b7b12] FunctionWrappers v1.1.3
  [77dc65aa] FunctionWrappersWrappers v0.1.3
  [d9f16b24] Functors v0.4.5
  [0c68f7d7] GPUArrays v9.0.0
  [46192b85] GPUArraysCore v0.1.5
  [34004b35] HypergeometricFunctions v0.3.23
  [7869d1d1] IRTools v0.4.10
  [22cec73e] InitialValues v0.3.1
  [505f98c9] InplaceOps v0.3.0
  [a98d9a8b] Interpolations v0.14.7
  [8197267c] IntervalSets v0.7.7
  [3587e190] InverseFunctions v0.1.12
  [41ab1584] InvertedIndices v1.3.0
  [92d709cd] IrrationalConstants v0.2.2
  [c8e1da08] IterTools v1.8.0
  [82899510] IteratorInterfaceExtensions v1.0.0
  [692b3bcd] JLLWrappers v1.5.0
  [682c06a0] JSON v0.21.4
  [63c18a36] KernelAbstractions v0.9.8
  [5ab0869b] KernelDensity v0.6.7
  [929cbde3] LLVM v6.3.0
  [8ac3fa9e] LRUCache v1.5.0
  [b964fa9f] LaTeXStrings v1.3.0
  [50d2b5c4] Lazy v0.15.1
  [1d6d02ad] LeftChildRightSiblingTrees v0.2.0
  [6f1fad26] Libtask v0.8.6
  [6fdf6af0] LogDensityProblems v2.1.1
  [996a588d] LogDensityProblemsAD v1.6.1
  [2ab3a3ac] LogExpFunctions v0.3.26
  [e6f89c97] LoggingExtras v1.0.2
  [c7f686f2] MCMCChains v6.0.3
  [be115224] MCMCDiagnosticTools v0.3.5
  [e80e1ace] MLJModelInterface v1.9.2
  [1914dd2f] MacroTools v0.5.11
  [dbb5928d] MappedArrays v0.4.2
  [128add7d] MicroCollections v0.1.4
  [e1d29d7a] Missings v1.1.0
  [872c559c] NNlib v0.9.7
  [77ba4419] NaNMath v1.0.2
  [86f7a689] NamedArrays v0.10.0
  [c020b1a1] NaturalSort v1.0.0
  [6fe1bfb0] OffsetArrays v1.12.10
  [3bd65402] Optimisers v0.3.1
  [bac558e1] OrderedCollections v1.6.2
  [90014a1f] PDMats v0.11.19
  [69de0a69] Parsers v2.7.2
  [aea7be01] PrecompileTools v1.2.0
  [21216c6a] Preferences v1.4.1
  [08abe8d2] PrettyTables v2.2.7
  [33c8b6b6] ProgressLogging v0.1.4
  [92933f4c] ProgressMeter v1.9.0
  [1fd47b50] QuadGK v2.9.1
  [74087812] Random123 v1.6.1
  [e6cf234a] RandomNumbers v1.5.3
  [b3c3ace0] RangeArrays v0.3.2
  [c84ed2f1] Ratios v0.4.5
  [c1ae055f] RealDot v0.1.0
  [3cdcf5f2] RecipesBase v1.3.4
  [731186ca] RecursiveArrayTools v2.38.10
  [189a3867] Reexport v1.2.2
  [ae029012] Requires v1.3.0
  [37e2e3b7] ReverseDiff v1.15.1
  [79098fc4] Rmath v0.7.1
  [f2b01f46] Roots v2.0.20
  [7e49a35a] RuntimeGeneratedFunctions v0.5.12
⌅ [0bca4576] SciMLBase v1.98.1
  [c0aeaf25] SciMLOperators v0.3.6
  [30f210dd] ScientificTypesBase v3.0.0
  [efcf1570] Setfield v1.1.1
  [ce78b400] SimpleUnPack v1.1.0
  [a2af1166] SortingAlgorithms v1.1.1
  [dc90abb0] SparseInverseSubset v0.1.1
  [276daf66] SpecialFunctions v2.3.1
  [171d559e] SplittablesBase v0.1.15
  [90137ffa] StaticArrays v1.6.5
  [1e83bf80] StaticArraysCore v1.4.2
  [64bff920] StatisticalTraits v3.2.0
  [82ae8749] StatsAPI v1.7.0
  [2913bbd2] StatsBase v0.34.2
  [4c63d2b9] StatsFuns v1.3.0
  [892a3eda] StringManipulation v0.3.4
  [09ab397b] StructArrays v0.6.16
  [2efcf032] SymbolicIndexingInterface v0.2.2
  [3783bdb8] TableTraits v1.0.1
  [bd369af6] Tables v1.11.0
  [5d786b92] TerminalLoggers v0.1.7
  [9f7883ad] Tracker v0.2.27
  [28d57a85] Transducers v0.4.78
  [410a4b4d] Tricks v0.1.7
  [781d530d] TruncatedStacktraces v1.4.0
  [fce5fe82] Turing v0.29.1
  [0db1332d] TuringBenchmarking v0.3.2
  [013be700] UnsafeAtomics v0.2.1
  [d80eeb9a] UnsafeAtomicsLLVM v0.1.3
  [efce3f68] WoodburyMatrices v0.5.5
  [e88e6eb3] Zygote v0.6.65
  [700de1a5] ZygoteRules v0.2.3
  [f5851436] FFTW_jll v3.3.10+0
  [1d5cc7b8] IntelOpenMP_jll v2023.2.0+0
  [dad2f222] LLVMExtra_jll v0.0.26+0
  [856f044c] MKL_jll v2023.2.0+0
  [efe28fd5] OpenSpecFun_jll v0.5.5+0
  [f50d1b31] Rmath_jll v0.4.0+0
  [0dad84c5] ArgTools v1.1.1
  [56f22d72] Artifacts
  [2a0f44e3] Base64
  [ade2ca70] Dates
  [8ba89e20] Distributed
  [f43a241f] Downloads v1.6.0
  [7b1f6079] FileWatching
  [9fa8497b] Future
  [b77e0a4c] InteractiveUtils
  [4af54fe1] LazyArtifacts
  [b27032c2] LibCURL v0.6.3
  [76f85450] LibGit2
  [8f399da3] Libdl
  [37e2e46d] LinearAlgebra
  [56ddb016] Logging
  [d6f4376e] Markdown
  [a63ad114] Mmap
  [ca575930] NetworkOptions v1.2.0
  [44cfe95a] Pkg v1.9.2
  [de0858da] Printf
  [9abbd945] Profile
  [3fa0cd96] REPL
  [9a3f8284] Random
  [ea8e919c] SHA v0.7.0
  [9e88b42a] Serialization
  [1a1011a3] SharedArrays
  [6462fe0b] Sockets
  [2f01184e] SparseArrays
  [10745b16] Statistics v1.9.0
  [4607b0f0] SuiteSparse
  [fa267f1f] TOML v1.0.3
  [a4e569a6] Tar v1.10.0
  [8dfed614] Test
  [cf7118a7] UUIDs
  [4ec0a83e] Unicode
  [e66e0078] CompilerSupportLibraries_jll v1.0.5+0
  [deac9b47] LibCURL_jll v7.84.0+0
  [29816b5a] LibSSH2_jll v1.10.2+0
  [c8ffd9c3] MbedTLS_jll v2.28.2+0
  [14a3606d] MozillaCACerts_jll v2022.10.11
  [4536629a] OpenBLAS_jll v0.3.21+4
  [05823500] OpenLibm_jll v0.8.1+0
  [bea87d4a] SuiteSparse_jll v5.10.1+6
  [83775a58] Zlib_jll v1.2.13+0
  [8e850b90] libblastrampoline_jll v5.8.0+0
  [8e850ede] nghttp2_jll v1.48.0+0
  [3f19e933] p7zip_jll v17.4.0+0
Info Packages marked with ⌃ and ⌅ have new versions available, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated -m`
@sunxd3
Copy link
Collaborator

sunxd3 commented Oct 5, 2023

Fixed by #2097? Run without warning on my end.

versioninfo
Julia Version 1.10.0-beta3
Commit 404750f8586 (2023-10-03 12:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 10 × Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
  Threads: 8 on 10 virtual cores
Environment:
  LD_LIBRARY_PATH = /usr/local/cuda/lib64:
  JULIA_NUM_THREADS = 6
  JULIA_EDITOR = code

@torfjelde
Copy link
Member Author

Are you on master? Because we haven't made a realize with that PR yet

@sunxd3
Copy link
Collaborator

sunxd3 commented Oct 6, 2023

No, on your branch

@torfjelde
Copy link
Member Author

Oh sure, then this will be fixed by #2097 yes

@torfjelde
Copy link
Member Author

Fixed by #2097, which is now being released

@tiemvanderdeure
Copy link

I'm still getting some weird behaviour when using LKJCholesky with compiled ReverseDiff. benchmark_model doesn't identify any gradient differences, and sampling is much more inefficient in compiled mode NUTS.

I know I made a similar issue about this, but actually the MWE here is enough to reproduce some behaviour that looks weird to me.

That is, in the code below, I would expect chn_rd and chn_rd_compiled to be identical. Instead chn_rd_compiled has much lower effective sample sizes and the chain looks very different, even though benchmarking with check = true doesn't flag anything.

Am I missing anything obvious here?

using Turing, Random, StatsPlots, TuringBenchmarking, Memoization
@model demo() = L ~ LKJCholesky(3, 1.0)

Turing.setadbackend(:reversediff)
Turing.setrdcache(false)
Random.seed!(1234)
chn_rd = sample(demo(), NUTS(), 1000)

Turing.setrdcache(true)
Random.seed!(1234)
chn_rd_compiled = sample(demo(), NUTS(), 1000)

Turing.setadbackend(:forwarddiff)
Random.seed!(1234)
chn_fd = sample(demo(), NUTS(), 1000)

StatsPlots.plot(chn_rd) # looks healthy
StatsPlots.plot(chn_rd_compiled) # looks not great

using TuringBenchmarking
benchmark_model(demo(); check = true, adbackends=[:forwarddiff, :reversediff, :reversediff_compiled]) # no warnings

@torfjelde
Copy link
Member Author

torfjelde commented Oct 25, 2023

Hmmm, yeah this is strange.

I get the following:

julia> using Random, Turing, ReverseDiff


julia> @model demo() = L ~ LKJCholesky(3, 1.0)
demo (generic function with 2 methods)

julia> Turing.setadbackend(:reversediff)
:reversediff

julia> Turing.setrdcache(false)
false

julia> Random.seed!(1234)
TaskLocalRNG()

julia> chn_rd = sample(demo(), NUTS(), 1000)
┌ Info: Found initial step size
└   ϵ = 0.8125
Sampling 100%|███████████████████████████████████████████████████████████| Time: 0:00:15
Chains MCMC chain (1000×18×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 16.14 seconds
Compute duration  = 16.14 seconds
parameters        = L.L[1,1], L.L[2,1], L.L[3,1], L.L[2,2], L.L[3,2], L.L[3,3]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse    ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64     Float64    Float64   Float64       Float64 

    L.L[1,1]    1.0000    0.0000       NaN         NaN        NaN       NaN           NaN
    L.L[2,1]   -0.0220    0.5100    0.0149   1108.5225   588.9220    0.9999       68.6817
    L.L[3,1]   -0.0018    0.4925    0.0134   1301.7727   754.2240    0.9997       80.6551
    L.L[2,2]    0.8417    0.1770    0.0082    513.2173   506.4421    0.9993       31.7978
    L.L[3,2]    0.0038    0.5100    0.0127   1463.0641   699.8792    1.0006       90.6483
    L.L[3,3]    0.6646    0.2369    0.0111    446.5069   698.9634    1.0020       27.6646

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

    L.L[1,1]    1.0000    1.0000    1.0000    1.0000    1.0000
    L.L[2,1]   -0.8724   -0.4387   -0.0179    0.3828    0.8772
    L.L[3,1]   -0.8948   -0.3826   -0.0162    0.3757    0.8773
    L.L[2,2]    0.3821    0.7527    0.9118    0.9802    0.9997
    L.L[3,2]   -0.8838   -0.4094   -0.0122    0.4311    0.8668
    L.L[3,3]    0.1783    0.4838    0.7020    0.8775    0.9914


julia> Turing.setrdcache(true)
true

julia> Random.seed!(1234)
TaskLocalRNG()

julia> chn_rd_compiled = sample(demo(), NUTS(), 1000)
┌ Info: Found initial step size
└   ϵ = 0.8
Sampling 100%|███████████████████████████████████████████████████████████| Time: 0:00:05
Chains MCMC chain (1000×18×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 5.84 seconds
Compute duration  = 5.84 seconds
parameters        = L.L[1,1], L.L[2,1], L.L[3,1], L.L[2,2], L.L[3,2], L.L[3,3]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

    L.L[1,1]    1.0000    0.0000       NaN        NaN        NaN       NaN           NaN
    L.L[2,1]   -0.0084    0.4925    0.0324   228.6450   338.8731    1.0121       39.1851
    L.L[3,1]   -0.0309    0.5025    0.0258   340.1116   282.3064    1.0081       58.2882
    L.L[2,2]    0.8551    0.1626    0.0089   370.7538   497.9914    1.0130       63.5396
    L.L[3,2]   -0.0147    0.4512    0.0227   392.0160   419.2180    1.0097       67.1836
    L.L[3,3]    0.7074    0.2068    0.0118   340.1865   475.6237    0.9994       58.3010

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

    L.L[1,1]    1.0000    1.0000    1.0000    1.0000    1.0000
    L.L[2,1]   -0.8647   -0.3878   -0.0117    0.3833    0.8483
    L.L[3,1]   -0.8775   -0.4332   -0.0412    0.3793    0.8522
    L.L[2,2]    0.4263    0.7698    0.9232    0.9847    0.9998
    L.L[3,2]   -0.8352   -0.3499   -0.0043    0.3450    0.7960
    L.L[3,3]    0.2748    0.5634    0.7456    0.8834    0.9902

So it seems the resulting parameter estimates are roughly the same but the ESS is different. Also note that the chosen step-size is slightly different.

Compiled ReverseDiff will produce different results if we have if-statements in the computation which relies on the values of the random variables we're inferring, and so I'm wondering if maybe there's a conditional somewhere in the computation graph that is not correctly included.

TuringBenchmarking.jl just checks a single value, hence it might work correctly for that particular value.

julia> using Test, LogDensityProblems, LogDensityProblemsAD

julia> model = demo();

julia> varinfo = DynamicPPL.link(DynamicPPL.VarInfo(model), model);

julia> f = DynamicPPL.LogDensityFunction(model, varinfo);

julia> f_rd = LogDensityProblemsAD.ADgradient(Turing.Essential.ReverseDiffAD{false}(), deepcopy(f));

julia> f_crd = LogDensityProblemsAD.ADgradient(Turing.Essential.ReverseDiffAD{true}(), deepcopy(f));

julia> # Let's check if they're the same on the input we compiled for.
       result_rd = LogDensityProblems.logdensity_and_gradient(f_rd, varinfo[:])
(-2.0957855104582483, [1.4082609718596935, -0.09191446506506586, 0.6847718985562207])

julia> result_crd = LogDensityProblems.logdensity_and_gradient(f_crd, varinfo[:])
(-2.0957855104582483, [1.4082609718596935, -0.09191446506506586, 0.6847718985562207])

julia> @test result_rd[1]  result_crd[1]
Test Passed

julia> @test result_rd[2]  result_crd[2]
Test Passed

julia> # Now with inputs that it was not compiled on.
       d = length(varinfo[:]);

julia> x = rand(d);

julia> result_unseen_rd = LogDensityProblems.logdensity_and_gradient(f_rd, x)
(-2.4459590133007008, [-1.870827529402606, -0.17639514549826948, -0.6335288221522503])

julia> result_unseen_crd = LogDensityProblems.logdensity_and_gradient(f_crd, x)
(-2.4459590133007008, [-0.8316199935369052, -0.1968884416775799, -0.6343288727116665])

julia> @test result_unseen_rd[1]  result_unseen_crd[1]
Test Passed

julia> @test result_unseen_rd[2]  result_unseen_crd[2]
Test Failed at REPL[282]:1
  Expression: result_unseen_rd[2]  result_unseen_crd[2]
   Evaluated: [-1.870827529402606, -0.17639514549826948, -0.6335288221522503]  [-0.8316199935369052, -0.1968884416775799, -0.6343288727116665]

ERROR: There was an error during testing

which immediately fails.

And so this is likely caused by some conditional somewhere.

My immediate question is if this conditional is present in the linking or just the log-prob computation:

julia> varinfo = DynamicPPL.VarInfo(model);  # without linking


julia> f = DynamicPPL.LogDensityFunction(model, varinfo);

julia> f_rd = LogDensityProblemsAD.ADgradient(Turing.Essential.ReverseDiffAD{false}(), deepcopy(f));

julia> f_crd = LogDensityProblemsAD.ADgradient(Turing.Essential.ReverseDiffAD{true}(), deepcopy(f));

julia> # Let's check if they're the same on the input we compiled for.
       result_rd = LogDensityProblems.logdensity_and_gradient(f_rd, varinfo[:])
(-1.6501987419955653, [0.0, 0.0, 0.0, 0.0, 1.0553644429775089, 0.0, 0.0, 0.0, 0.0])

julia> result_crd = LogDensityProblems.logdensity_and_gradient(f_crd, varinfo[:])
(-1.6501987419955653, [0.0, 0.0, 0.0, 0.0, 1.0553644429775089, 0.0, 0.0, 0.0, 0.0])

julia> @test result_rd[1]  result_crd[1]
Test Passed

julia> @test result_rd[2]  result_crd[2]
Test Passed

julia> # Unseen.
       x = DynamicPPL.vectorize(LKJCholesky(3, 1.0), model())
9-element Vector{Float64}:
 1.0
 0.3623020689503904
 0.2946552179906151
 0.0
 0.932060733447272
 0.7153149803851709
 0.0
 0.0
 0.6336424712307923

julia> result_unseen_rd = LogDensityProblems.logdensity_and_gradient(f_rd, x)
(-1.6666698929155281, [0.0, 0.0, 0.0, 0.0, 1.072891458801672, 0.0, 0.0, 0.0, 0.0])

julia> result_unseen_crd = LogDensityProblems.logdensity_and_gradient(f_crd, x)
(-1.6666698929155281, [0.0, 0.0, 0.0, 0.0, 1.072891458801672, 0.0, 0.0, 0.0, 0.0])

julia> @test result_unseen_rd[1]  result_unseen_crd[1]
Test Passed

julia> @test result_unseen_rd[2]  result_unseen_crd[2]
Test Passed

Okay, so this works nicely while the linked version fails, indicating that it's an issue with the linking itself. Will have a look at this.

I'll also make it so that TuringBenchmarking.jl runs the gradient checks on inputs which it is not compiled on so we catch these things too.

Manifest
(jl_w7AqOd) pkg> st --manifest
Status `/tmp/jl_w7AqOd/Manifest.toml`
⌃ [47edcb42] ADTypes v0.1.6
  [621f4979] AbstractFFTs v1.5.0
  [80f14c24] AbstractMCMC v4.5.0
  [7a57a42e] AbstractPPL v0.6.2
  [1520ce14] AbstractTrees v0.4.4
  [79e6a3ab] Adapt v3.7.0
  [0bf59076] AdvancedHMC v0.5.5
  [5b7e9947] AdvancedMH v0.7.5
⌅ [576499cb] AdvancedPS v0.4.3
  [b5ca4192] AdvancedVI v0.2.4
  [dce04be8] ArgCheck v2.3.0
  [4fba245c] ArrayInterface v7.4.11
  [a9b6321e] Atomix v0.1.0
  [13072b0f] AxisAlgorithms v1.0.1
  [39de3d68] AxisArrays v0.4.7
  [198e06fe] BangBang v0.3.39
  [9718e550] Baselet v0.1.1
  [6e4b80f9] BenchmarkTools v1.3.2
  [76274a88] Bijectors v0.13.7
⌅ [fa961155] CEnum v0.4.2
  [49dc2e85] Calculus v0.5.1
  [082447d4] ChainRules v1.56.0
  [d360d2e6] ChainRulesCore v1.18.0
  [9e997f8a] ChangesOfVariables v0.1.8
  [861a8166] Combinatorics v1.0.2
  [38540f10] CommonSolve v0.2.4
  [bbf7d656] CommonSubexpressions v0.3.0
  [34da2185] Compat v4.10.0
  [a33af91c] CompositionsBase v0.1.2
  [88cd18e8] ConsoleProgressMonitor v0.1.2
  [187b0558] ConstructionBase v1.5.4
  [a8cc5b0e] Crayons v4.1.1
  [9a962f9c] DataAPI v1.15.0
  [864edb3b] DataStructures v0.18.15
  [e2d170a0] DataValueInterfaces v1.0.0
  [244e2a9f] DefineSingletons v0.1.2
  [8bb1440f] DelimitedFiles v1.9.1
  [b429d917] DensityInterface v0.4.0
  [163ba53b] DiffResults v1.1.0
  [b552c78f] DiffRules v1.15.1
  [31c24e10] Distributions v0.25.102
  [ced4e74d] DistributionsAD v0.6.53
  [ffbed154] DocStringExtensions v0.9.3
  [fa6b7ba4] DualNumbers v0.6.8
  [366bfd00] DynamicPPL v0.23.20
  [cad2338a] EllipticalSliceSampling v1.1.0
  [4e289a0a] EnumX v1.0.4
  [e2ba6199] ExprTools v0.1.10
  [7a1cc6ca] FFTW v1.7.1
  [1a297f60] FillArrays v1.7.0
  [59287772] Formatting v0.4.2
  [f6369f11] ForwardDiff v0.10.36
  [069b7b12] FunctionWrappers v1.1.3
  [77dc65aa] FunctionWrappersWrappers v0.1.3
  [d9f16b24] Functors v0.4.5
  [0c68f7d7] GPUArrays v9.0.0
  [46192b85] GPUArraysCore v0.1.5
  [34004b35] HypergeometricFunctions v0.3.23
  [7869d1d1] IRTools v0.4.11
  [22cec73e] InitialValues v0.3.1
  [505f98c9] InplaceOps v0.3.0
  [a98d9a8b] Interpolations v0.14.7
  [8197267c] IntervalSets v0.7.8
  [3587e190] InverseFunctions v0.1.12
  [41ab1584] InvertedIndices v1.3.0
  [92d709cd] IrrationalConstants v0.2.2
  [c8e1da08] IterTools v1.8.0
  [82899510] IteratorInterfaceExtensions v1.0.0
  [692b3bcd] JLLWrappers v1.5.0
  [682c06a0] JSON v0.21.4
  [63c18a36] KernelAbstractions v0.9.10
  [5ab0869b] KernelDensity v0.6.7
  [929cbde3] LLVM v6.3.0
  [8ac3fa9e] LRUCache v1.5.0
  [b964fa9f] LaTeXStrings v1.3.0
  [50d2b5c4] Lazy v0.15.1
  [1d6d02ad] LeftChildRightSiblingTrees v0.2.0
  [6f1fad26] Libtask v0.8.6
  [6fdf6af0] LogDensityProblems v2.1.1
  [996a588d] LogDensityProblemsAD v1.6.1
  [2ab3a3ac] LogExpFunctions v0.3.26
  [e6f89c97] LoggingExtras v1.0.3
  [c7f686f2] MCMCChains v6.0.3
  [be115224] MCMCDiagnosticTools v0.3.7
  [e80e1ace] MLJModelInterface v1.9.3
  [1914dd2f] MacroTools v0.5.11
  [dbb5928d] MappedArrays v0.4.2
  [128add7d] MicroCollections v0.1.4
  [e1d29d7a] Missings v1.1.0
  [872c559c] NNlib v0.9.7
  [77ba4419] NaNMath v1.0.2
  [86f7a689] NamedArrays v0.10.0
  [c020b1a1] NaturalSort v1.0.0
  [6fe1bfb0] OffsetArrays v1.12.10
  [3bd65402] Optimisers v0.3.1
  [bac558e1] OrderedCollections v1.6.2
  [90014a1f] PDMats v0.11.28
  [69de0a69] Parsers v2.7.2
  [aea7be01] PrecompileTools v1.2.0
  [21216c6a] Preferences v1.4.1
  [08abe8d2] PrettyTables v2.2.8
  [33c8b6b6] ProgressLogging v0.1.4
  [92933f4c] ProgressMeter v1.9.0
  [1fd47b50] QuadGK v2.9.1
  [74087812] Random123 v1.6.1
  [e6cf234a] RandomNumbers v1.5.3
  [b3c3ace0] RangeArrays v0.3.2
  [c84ed2f1] Ratios v0.4.5
  [c1ae055f] RealDot v0.1.0
  [3cdcf5f2] RecipesBase v1.3.4
  [731186ca] RecursiveArrayTools v2.38.10
  [189a3867] Reexport v1.2.2
  [ae029012] Requires v1.3.0
  [37e2e3b7] ReverseDiff v1.15.1
  [79098fc4] Rmath v0.7.1
  [f2b01f46] Roots v2.0.20
  [7e49a35a] RuntimeGeneratedFunctions v0.5.12
  [0bca4576] SciMLBase v2.4.3
  [c0aeaf25] SciMLOperators v0.3.6
  [30f210dd] ScientificTypesBase v3.0.0
  [efcf1570] Setfield v1.1.1
  [ce78b400] SimpleUnPack v1.1.0
  [a2af1166] SortingAlgorithms v1.2.0
  [dc90abb0] SparseInverseSubset v0.1.1
  [276daf66] SpecialFunctions v2.3.1
  [171d559e] SplittablesBase v0.1.15
  [90137ffa] StaticArrays v1.6.5
  [1e83bf80] StaticArraysCore v1.4.2
  [64bff920] StatisticalTraits v3.2.0
  [82ae8749] StatsAPI v1.7.0
  [2913bbd2] StatsBase v0.34.2
  [4c63d2b9] StatsFuns v1.3.0
  [892a3eda] StringManipulation v0.3.4
  [09ab397b] StructArrays v0.6.16
  [2efcf032] SymbolicIndexingInterface v0.2.2
  [3783bdb8] TableTraits v1.0.1
  [bd369af6] Tables v1.11.1
  [5d786b92] TerminalLoggers v0.1.7
  [9f7883ad] Tracker v0.2.27
  [28d57a85] Transducers v0.4.78
  [410a4b4d] Tricks v0.1.8
  [781d530d] TruncatedStacktraces v1.4.0
  [fce5fe82] Turing v0.29.3
  [0db1332d] TuringBenchmarking v0.3.3
  [013be700] UnsafeAtomics v0.2.1
  [d80eeb9a] UnsafeAtomicsLLVM v0.1.3
  [efce3f68] WoodburyMatrices v0.5.5
  [e88e6eb3] Zygote v0.6.66
  [700de1a5] ZygoteRules v0.2.3
  [f5851436] FFTW_jll v3.3.10+0
  [1d5cc7b8] IntelOpenMP_jll v2023.2.0+0
  [dad2f222] LLVMExtra_jll v0.0.26+0
  [856f044c] MKL_jll v2023.2.0+0
  [efe28fd5] OpenSpecFun_jll v0.5.5+0
  [f50d1b31] Rmath_jll v0.4.0+0
  [0dad84c5] ArgTools v1.1.1
  [56f22d72] Artifacts
  [2a0f44e3] Base64
  [ade2ca70] Dates
  [8ba89e20] Distributed
  [f43a241f] Downloads v1.6.0
  [7b1f6079] FileWatching
  [9fa8497b] Future
  [b77e0a4c] InteractiveUtils
  [4af54fe1] LazyArtifacts
  [b27032c2] LibCURL v0.6.3
  [76f85450] LibGit2
  [8f399da3] Libdl
  [37e2e46d] LinearAlgebra
  [56ddb016] Logging
  [d6f4376e] Markdown
  [a63ad114] Mmap
  [ca575930] NetworkOptions v1.2.0
  [44cfe95a] Pkg v1.9.2
  [de0858da] Printf
  [9abbd945] Profile
  [3fa0cd96] REPL
  [9a3f8284] Random
  [ea8e919c] SHA v0.7.0
  [9e88b42a] Serialization
  [1a1011a3] SharedArrays
  [6462fe0b] Sockets
  [2f01184e] SparseArrays
  [10745b16] Statistics v1.9.0
  [4607b0f0] SuiteSparse
  [fa267f1f] TOML v1.0.3
  [a4e569a6] Tar v1.10.0
  [8dfed614] Test
  [cf7118a7] UUIDs
  [4ec0a83e] Unicode
  [e66e0078] CompilerSupportLibraries_jll v1.0.5+0
  [deac9b47] LibCURL_jll v7.84.0+0
  [29816b5a] LibSSH2_jll v1.10.2+0
  [c8ffd9c3] MbedTLS_jll v2.28.2+0
  [14a3606d] MozillaCACerts_jll v2022.10.11
  [4536629a] OpenBLAS_jll v0.3.21+4
  [05823500] OpenLibm_jll v0.8.1+0
  [bea87d4a] SuiteSparse_jll v5.10.1+6
  [83775a58] Zlib_jll v1.2.13+0
  [8e850b90] libblastrampoline_jll v5.8.0+0
  [8e850ede] nghttp2_jll v1.48.0+0
  [3f19e933] p7zip_jll v17.4.0+0
Info Packages marked with ⌃ and ⌅ have new versions available, but those with ⌅ are restricted by
 compatibility constraints from upgrading. To see why use `status --outdated -m`

@torfjelde torfjelde reopened this Oct 25, 2023
@torfjelde
Copy link
Member Author

torfjelde commented Oct 25, 2023

Ah, I thought we had chainrules for everything relating to the transformation of LKJChol, but seems we don't for the log-absdet-jacobian. That is, we need a rule for https://github.com/TuringLang/Bijectors.jl/blob/04b79dd46eca8cea2f988348c47bd5e720a2b9a4/src/bijectors/corr.jl#L410-L427

Though I'm a bit uncertain why this would cause issues..

@torfjelde
Copy link
Member Author

Hmm, this is very strange: I can't seem to reproduce the issue when looking just at the transformation.

julia> using Turing, ReverseDiff

julia> dist = LKJCholesky(3, 1.0)

LKJCholesky{Float64}(
d: 3
η: 1.0
uplo: L
)


julia> b = bijector(dist)
Bijectors.VecCholeskyBijector(:L)

julia> binv = inverse(b)
Inverse{Bijectors.VecCholeskyBijector}(Bijectors.VecCholeskyBijector(:L))

julia> x = rand(dist)
LinearAlgebra.Cholesky{Float64, Matrix{Float64}}
L factor:
3×3 LinearAlgebra.LowerTriangular{Float64, Matrix{Float64}}:
 1.0                  
 0.10674   0.994287    
 0.390942  0.661192  0.640304

julia> y = b(x)
3-element Vector{Float64}:
 0.10714823845189496
 0.4129118569882736
 0.9042537275936676

julia> function f(y)
           x, logjac = with_logabsdet_jacobian(binv, y)
           return logpdf(dist, x) + logjac
       end
f (generic function with 1 method)

julia> # ReverseDiff.
       f_tape = ReverseDiff.GradientTape(f, (y,))
typename(ReverseDiff.GradientTape)(f)

julia> f_tape_compiled = ReverseDiff.compile(f_tape)
typename(ReverseDiff.CompiledTape)(f)

julia> inputs = (y,);

julia> buffers = (DiffResults.GradientResult(similar(y)),);

julia> cfg = ReverseDiff.GradientConfig(inputs);

julia> ReverseDiff.gradient!(buffers, f_tape, inputs)
(MutableDiffResult(-2.588055647847676, ([-0.32022019679930697, -1.1728269159715612, -1.4367255598238464],)),)

julia> ReverseDiff.gradient!(buffers, f_tape_compiled, inputs)
(MutableDiffResult(-2.588055647847676, ([-0.32022019679930697, -1.1728269159715612, -1.4367255598238464],)),)

julia> # New inputs.
       inputs = (randn(length(y)),);

julia> buffers = (DiffResults.GradientResult(similar(y)),);

julia> cfg = ReverseDiff.GradientConfig(inputs);

julia> ReverseDiff.gradient!(buffers, f_tape, inputs)
(MutableDiffResult(-3.7413960226614016, ([-0.15691122445568645, 8.57759603199995, -1.7830917058959441],)),)

julia> ReverseDiff.gradient!(buffers, f_tape_compiled, inputs)
(MutableDiffResult(-3.7413960226614016, ([-0.15691122445568645, 8.57759603199995, -1.7830917058959441],)),)

@tiemvanderdeure
Copy link

Any news on a fix for this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants