Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing support for muladd in case of brodcasting with a complex argument #1461

Open
rcalxrc08 opened this issue Oct 7, 2023 · 1 comment
Labels
bug Something isn't working

Comments

@rcalxrc08
Copy link

Hi all,

I noticed the following when I combine complex numbers muladd and forward mode (I think it is forward mode because I am broadcasting some function over a vector).
I am using julia 1.9.3:

Julia Version 1.9.3
Commit bed2cd540a (2023-08-24 14:43 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 8 × Intel(R) Core(TM) i7-8550U CPU @ 1.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, skylake)
  Threads: 1 on 8 virtual cores

And Zygote v0.6.64.
The MWE is the following:

using Zygote
function f_no_muladd(x,a,b)
	vec_res=@. real(a*exp(x+b*im))
	return sum(vec_res)
end

function f_muladd(x,a,b)
	vec_res=@. real(a*exp(muladd(b,im,x)))
	return sum(vec_res)
end
x=ones(Float64,10);
a=1.0;
b=2.0;
Zygote.gradient(f_no_muladd,x,a,b) #completely fine
Zygote.gradient(f_muladd,x,a,b) # This call fails

The error is actually in ForwardDiff.jl:

ERROR: MethodError: no method matching calc_muladd_xyz(::ForwardDiff.Dual{Nothing, Bool, 6}, ::ForwardDiff.Dual{Nothing, Float64, 3}, ::ForwardDiff.Dual{Nothing, Float64, 3})

Closest candidates are:
  calc_muladd_xyz(::ForwardDiff.Dual{T, <:Any, N}, ::ForwardDiff.Dual{T, <:Any, N}, ::ForwardDiff.Dual{T, <:Any, N}) where {T, N}
   @ ForwardDiff C:\Users\Nicola\.julia\packages\ForwardDiff\PcZ48\src\dual.jl:637

Stacktrace:
  [1] muladd
    @ C:\Users\Nicola\.julia\packages\ForwardDiff\PcZ48\src\dual.jl:155 [inlined]
  [2] muladd(z::Complex{ForwardDiff.Dual{Nothing, Bool, 6}}, x::ForwardDiff.Dual{Nothing, Float64, 3}, y::ForwardDiff.Dual{Nothing, Float64, 3})
    @ Base .\complex.jl:340
  [3] muladd(x::ForwardDiff.Dual{Nothing, Float64, 3}, z::Complex{ForwardDiff.Dual{Nothing, Bool, 6}}, y::ForwardDiff.Dual{Nothing, Float64, 3})
    @ Base .\complex.jl:339
  [4] (::Zygote.var"#1404#1405"{typeof(muladd)})(::Float64, ::Complex{Bool}, ::Float64)
    @ Zygote C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\lib\broadcast.jl:276
  [5] _broadcast_getindex_evalf
    @ .\broadcast.jl:683 [inlined]
  [6] _broadcast_getindex
    @ .\broadcast.jl:656 [inlined]
  [7] getindex
    @ .\broadcast.jl:610 [inlined]
  [8] copy
    @ .\broadcast.jl:912 [inlined]
  [9] materialize
    @ .\broadcast.jl:873 [inlined]
 [10] broadcast_forward
    @ C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\lib\broadcast.jl:282 [inlined]
 [11] _broadcast_generic
    @ C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\lib\broadcast.jl:212 [inlined]
 [12] adjoint
    @ C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\lib\broadcast.jl:205 [inlined]
 [13] _pullback(::Zygote.Context{false}, ::typeof(Base.Broadcast.broadcasted), ::Base.Broadcast.DefaultArrayStyle{1}, ::typeof(muladd), ::Float64, ::Complex{Bool}, ::Vector{Float64})
    @ Zygote C:\Users\Nicola\.julia\packages\ZygoteRules\OgCVT\src\adjoint.jl:66
 [14] _apply(::Function, ::Vararg{Any})
    @ Core .\boot.jl:838
 [15] adjoint
    @ C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\lib\lib.jl:203 [inlined]
 [16] _pullback
    @ C:\Users\Nicola\.julia\packages\ZygoteRules\OgCVT\src\adjoint.jl:66 [inlined]
 [17] _pullback
    @ .\broadcast.jl:1317 [inlined]
 [18] _pullback
    @ .\REPL[3]:2 [inlined]
 [19] _pullback(::Zygote.Context{false}, ::typeof(f_muladd), ::Vector{Float64}, ::Float64, ::Float64)
    @ Zygote C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\compiler\interface2.jl:0
 [20] pullback(::Function, ::Zygote.Context{false}, ::Vector{Float64}, ::Vararg{Any})
    @ Zygote C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\compiler\interface.jl:44
 [21] pullback(::Function, ::Vector{Float64}, ::Float64, ::Vararg{Float64})
    @ Zygote C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\compiler\interface.jl:42
 [22] gradient(::Function, ::Vector{Float64}, ::Vararg{Any})
    @ Zygote C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\compiler\interface.jl:96
 [23] top-level scope
    @ REPL[8]:1
@mcabbott
Copy link
Member

mcabbott commented Oct 7, 2023

Notice that the error shows Duals with 6 and 3 partials together, which doesn't make sense for ForwardDiff:

julia> using ForwardDiff: Dual

julia> muladd(Dual(1,2), Dual(3,4), Dual(5,6))
Dual{Nothing}(8,16)

julia> muladd(Dual(1,2,0), Dual(3,4), Dual(5,6))
ERROR: MethodError: no method matching calc_muladd_xyz(::Dual{Nothing, Int64, 2}, ::Dual{Nothing, Int64, 1}, ::Dual{Nothing, Int64, 1})

So the bug is here somehow?

The use of Dual for complex-number broadcasting was added in #1324, would be worth checking whether #1441 changes anything.

@mcabbott mcabbott added the bug Something isn't working label Oct 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants