Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mapreduce with accumulation inside is broken #1448

Open
Red-Portal opened this issue Aug 17, 2023 · 6 comments
Open

mapreduce with accumulation inside is broken #1448

Red-Portal opened this issue Aug 17, 2023 · 6 comments
Labels
bug Something isn't working ChainRules adjoint -> rrule, and further integration compiler

Comments

@Red-Portal
Copy link

Red-Portal commented Aug 17, 2023

Hi, the following use-case of mapreduce doesnt work:

gradient(randn(10)) do x
         y₀ = Float64[]
         ∑x = 0.0
         ys = mapreduce(vcat, x, 1:length(x); init = y₀) do xᵢ, r
             yᵢ = xᵢ.^2
             ∑x += xᵢ
             [yᵢ]
         end
         sum(ys) + ∑x
end

It seems the ∑x += xᵢ part is at fault here because with or without init it doesn't work:

(vcat, [[0.7008872619503351], [0.057800842475147274], [0.4508806424034738], [6.360461041381114], [8.229642138382558e-5], [0.43781177206525196], [1.7425577575168238], [0.8947064561514089], [0.678655434187004], [0.10421486484899199]])
(init = Float64[],)
ERROR: MethodError: no method matching iterate(::Nothing)

Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen})
   @ Base range.jl:880
  iterate(::Union{LinRange, StepRangeLen}, ::Integer)
   @ Base range.jl:880
  iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}}
   @ Base dict.jl:698
  ...

Stacktrace:
  [1] indexed_iterate(I::Nothing, i::Int64)
    @ Base ./tuple.jl:91
  [2] chain_rrule_kw(::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::Function, ::NamedTuple{(:init,), Tuple{Vector{Float64}}}, ::Function, ::Function, ::Vararg{Any})
    @ Zygote ./REPL[7]:5
  [3] macro expansion
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:init,), Tuple{Vector{Float64}}}, ::typeof(reduce), ::typeof(vcat), ::Vector{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101
  [5] _pullback
    @ ./reducedim.jl:359 [inlined]
  [6] _pullback(::Zygote.Context{false}, ::Base.var"##mapreduce#801", ::Base.Pairs{Symbol, Vector{Float64}, Tuple{Symbol}, NamedTuple{(:init,), Tuple{Vector{Float64}}}}, ::typeof(mapreduce), ::var"#24#26", ::typeof(vcat), ::Vector{Float64}, ::UnitRange{Int64})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
  [7] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
  [8] adjoint
    @ ~/.julia/packages/Zygote/4rucm/src/lib/lib.jl:203 [inlined]
  [9] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
 [10] _pullback
    @ ./reducedim.jl:359 [inlined]
 [11] _pullback
    @ ./REPL[8]:4 [inlined]
 [12] _pullback(ctx::Zygote.Context{false}, f::var"#23#25", args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
 [13] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:44
 [14] pullback
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:42 [inlined]
 [15] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:96
 [16] top-level scope
    @ REPL[8]:1

This used to work and got broken at some point. Is this an rrule problem? This currently works without problem on ReverseDiff and ForwardDiff.

@Red-Portal
Copy link
Author

Related issues in Bijectors.jl and Turing.jl

@ToucheSir
Copy link
Member

That code path has failed in the past because there are ambiguities in which rrules might apply for a given call. In this case I'm not sure if that is the culprit, however. I believe the problem is that ChainRules does not have a rrule for reduce(vcat, ...; init=...), yet somehow the has_chain_rrule detection logic is reporting it does.

@ToucheSir ToucheSir added the bug Something isn't working label Aug 18, 2023
@mcabbott
Copy link
Member

mapreduce(f, vcat, x, 1:length(x); init = y₀) could probably be plumbed to reduce(vcat, foldl(f, x, 1:length(x); init = y₀)). Perhaps that would be one way to work around this.

Note also that reduce(vcat, xs; init) and mapreduce(f, vcat, xs) are always pairwise, they never hit the magic fast path of reduce(vcat, xs).

@Red-Portal
Copy link
Author

@torfjelde Is there a reason we compute the first element first and then use that to initialize mapreduce in Stacked?

@torfjelde
Copy link
Contributor

Is there a reason we compute the first element first and then use that to initialize mapreduce in Stacked?

Type-stability issues, in particular when combined with AD. Very often we'd run into instabilities without init, and so I believe this was a way to work around this (type-stability is quite crucial here, in particular with Zygote).

@ToucheSir
Copy link
Member

About type stability, note that any call to a method with kwargs (whether they're provided in the call or not) will be type unstable unless there's a rrule defined for that particular method. In this case there is not.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ChainRules adjoint -> rrule, and further integration compiler
Projects
None yet
Development

No branches or pull requests

4 participants