Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confusing error / silent failure with broadcasted functions with type instability #1439

Open
DomCRose opened this issue Jul 10, 2023 · 3 comments · May be fixed by #1441
Open

Confusing error / silent failure with broadcasted functions with type instability #1439

DomCRose opened this issue Jul 10, 2023 · 3 comments · May be fixed by #1441
Labels
bug Something isn't working

Comments

@DomCRose
Copy link
Contributor

When a function is broadcasted which is type unstable with Dual type inputs, there is a good chance the element type of the resulting output will be abstract, leading to a failure of the logic at

T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing)
(thanks to @ToucheSir who helped debug this). This can then cause an error much later on than the origin of the Dual, after Duals leak into the pullback definition and e.g. the gradient of the output is pulled onto the gradient of the input which assumed a non-Dual eltype, making it confusing to debug. Perhaps even worse, in some cases it causes the gradient to fail silently, either returning nothing or Duals for the gradient.

A MWE of silent failure on 1.9.0, in a temporary environment with only Zygote:

using Zygote
f(x) = x > 1.0 ? 1.0 : x^2
g(x) = sum(f.(x))
gradient(g, collect(0.5:0.25:1.5)) # (nothing,)

In contrast with the expected behaviour of:

using Zygote
f(x) = x > 1.0 ? one(x) : x^2
g(x) = sum(f.(x))
gradient(g, collect(0.5:0.25:1.5)) # ([1.0, 1.5, 2.0, 0.0, 0.0],)

A MWE of error, using repeat with the inner keyword as an example which doesn't allow the Dual to leak:

using Zygote
f(x) = x > 1.0 ? 1.0 : x^2
g(x) = sum(repeat(x, inner=2) .* f.(repeat(x, inner=2)))
gradient(g, collect(0.5:0.25:1.5))

results in:

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{Nothing, Float64, 1})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:207
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:792
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number}
   @ Base char.jl:50
  ...

Stacktrace:
 [1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{Nothing, Float64, 1})
   @ Base .\number.jl:7
 [2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{Nothing, Float64, 1}, i1::Int64)
   @ Base .\array.jl:969
 [3] (::Zygote.var"#626#634"{Int64, Vector{Float64}})(Δ::Vector{ForwardDiff.Dual{Nothing, Float64, 1}})
   @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\lib\array.jl:137
 [4] (::Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}})(Δ::Vector{ForwardDiff.Dual{Nothing, Float64, 1}})
   @ Zygote C:\Users\domin\.julia\packages\ZygoteRules\OgCVT\src\adjoint.jl:80
 [5] Pullback
   @ .\REPL[29]:1 [inlined]
 [6] (::Zygote.Pullback{Tuple{typeof(g), Vector{Float64}}, Tuple{Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Vector{Real}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Real}}, Tuple{}}, Zygote.var"#3802#back#1201"{Zygote.var"#1197#1200"{Vector{Float64}, Vector{Real}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(f), Vector{Float64}}, Tuple{Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4123#back#1356"{Zygote.var"#1386#1388"}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2881#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}, Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1162"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2017#back#200"{typeof(identity)}}}}})(Δ::ForwardDiff.Dual{Nothing, Float64, 1})
   @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface2.jl:0
 [7] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(g), Vector{Float64}}, Tuple{Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Vector{Real}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Real}}, Tuple{}}, Zygote.var"#3802#back#1201"{Zygote.var"#1197#1200"{Vector{Float64}, Vector{Real}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(f), Vector{Float64}}, Tuple{Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4123#back#1356"{Zygote.var"#1386#1388"}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2881#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}, Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1162"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2017#back#200"{typeof(identity)}}}}}})(Δ::ForwardDiff.Dual{Nothing, Float64, 1})
   @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface.jl:45
 [8] gradient(f::Function, args::Vector{Float64})
   @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface.jl:97
 [9] top-level scope
   @ REPL[30]:1

which leaves it unclear where the Duals originate from, since the forward pass succeeds with incorrect outputs:

julia> pullback(g, collect(0.5:0.25:1.5))
(Dual{Nothing}(8.59375,7.25), Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(g), Vector{Float64}}, Tuple{Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Vector{Real}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Real}}, Tuple{}}, Zygote.var"#3802#back#1201"{Zygote.var"#1197#1200"{Vector{Float64}, Vector{Real}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(f), Vector{Float64}}, Tuple{Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4123#back#1356"{Zygote.var"#1386#1388"}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2881#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}, Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1162"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2017#back#200"{typeof(identity)}}}}}}(∂(g)))

In the long run it would be better to fix this, however, in the short term simply adding an error before

T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing)
when the element type is abstract warning that the function needs to be made type stable on Dual inputs would at least make debugging this much easier. Happy to do a PR adding that.

@DomCRose DomCRose changed the title Confusing failure with broadcasted functions with type instability Confusing error / silent failure with broadcasted functions with type instability Jul 10, 2023
@mcabbott
Copy link
Member

An error would be better than the present state, e.g. isconcretetype(T) || error(...)

When T is abstract, could it just assume that there are Duals in there? If not, construct an array of zeros, instead of nothing?

@mcabbott mcabbott added the bug Something isn't working label Jul 10, 2023
@DomCRose
Copy link
Contributor Author

DomCRose commented Jul 10, 2023

An array of zeros doesn't seem quite right, in the first MWE above that would lead to incorrect zero gradients if I understand correctly?

Assuming Dual seems like it might work, since calling partials on a real or complex simply returns 0.0 anyway, although it might require a rework of the branching on complex inputs. Though I don't think the compiler will remove things if every element is not a Dual, so the quicker branch should be left for when the compiler confirms that the eltype isn't Dual.

Perhaps the dispatch on complex outputs could be moved inside the _broadcast_forward and _broadcast_forward_complex loops using another internal function? E.g. on this line

unbroadcast(args[i], broadcast((y1, o1) -> y1 * partials(o1,i), ȳ, out))

to split on complex o1 instead to do
unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*partials(real(o1),i) + imag(y1)*partials(imag(o1), i)), ȳ, out))

when required, so it is dispatched element wise. Should produce the same code when the eltype is uniform?

@DomCRose
Copy link
Contributor Author

Small update: I have a fix for this written I think, just need to add tests.

@DomCRose DomCRose linked a pull request Jul 13, 2023 that will close this issue
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants