New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Construction of ComponentArray inside of AD/Zygote #176
Comments
Constructing an array in general fails with Zygote: using Zygote
using ComponentArrays
Zygote.gradient(x -> ComponentArray(a = [5])[1], [0.]) gives ERROR: Mutating arrays is not supported -- called push!(Vector{Any}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] _throw_mutation_error(f::Function, args::Vector{Any})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/array.jl:86
[3] (::Zygote.var"#397#398"{Vector{Any}})(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/array.jl:105
[4] (::Zygote.var"#2508#back#399"{Zygote.var"#397#398"{Vector{Any}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[5] Pullback
@ ./namedtuple.jl:309 [inlined]
[6] (::typeof(∂(merge)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[7] Pullback
@ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:161 [inlined]
[8] (::typeof(∂(make_idx)))(Δ::Tuple{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[9] Pullback
@ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:147 [inlined]
[10] (::typeof(∂(make_carray_args)))(Δ::Tuple{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[11] Pullback
@ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:139 [inlined]
[12] (::typeof(∂(make_carray_args)))(Δ::Tuple{Vector{Float64}, Nothing})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:63 [inlined]
[14] Pullback
@ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:66 [inlined]
[15] (::typeof(∂(#ComponentArray#21)))(Δ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(a = 1:1,)}}})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[16] Pullback
@ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:66 [inlined]
[17] (::typeof(∂(Type##kw)))(Δ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(a = 1:1,)}}})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[18] Pullback
@ ./REPL[4]:1 [inlined]
[19] (::typeof(∂(#3)))(Δ::Int64)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[20] (::Zygote.var"#60#61"{typeof(∂(#3))})(Δ::Int64)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
[21] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
[22] top-level scope
@ REPL[4]:1 It seems that to make CA work with Zygote it must entirely avoid mutating arrays (even appending to arrays)... |
Is there any chance to resolve the above error by using |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I want to compute the gradient of a loss function with respect to a
ComponentArray
. In the loss function, I need to reconstruct aComponentArray
.Based on @jonniedie reply #126 (comment), I tried
which fails with
pointing to the
@unpack
call. @avik-pal noted that it also happens even without the@unpack
but is resolved by using
vcat
The issue seems to be that \Delta is a Tuple{Float64} in
ComponentArrays.jl/src/compat/chainrulescore.jl
Line 4 in cbb24ef
for splatting.
The text was updated successfully, but these errors were encountered: