Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nothing in output of a pullback #1464

Open
simsurace opened this issue Oct 12, 2023 · 2 comments
Open

nothing in output of a pullback #1464

simsurace opened this issue Oct 12, 2023 · 2 comments

Comments

@simsurace
Copy link
Contributor

I'm opening this to track an issue that was discussed across different repos.
According to @ToucheSir this should not happen: JuliaStats/Distances.jl#256 (comment)
This problem was exposed by the ChainRules 1.53.0 update.
Simple reproducer:

julia> using Distances, Zygote

julia> x = rand(10);

julia> f(x) = iszero(x) ? zero(x) : x;

julia> Zygote.gradient(_x -> sum(f, pairwise(Euclidean(), reshape(_x, :, 1); dims=1)), x)
ERROR: MethodError: no method matching *(::Nothing, ::Float64)

Closest candidates are:
  *(::Any, ::Any, ::Any, ::Any...)
   @ Base operators.jl:578
  *(::T, ::T) where T<:Union{Float16, Float32, Float64}
   @ Base float.jl:410
  *(::StridedArray{P}, ::Real) where P<:Dates.Period
   @ Dates ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/Dates/src/deprecated.jl:44
  ...

Stacktrace:
  [1] (::Zygote.var"#1412#1416"{Int64})(y1::Nothing, o1::ForwardDiff.Dual{Nothing, Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/lib/broadcast.jl:298
  [2] _broadcast_getindex_evalf
    @ ./broadcast.jl:683 [inlined]
  [3] _broadcast_getindex
    @ ./broadcast.jl:656 [inlined]
  [4] getindex
    @ ./broadcast.jl:610 [inlined]
  [5] macro expansion
    @ ./broadcast.jl:974 [inlined]
  [6] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [7] copyto!
    @ ./broadcast.jl:973 [inlined]
  [8] copyto!
    @ ./broadcast.jl:926 [inlined]
  [9] copy
    @ ./broadcast.jl:898 [inlined]
 [10] materialize
    @ ./broadcast.jl:873 [inlined]
 [11] broadcast(::Zygote.var"#1412#1416"{Int64}, ::Matrix{Union{Nothing, Float64}}, ::Matrix{ForwardDiff.Dual{Nothing, Float64, 2}})
    @ Base.Broadcast ./broadcast.jl:811
 [12] #1411
    @ ~/.julia/packages/Zygote/XJ8pP/src/lib/broadcast.jl:298 [inlined]
 [13] ntuple
    @ ./ntuple.jl:49 [inlined]
 [14] bc_fwd_back
    @ ~/.julia/packages/Zygote/XJ8pP/src/lib/broadcast.jl:297 [inlined]
 [15] #4155#back
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
 [16] #291
    @ ~/.julia/packages/Zygote/XJ8pP/src/lib/lib.jl:206 [inlined]
 [17] #2173#back
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
 [18] Pullback
    @ ./broadcast.jl:1317 [inlined]
 [19] Pullback
    @ ~/.julia/packages/Zygote/XJ8pP/ext/ZygoteDistancesExt.jl:104 [inlined]
 [20] (::ZygoteDistancesExt.var"#pairwise_Euclidean_pullback#52"{Zygote.Pullback{Tuple{ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, SqEuclidean, Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eps_pullback#396"{Tuple{DataType}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(eltype), Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eltype_pullback#385"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(ZygoteDistancesExt._sqrt_if_positive), Matrix{Float64}, Float64}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.var"#4155#back#1376"{Zygote.var"#bc_fwd_back#1414"{Matrix{ForwardDiff.Dual{Nothing, Float64, 2}}, Tuple{Matrix{Float64}, Float64}, Val{2}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Float64}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2881#back#688"{Zygote.var"#map_back#682"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1182"{Tuple{Nothing, Nothing, Nothing}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Matrix{Float64}}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, ZygoteDistancesExt.var"#63#back#30"{ZygoteDistancesExt.var"#32#33"{LinearAlgebra.Transpose{Float64, Matrix{Float64}}, typeof(transpose)}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float64}}, Tuple{}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:dims, Zygote.Context{false}, ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, Int64}}}}})(Δ::Matrix{Union{Nothing, Float64}})
    @ ZygoteDistancesExt ~/.julia/packages/Zygote/XJ8pP/ext/ZygoteDistancesExt.jl:107
 [21] Pullback
    @ ./REPL[86]:1 [inlined]
 [22] (::Zygote.Pullback{Tuple{var"#89#90", Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#sum_pullback_f2#1665"{typeof(f), Colon, Matrix{Tuple{Float64, Zygote.var"#ad_pullback#58"{Tuple{typeof(f), Float64}, Zygote.Pullback{Tuple{typeof(f), Float64}, Any}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.Pullback{Tuple{Type{Euclidean}}, Tuple{}}, ZygoteDistancesExt.var"#pairwise_Euclidean_pullback#52"{Zygote.Pullback{Tuple{ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, SqEuclidean, Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eps_pullback#396"{Tuple{DataType}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(eltype), Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eltype_pullback#385"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(ZygoteDistancesExt._sqrt_if_positive), Matrix{Float64}, Float64}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.var"#4155#back#1376"{Zygote.var"#bc_fwd_back#1414"{Matrix{ForwardDiff.Dual{Nothing, Float64, 2}}, Tuple{Matrix{Float64}, Float64}, Val{2}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Float64}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2881#back#688"{Zygote.var"#map_back#682"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1182"{Tuple{Nothing, Nothing, Nothing}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Matrix{Float64}}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, ZygoteDistancesExt.var"#63#back#30"{ZygoteDistancesExt.var"#32#33"{LinearAlgebra.Transpose{Float64, Matrix{Float64}}, typeof(transpose)}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float64}}, Tuple{}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:dims, Zygote.Context{false}, ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, Int64}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2799#back#625"{Zygote.var"#619#623"{Vector{Float64}, Tuple{Colon, Int64}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [23] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#89#90", Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#sum_pullback_f2#1665"{typeof(f), Colon, Matrix{Tuple{Float64, Zygote.var"#ad_pullback#58"{Tuple{typeof(f), Float64}, Zygote.Pullback{Tuple{typeof(f), Float64}, Any}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.Pullback{Tuple{Type{Euclidean}}, Tuple{}}, ZygoteDistancesExt.var"#pairwise_Euclidean_pullback#52"{Zygote.Pullback{Tuple{ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, SqEuclidean, Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eps_pullback#396"{Tuple{DataType}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(eltype), Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eltype_pullback#385"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(ZygoteDistancesExt._sqrt_if_positive), Matrix{Float64}, Float64}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.var"#4155#back#1376"{Zygote.var"#bc_fwd_back#1414"{Matrix{ForwardDiff.Dual{Nothing, Float64, 2}}, Tuple{Matrix{Float64}, Float64}, Val{2}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Float64}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2881#back#688"{Zygote.var"#map_back#682"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1182"{Tuple{Nothing, Nothing, Nothing}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Matrix{Float64}}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, ZygoteDistancesExt.var"#63#back#30"{ZygoteDistancesExt.var"#32#33"{LinearAlgebra.Transpose{Float64, Matrix{Float64}}, typeof(transpose)}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float64}}, Tuple{}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:dims, Zygote.Context{false}, ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, Int64}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2799#back#625"{Zygote.var"#619#623"{Vector{Float64}, Tuple{Colon, Int64}}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface.jl:45
 [24] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface.jl:97
 [25] top-level scope
    @ REPL[86]:1
@devmotion
Copy link
Collaborator

The problem can be fixed by

M src/lib/broadcast.jl
@@ -295,9 +295,14 @@ end
   y = broadcast(x -> value(x), out)
   function bc_fwd_back(ȳ)
     dargs = ntuple(valN) do i
-      unbroadcast(args[i], broadcast((y1, o1) -> y1 * partials(o1,i), ȳ, out))
+      unbroadcast(args[i], broadcast((y1, o1) -> y1 === nothing ? nothing : y1 * partials(o1,i), ȳ, out))
+    end
+    # Collapse all `nothing`
+    if dargs isa Tuple{Vararg{Nothing}}
+        return nothing
+    else
+        (nothing, nothing, dargs...) # nothings for broadcasted & f
     end
-    (nothing, nothing, dargs...) # nothings for broadcasted & f
   end
   return y, bc_fwd_back

Similar fixes could (should?) be applied to many more functions that currently assume that the input to the pullback is completely numeric - but in cases such as the example above when dealing with arrays where some elements are nothing and some numbers this is not true. IIRC Zygote would catch all nothing inputs at a higher level but this seems difficult (impossible?) with such mixed and nested types.

Unfortunately, one additional fix is required though: Summation of the broadcast results in unbroadcast breaks due to the nothings. I have to leave for lunch now and will continue debugging later (if the issue is not solved by then 🙂).

@ToucheSir
Copy link
Member

One idea for a more general solution would be to add an overload in ZygoteRules here which collapses Tuple{Vararg{Nothing}}. Then any rule declared with @adjoint should inherit the behaviour.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants