Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient broken for (*)(::Diagonal{Real}, ::Matrix{Complex}, ::Diagonal{Real}) when updating Julia 1.8 -> 1.9 #1483

Open
kylebeggs opened this issue Dec 21, 2023 · 6 comments
Labels
bug Something isn't working ChainRules adjoint -> rrule, and further integration

Comments

@kylebeggs
Copy link

gradient breaks when triple multiplying a Diagonal{<:Real}, Matrix{<:Complex}, and Diagonal{Real}. This breaks going from Julia 1.8 -> 1.9.

MWE:

using LinearAlgebra
using Zygote

D = Diagonal(rand(3))
Ac = rand(ComplexF64, 3, 3)
Ar = rand(Float64, 3, 3)

f_real(x) = abs(sum(Diagonal(x) * Ar * D))
f_complex(x) = abs(sum(Diagonal(x) * Ac * D))

g_real = gradient(f_real, rand(3)) # works
g_complex = gradient(f_complex, rand(3)) # breaks, error message below

Error message:

ERROR: MethodError: no method matching _mul_partials(::ForwardDiff.Partials{3, Float64}, ::ForwardDiff.Partials{6, Float64}, ::Float64, ::Float64)

Closest candidates are:
  _mul_partials(::ForwardDiff.Partials{N, A}, ::ForwardDiff.Partials{0, B}, ::Any, ::Any) where {N, A, B}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/partials.jl:142
  _mul_partials(::ForwardDiff.Partials{0, A}, ::ForwardDiff.Partials{N, B}, ::Any, ::Any) where {N, A, B}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/partials.jl:141
  _mul_partials(::ForwardDiff.Partials{N}, ::ForwardDiff.Partials{N}, ::Any, ::Any) where N
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/partials.jl:118
  ...

Stacktrace:
  [1] dual_definition_retval(::Val{…}, val::Float64, deriv1::Float64, partial1::ForwardDiff.Partials{…}, deriv2::Float64, partial2::ForwardDiff.Partials{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:203
  [2] *
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:271 [inlined]
  [3] *(x::ForwardDiff.Dual{Nothing, Float64, 3}, z::Complex{ForwardDiff.Dual{Nothing, Float64, 6}})
    @ Base ./complex.jl:339
  [4] *
    @ ./operators.jl:587 [inlined]
  [5] (::Zygote.var"#1388#1389"{typeof(*)})(::Float64, ::ComplexF64, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/lib/broadcast.jl:276
  [6] _broadcast_getindex_evalf
    @ ./broadcast.jl:709 [inlined]
  [7] _broadcast_getindex
    @ ./broadcast.jl:682 [inlined]
  [8] getindex
    @ ./broadcast.jl:636 [inlined]
  [9] copy
    @ ./broadcast.jl:942 [inlined]
 [10] materialize
    @ ./broadcast.jl:903 [inlined]
 [11] broadcast_forward
    @ ~/.julia/packages/Zygote/WOy6z/src/lib/broadcast.jl:282 [inlined]
 [12] _broadcast_generic
    @ ~/.julia/packages/Zygote/WOy6z/src/lib/broadcast.jl:212 [inlined]
 [13] adjoint
    @ ~/.julia/packages/Zygote/WOy6z/src/lib/broadcast.jl:169 [inlined]
 [14] _pullback
    @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
 [15] adjoint
    @ ~/.julia/packages/Zygote/WOy6z/src/lib/broadcast.jl:245 [inlined]
 [16] _pullback
    @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
 [17] *
    @ ~/.julia/juliaup/julia-1.10.0-rc3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:409 [inlined]
 [18] _pullback(::Zygote.Context{…}, ::typeof(*), ::Diagonal{…}, ::Matrix{…}, ::Diagonal{…})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
 [19] f_complex
    @ ~/.julia/dev/SMOL/dev/update-julia.jl:9 [inlined]
 [20] _pullback(ctx::Zygote.Context{false}, f::typeof(f_complex), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
 [21] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:44
 [22] pullback
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:42 [inlined]
 [23] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:96
 [24] top-level scope
    @ ~/.julia/dev/SMOL/dev/update-julia.jl:12
Some type information was truncated. Use `show(err)` to see complete types.

Versions:

julia> versioninfo()
Julia Version 1.10.0-rc3
Commit ed79752b939 (2023-12-18 09:57 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 20 × 12th Gen Intel(R) Core(TM) i7-12700H
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, alderlake)
  Threads: 11 on 20 virtual cores
Environment:
  LD_PRELOAD = /usr/lib/x86_64-linux-gnu/libstdc++.so.6
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 8

Package versions:

(jl_w76chw) pkg> st
Status `/tmp/jl_w76chw/Project.toml`
  [e88e6eb3] Zygote v0.6.68
@ToucheSir
Copy link
Member

Do you mind testing the same function but with ForwardDiff instead of Zygote on 1.8/1.9? Zygote's broadcasting code isn't doing anything different between versions, so I wonder if changes in the stdlib or ForwardDiff are leading to a different path being taken depending on the version.

@kylebeggs
Copy link
Author

using ForwardDiff works. FWIW I quickly checked the code in LinearAlgebra/src/diagonal.jl and it seems it changed from 1.8 -> 1.9.

@ToucheSir
Copy link
Member

The blame points to JuliaLang/julia#46400. I imagine this was hitting the 2-arg * rule in ChainRules before and thus worked without any extra intervention required. What I don't know is whether this is better solved in ChainRules (with a new rrule) or in Zygote, mainly because I don't understand why this is causing a mismatched number of partials in Zygote's Dual broadcasting path. @mcabbott do you have any thoughts on this?

@kylebeggs
Copy link
Author

kylebeggs commented Dec 26, 2023

I have the same thought process. Personally, I'd think we should write a new rrule, but I'm a novice with autodiff so take that with a grain of salt.

Edit - Please let me know if there is anything I can do to speed up fixing this (the way you guys wish, I'm going to try writing a rrule on a fork for now to fix it for myself).

@mcabbott
Copy link
Member

mcabbott commented Dec 26, 2023

Agree that JuliaLang/julia#46400 must be what's new, previously this would have gone on to pairwise *. But the new path with broadcasting looks like it ought to work. Attempts to isolate further what the problem is:

julia> gradient(x -> sum(abs2, x .* Ac .* x), [0.1, 0.2, 0.3])  # fine?
([0.008354213288778839, 0.04584256955208588, 0.27367634356707043],)

julia> gradient(x -> sum(abs2, broadcast(*, x, Ac, x)), [0.1, 0.2, 0.3])  # same error
ERROR: MethodError: no method matching _mul_partials(::ForwardDiff.Partials{3, Float64}, ::ForwardDiff.Partials{6, Float64}, ::Float64, ::Float64)

julia> using ForwardDiff

julia> ForwardDiff.gradient(x -> sum(abs2, broadcast(*, x, Ac, x)), [0.1, 0.2, 0.3])  # fine, same result
3-element Vector{Float64}:
 0.008354213288778839
 0.04584256955208588
 0.27367634356707043

julia> gradient(x -> sum(abs2, broadcast(*, x, Ar, x)), [0.1, 0.2, 0.3])  # all real is fine, as above
([0.006549471072183613, 0.017877747316265458, 0.05009079087702957],)

julia> gradient(x -> sum(abs2, broadcast(*, x*im, Ac, x*im)), [0.1, 0.2, 0.3])  # all complex also fine?
([0.008354213288778839, 0.04584256955208588, 0.27367634356707043],)

I do think this probably points to Zygote's treatment of broadcasting with complex numbers.

There are special rules for broadcasting *, I believe only this one treats more than 2 arguments. But it's on broadcasted not broadcast, certainly that's called by .*, maybe it's not called here.

When there is no special broadcasting rule, the generic one here tries to use Dual numbers before giving up and eventually saving all the Zygote scalar pullbacks. The upgrade to use Complex{Dual} was I think #1324 , and it's possible that this mismatch of 6 + 3 partials comes from some bug in that? The code is re-worked in #1441, not merged yet, but it might be worth trying that to see if anything changed.

Attempting to trigger such a broadcasting bug, without *, I get a different error:

julia> Zygote.Numeric
Union{AbstractArray{<:T}, T} where T<:Number

julia> gradient(x -> sum(abs2, broadcast(+, x, Ac, x)), [0.1, 0.2, 0.3])  # 3-arg + is ok
([6.217191396508677, 11.013318686350626, 11.801770970288118],)

julia> gradient(x -> sum(abs2, broadcast((a,b,c) -> (a/b+c), x, Ac, x)), [0.1, 0.2, 0.3])
ERROR: Cannot determine ordering of Dual tags Nothing and Nothing
Stacktrace:
  [1] (a::Type, b::Type)
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:54
  [2] promote_rule(::Type{Dual{Nothing, Float64, 6}}, ::Type{Dual{Nothing, Float64, 3}})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:407

julia> gradient(x -> sum(abs2, broadcast((a,b,c) -> (a/b+c), x, Ar, x)), [0.1, 0.2, 0.3])  # all real
([25.238201775657366, 49.06406347939432, 39.42294666177911],)

julia> gradient(x -> sum(abs2, broadcast((a,b,c) -> (a/b+c), x*im, Ac, x*im)), [0.1, 0.2, 0.3])  # all complex
([2.799189365814652, 15.391840124749505, 6.102148153348823],)

Aside from looking for bugs in Zygote's broadcasting, it would not be crazy to have a rule for this 3-arg * method in ChainRules. Even if the above Dual broadcasting worked, this could probably be more efficient. (Earlier 3-arg * rules were added in JuliaDiff/ChainRules.jl#412.)

@mcabbott mcabbott added bug Something isn't working ChainRules adjoint -> rrule, and further integration labels Dec 26, 2023
@kylebeggs
Copy link
Author

FWIW, this temporary workaround fixed my issue. I'm sure it is not ideal (I'm a beginner with writing rrules), but I can work on polishing this and make a PR in a week or two. I know I should add some @thunks etc. in there.

_3arg_mul(A::Diagonal, B::AbstractMatrix{<:Complex}, C::Diagonal) = A * B * C
function ChainRulesCore.rrule(
    ::typeof(_3arg_mul), A::Diagonal, B::AbstractMatrix{<:Complex}, C::Diagonal
)
    project_A = ProjectTo(A)
    project_B = ProjectTo(B)
    project_C = ProjectTo(C)
    function _3arg_mul_pullback(ȳ)
        dA =* (B * C)'
        dB = A' ** C'
        dC = (A * B)' *return NoTangent(), project_A(dA), project_B(dB), project_C(dC)
    end
    return A * B * C, _3arg_mul_pullback
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ChainRules adjoint -> rrule, and further integration
Projects
None yet
Development

No branches or pull requests

3 participants