Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote gradients different from ForwardDiff/ReverseDiff on Julia 1.10-rc2 #1478

Open
SaremS opened this issue Dec 12, 2023 · 3 comments
Open

Comments

@SaremS
Copy link

SaremS commented Dec 12, 2023

Hi, hope this suffices as an MWE:

using Pkg
packages = [
    Pkg.PackageSpec(;name="ForwardDiff", version="0.10.36")
    Pkg.PackageSpec(;name="ReverseDiff", version="1.15.1"),
    Pkg.PackageSpec(;name="Zygote", version="0.6.67")
    Pkg.PackageSpec(;name="KernelFunctions", version="0.10.60")
    Pkg.PackageSpec(;name="Distributions", version="0.25.104")
]
Pkg.add(packages)


using ForwardDiff, ReverseDiff, Zygote, Distributions, KernelFunctions

#Define kernel function (periodic + white noise)
kernel(l,s) = with_lengthscale(SqExponentialKernel(), l^2) ∘ PeriodicTransform(1/365) + ScaledKernel(WhiteKernel(),s^2)

#Create data deterministically
m = collect(-0.9:0.1:1)

#Differentiate likelihood for data sample with respect to kernel hyperparameters
#ForwardDiff
println(ForwardDiff.gradient(x->logpdf(MvNormal(zeros(20),kernelmatrix(kernel(x[1],x[2]),collect(1:18:360))),m),[1.,0.1]))

#ReverseDiff
println(ReverseDiff.gradient(x->logpdf(MvNormal(zeros(20),kernelmatrix(kernel(x[1],x[2]),collect(1:18:360))),m),[1.,0.1]))

#Zygote
println(Zygote.gradient(x->logpdf(MvNormal(zeros(20),kernelmatrix(kernel(x[1],x[2]),collect(1:18:360))),m),[1.,0.1]))

Outputs are as follows on my machine, using Julia 1.10-rc2:

ForwardDiff: [-52.862449903127434, 403.6043237529404]
ReverseDiff: [-52.86244990312515, 403.6043237529402]

Zygote: ([23.812852743170346, -114.3874772277085],)

Let me know if you need anything else.

@ToucheSir
Copy link
Member

My understanding is that KernelFunctions defines their own ChainRules, so a lot of the code in question will be on that side. Have you raised this issue with them? They may be able to offer a more informed opinion on what's going on.

@SaremS
Copy link
Author

SaremS commented Dec 15, 2023

Thank you, that would probably explain it. My understanding was that a ChainRule is applied equivalently in all autodiff packages, but that must have been wrong then.

Do you recommend closing this issue then, or shall I leave it open?

@ToucheSir
Copy link
Member

Out of the 3 ADs you tested, it's likely only Zygote is using any ChainRules here. Feel free to leave this open and link to it when you're creating issues in other repos so we have a trail of breadcrumbs to follow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants