Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use ForwardDiff.jacobian in place of Zygote.forward_jacobian #1468

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

vpuri3
Copy link

@vpuri3 vpuri3 commented Oct 26, 2023

Pursuant to #1270

@vpuri3
Copy link
Author

vpuri3 commented Oct 26, 2023

There are some enzyme related errors in NNlib integration tests but they seem unrelated to this PR.

@vpuri3
Copy link
Author

vpuri3 commented Oct 26, 2023

@ToucheSir, LMK I need to add more tests.

here's a working MWE with Lux. This also resolves #1348

with the change in this PR, this code is working:

using Random
using Lux, CUDA, LuxCUDA, ComponentArrays
using Zygote, ForwardDiff

CUDA.allowscalar(false)

#==========================#
function testhessian(
    NN::Lux.AbstractExplicitLayer,
    data::Tuple;
    device = cpu_device(),
)
    p, st = Lux.setup(Random.default_rng(), NN)

    st = Lux.testmode(st)
    p = ComponentArray(p)

    xdata, ydata = data |> device
    p, st = (p, st)     |> device

    function loss(optx)
        ypred, _ = NN(xdata, optx, st)

        sum(abs2, ydata - ypred)
    end

    g(p) = Zygote.gradient(loss, p)[1]
    H(p) = ForwardDiff.jacobian(g, p)

    Zygote.hessian(loss, p)
end
#==========================#
NN = Chain(Dense(1, 3), Dense(3, 1))

data = ntuple(_ -> rand(1, 10), 2)
device = Lux.gpu_device()

H = testhessian(NN, data; device)
julia> include("hess.jl")
10×10 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
  0.236781  -0.075257    -1.20583    0.31846   -0.101217    -1.62179   -0.713834    0.503548  -1.14138     1.98508
 -0.075257   0.0239192    0.383253  -0.101217   0.0321702    0.515458   0.0296168  -0.780695   0.362769   -0.630924
 -1.20583    0.383253     6.1408    -1.62179    0.515458     8.2591     0.474545   -2.56436    5.19194   -10.1092
  0.318461  -0.101217    -1.62179    0.514738  -0.163601    -2.62135   -2.09317     0.677249  -1.53511     3.20854
 -0.101217   0.0321702    0.515458  -0.163601   0.0519977    0.833151   0.0398333  -2.18309    0.487909   -1.01978
 -1.62179    0.515458     8.2591    -2.62135    0.833151    13.3494     0.638242   -3.44895    5.84984   -16.3398
 -0.713834   0.0296168    0.474545  -2.09317    0.0398333    0.638242   0.0366717  -0.198167   0.449183   -0.781213
  0.503548  -0.780695    -2.56436    0.677249  -2.18309     -3.44895   -0.198167    1.07086   -2.4273      4.22154
 -1.14138    0.362769     5.19194   -1.53511    0.487909     5.84984    0.449183   -2.4273     5.50193    -9.56889
  1.98508   -0.630924   -10.1092     3.20854   -1.01978    -16.3398    -0.781213    4.22154   -9.56889    20.0
(hess) pkg> st
Status `~/.julia/dev/GeometryLearning.jl/hess/Project.toml`
  [052768ef] CUDA v5.0.0
  [b0b7db55] ComponentArrays v0.15.4
  [f6369f11] ForwardDiff v0.10.36
  [b2108857] Lux v0.5.8
  [d0bbae9a] LuxCUDA v0.3.1
  [e88e6eb3] Zygote v0.6.67 `~/.julia/dev/Zygote`

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused. This looks like the same change as #1270, just with no tests? My comment at #1270 (comment) and @mcabbott's at #1270 (comment) still very much apply, so those need to be addressed.

@aksuhton
Copy link

Did this ever reach a conclusion? I'm in need of the ability to take the jacobian with respect to the inputs of a (Lux) model output and then optimize that object using gradient descent updates on the (Lux) model parameters. Something like the following

using Lux, CUDA, LuxCUDA, ComponentArrays
using Zygote #https://github.com/vpuri3/Zygote.jl/tree/fwd
using ForwardDiff
using LinearAlgebra

CUDA.allowscalar(false)
## Setup
L = 5
bs = 3
m = Chain(Dense(L, L), relu, Dense(L, L))
ps, st = Lux.setup(Random.default_rng(), m)
dev = Lux.gpu_device()
ps = ComponentArray(ps) |> dev
x = randn(Float32, L, bs) |> dev
y = randn(Float32, L, bs) |> dev
## Forward
function getpred(x, m, ps, st)
    function getpotential(x)
        return first(m(x, ps, st))
    end
    pred = reshape(diag(ForwardDiff.jacobian(getpotential, x)), size(x)...)
    return pred
end
pred = getpred(x, m, ps, st)
## Backward
function getgrads(x, y, m, ps, st)
    gs = Zygote.gradient(p -> mse(getpred(x, m, p, st), y), ps)
    return gs
end
gs = getgrads(x, y, m, ps, st) # returns (nothing,)

Or should I be looking towards JAX for this sort of thing? The use case is thermodynamics.

@ToucheSir
Copy link
Member

That's a better question for the SciML/Lux help channels, not this issue tracker.

@mcabbott
Copy link
Member

This PR changes the implementation used internally for FwdDiff-over-Zygote. It didn't get much attention as it was a little unclear what this solves -- see requests above for tests which fail before the change.

Your example wants to do Zygote-over-ForwardDiff, which won't work, and would not be changed by this PR.

(Zygote has a rule for ForwardDiff.jacobian(f, x) which was probably a bad idea, and translates it to Fwd-over-Fwd. It should complain loudly when f closes over parameters, as it cannot work out the derivative with respect to f.)

@mcabbott mcabbott added the second order zygote over zygote, or otherwise label Feb 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
second order zygote over zygote, or otherwise
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants