Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

slow/high allocation gradient with mapreduce and iterators #1487

Open
tcovert opened this issue Dec 31, 2023 · 11 comments
Open

slow/high allocation gradient with mapreduce and iterators #1487

tcovert opened this issue Dec 31, 2023 · 11 comments

Comments

@tcovert
Copy link

tcovert commented Dec 31, 2023

I've found that when I compute the gradient of a mapreduce expression whose inputs are slices, Zygote generates fast/low-allocation gradients. However, Zygote's gradient for the equivalent expression with iterator inputs is much slower, with much higher allocations. In the MWE below, each iterator result is of homogenous size, so in principle the code to run should be the same. However, I encountered this issue when planning for code where the iterators would be generating results of heterogenous sizes, where my seemingly fast slice/reshape strategy would no longer be feasible. Is this unavoidable? FWIW ForwardDiff's gradients for the two approaches are equally fast/have similar allocation counts.

It is possible this is related to #304?

Thanks in advance for any suggestions the Zygote community has here.

Here is an MWE:

using LogExpFunctions, SplitApplyCombine, LinearAlgebra, Zygote

# this does the main work on each observation (one observation of a multinomial logit negative log-likelihood)
function f(y_i, delta_i)
  return -1.0 * (dot(y_i, delta_i) - logsumexp(delta_i))
end

# this one has a fast/low allocation gradient
function g(y, X, theta, gs)
  deltas = X * theta

  nn, k = size(X)
  npeople = length(gs)
  nchoices = convert(Int64, nn / npeople)

  rdeltas = eachslice(reshape(deltas, nchoices, npeople), dims = 2)
  ry = eachslice(reshape(y, nchoices, npeople), dims = 2)

  return mapreduce(f, +, ry, rdeltas)
end

# this one has a slow/high allocation gradient
function h(y, X, theta, gs)
  deltas = X * theta

  rdeltas = (view(deltas, gidx) for gidx in gs)
  ry = (view(y, gidx) for gidx in gs)

  return mapreduce(f, +, ry, rdeltas)
end

# make test data and see if this works
nchoices = 100
npeople = 1000
ncols = 10
X = randn(nchoices * npeople, ncols)
theta = randn(size(X, 2))
ys = rand(1:nchoices, npeople)

# assign a single random choice to each individual
y = zeros(nchoices * npeople)
for j = 1:npeople
  y[(j-1) * nchoices + ys[j]] = 1.0
end

groups = collect(groupinds(repeat(1:npeople, inner = nchoices)));

# define the closures, verify they are equivalent
g0(x) = g(y, X, x, groups)
h0(x) = h(y, X, x, groups)

@time g0(theta)
@time g0(theta) # 16 allocations ~ 790Kib
@time h0(theta)
@time h0(theta) # 3 allocations, ~ 781Kib

# define the gradients
gg0(x) = gradient(g0, x)[1]
gh0(x) = gradient(h0, x)[1]

@time gg0(theta)
@time gg0(theta) # ~0.01 seconds, 34k allocations, ~ 22 Mib

@time gh0(theta)
@time gh0(theta) # ~1.74 seconds, 395k allocations, ~ 3 Gib
@ToucheSir
Copy link
Member

Diffing through non-array iterators is going to be tough in general, because the interface itself is rather constrained yet allows almost arbitrary behaviour from implementations. This is less of a problem for forward-mode ADs like ForwardDiff because they essentially run the same function twice, but per the name reverse-mode ADs have to "reverse" all the operations and that can be tricky to do (let alone performantly). Zygote in particular is not well-equipped here because it works on unoptimized IR, whereas normal iteration code heavily relies on optimizations like inlining to have good performance. We've specialized operations for certain types like arrays because they're a known quantity, but the language doesn't provide us with many tools to do the same for looser types like Generators or user-defined iterators.

@tcovert
Copy link
Author

tcovert commented Dec 31, 2023

Thanks for the quick response, and for helping me understand the issue! I'll try to figure out how to reformulate my problem without iterators then...

@ToucheSir
Copy link
Member

I think if you're sticking with Zygote, the current approach you have of using operations with eager materialization/implicit vectorization is the way to go. Depending on the use case, it may also be worth trying out other ADs since they'll have different performance characteristics.

@mcabbott
Copy link
Member

mcabbott commented Jan 2, 2024

I think this probably boils down to eachslice having a gradient rule. The rule for getindex it (still) allocates a whole array of zeros every time.

mapreduce(f, +, ry, rdeltas) has no special rules. I see another 3x speedup from using sum(map(f, ry, rdeltas)) instead. (Perhaps a version of the advice to use eager methods above.)

I'm a bit lazy to write this out, but I'm sure you can replace f with two operations each acting on the whole ry, rdeltas matrices .Something like sum(ry .* rdeltas; dims=1) .+ logsumexp(rdeltas; dims=1)?

@tcovert
Copy link
Author

tcovert commented Jan 3, 2024

Here's the beginning of the error I get when I try to Zygote gradient the map(sum(...)) approach:

ERROR: Need an adjoint for constructor Base.Generator{Vector{Vector{Int64}}, var"#8#10"{Vector{Float64}}}. Gradient is of type Vector{Vector{Float64}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] (::Zygote.Jnew{Base.Generator{Vector{Vector{Int64}}, var"#8#10"{Vector{Float64}}}, Nothing, false})(Δ::Vector{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/lib/lib.jl:330
  [3] (::Zygote.var"#2214#back#313"{Zygote.Jnew{Base.Generator{Vector{Vector{Int64}}, var"#8#10"{Vector{Float64}}}, Nothing, false}})(Δ::Vector{Vector{Float64}})
    @ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71
  [4] Pullback
    @ ./generator.jl:32 [inlined]
  [5] (::Zygote.Pullback{Tuple{Type{Base.Generator{Vector{Vector{Int64}}, var"#8#10"{Vector{Float64}}}}, var"#8#10"{Vector{Float64}}, Vector{Vector{Int64}}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2214#back#313"{Zygote.Jnew{Base.Generator{Vector{Vector{Int64}}, var"#8#10"{Vector{Float64}}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(convert), Type{Vector{Vector{Int64}}}, Vector{Vector{Int64}}}, Any}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(convert), Type{var"#8#10"{Vector{Float64}}}, var"#8#10"{Vector{Float64}}}, Tuple{}}}})(Δ::Vector{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./generator.jl:32 [inlined]
  [7] (::Zygote.Pullback{Tuple{Type{Base.Generator}, var"#8#10"{Vector{Float64}}, Vector{Vector{Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.Generator{Vector{Vector{Int64}}, var"#8#10"{Vector{Float64}}}}, var"#8#10"{Vector{Float64}}, Vector{Vector{Int64}}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2214#back#313"{Zygote.Jnew{Base.Generator{Vector{Vector{Int64}}, var"#8#10"{Vector{Float64}}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(convert), Type{Vector{Vector{Int64}}}, Vector{Vector{Int64}}}, Any}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(convert), Type{var"#8#10"{Vector{Float64}}}, var"#8#10"{Vector{Float64}}}, Tuple{}}}}}})(Δ::Vector{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0

The second suggestion unfortunately can't work in this case, since I need to logsumexp over individual groups of rdeltas.

FWIW, ReverseDiff seems to handle the mapreduce() approach just fine.

@mcabbott
Copy link
Member

mcabbott commented Jan 3, 2024

Not sure how you got that, working code is:

julia> function myg(y, X, theta, gs)
         deltas = X * theta

         nn, k = size(X)
         npeople = length(gs)
         nchoices = convert(Int64, nn / npeople)

         rdeltas = eachslice(reshape(deltas, nchoices, npeople), dims = 2)
         ry = eachslice(reshape(y, nchoices, npeople), dims = 2)

         return sum(map(f, ry, rdeltas))  # only change
       end
myg (generic function with 1 method)

julia> myg0(x) = myg(y, X, x, groups)
myg0 (generic function with 1 method)

julia> gmyg0(x) = gradient(myg0, x)[1]
gmyg0 (generic function with 1 method)

julia> @btime gh0($theta);
  min 563.853 ms, mean 588.972 ms (417571 allocations, 3.01 GiB)

julia> @btime gg0($theta);
  min 9.171 ms, mean 12.937 ms (40337 allocations, 22.37 MiB)

julia> @btime gmyg0($theta);
  min 3.083 ms, mean 5.203 ms (8117 allocations, 13.76 MiB)

can't work in this case, since I need to logsumexp over individual groups of rdeltas.

Not sure I understand. Your code applies it after eachslice, which should be the same as using dims:

julia> x = randn(3, 4)
3×4 Matrix{Float64}:
  0.572076   0.97193    0.133754    0.307906
 -0.961497  -0.749721   1.12417    -0.0599307
  0.309945   0.651524  -0.0721172   0.947538

julia> map(logsumexp, eachcol(x))
4-element Vector{Float64}:
 1.2577822485059058
 1.616214911339637
 1.6392287587452563
 1.5855046640359527

julia> logsumexp(x; dims=1)
1×4 Matrix{Float64}:
 1.25778  1.61621  1.63923  1.5855

@tcovert
Copy link
Author

tcovert commented Jan 3, 2024

Ah, I see that you were referring to the eachslices approach.

I am trying to get the view based approach to work because in my application, the equivalent notion of a slice wouldn't actually be the same size across the things I'd like to map over. What I have in mind is a multi-class labeling problem, where the set of labels and number of choices is not necessarily constant across choice situations - so some observations have choices A, B and C, others have B, D, E, F, for example. Is there a version of slicing that somehow allows for such heterogeneously sized slices?

@mcabbott
Copy link
Member

mcabbott commented Jan 3, 2024

Ok. It's sad that just indexing is still so expensive. There was a ton of code written to do this more efficiently via InplaceableThunks in ChainRules, but it's not used by Zygote (nor anywhere else).

What you can do is make something like eachslice which has a gradient rule, a bit like the CR rule for eachslice. Surely this exists in a package somewhere but I can't find it. It's possible that NNlib.gather ought to support this but it does not:

julia> NNlib.gather(10:10:100, [2,3,2])
3-element Vector{Int64}:
 20
 30
 20

julia> NNlib.gather(10:10:100, [[1,2,3], [3,4]])
ERROR: MethodError: no method matching typelength(::Type{Vector{Int64}})

Maybe MLUtils.chunk does what you need, or close? This has a gradient rule:

julia> MLUtils.chunk(collect(transpose(10:10:100)), size=[2, 3, 5])
3-element Vector{SubArray{Int64, 2, Matrix{Int64}, Tuple{Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}}:
 [10 20]
 [30 40 50]
 [60 70 … 90 100]

@tcovert
Copy link
Author

tcovert commented Jan 3, 2024

MLUtils.chunk does indeed work!

using MLUtils
function h2(y, X, theta, gs)
  deltas = X * theta
  mlg = map(length, groups)
  return mapreduce(f, +, chunk(y, size = mlg), chunk(deltas, size = mlg))  
end

With this formulation, Zygote makes a gradient that is about as fast (with identical allocations) as ReverseDiff.gradient on the original h. ReverseDiff with a tape is faster still, but this is great progress. Thanks for your help.

@mcabbott
Copy link
Member

mcabbott commented Jan 3, 2024

sum(map(f, chunk(y, size = mlg), chunk(deltas, size = mlg))) is 3x faster here too.

julia> function h2(y, X, theta, gs)
         deltas = X * theta
         mlg = map(length, groups)
         return mapreduce(f, +, chunk(y, size = mlg), chunk(deltas, size = mlg))  
       end
h2 (generic function with 1 method)

julia> @btime gh0($theta);^C

julia> h20(x) = h2(y, X, x, groups)
h20 (generic function with 1 method)

julia> gh20(x) = gradient(h20, x)[1]
gh20 (generic function with 1 method)

julia> @btime gg0($theta);
  min 9.097 ms, mean 12.811 ms (40337 allocations, 22.37 MiB)

julia> @btime gh20($theta);
  min 10.395 ms, mean 14.487 ms (50348 allocations, 22.83 MiB)

julia> function h3(y, X, theta, gs)
         deltas = X * theta
         mlg = map(length, groups)
         return sum(map(f, chunk(y, size = mlg), chunk(deltas, size = mlg)))
       end
h3 (generic function with 1 method)

julia> h30(x) = h3(y, X, x, groups)
h30 (generic function with 1 method)

julia> gh30(x) = gradient(h30, x)[1]
gh30 (generic function with 1 method)

julia> @btime gh30($theta);
  min 3.629 ms, mean 5.583 ms (18130 allocations, 14.22 MiB)

julia> gh30(theta) ≈ gh20(theta) ≈ gh0(theta)
true

@tcovert
Copy link
Author

tcovert commented Jan 3, 2024

oh cool. with that, Zygote is even a bit faster than a taped version with ReverseDiff. thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants