Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Construction of ComponentArray inside of AD/Zygote #176

Open
frankschae opened this issue Nov 17, 2022 · 2 comments
Open

Construction of ComponentArray inside of AD/Zygote #176

frankschae opened this issue Nov 17, 2022 · 2 comments

Comments

@frankschae
Copy link

I want to compute the gradient of a loss function with respect to a ComponentArray. In the loss function, I need to reconstruct a ComponentArray.
Based on @jonniedie reply #126 (comment), I tried

function my_sum(v)
    ax = getaxes(v)
    @unpack x, y = v
    ca = ComponentArray([x..., y...], ax)
    return sum(ca.x + ca.y)
end

Zygote.gradient(my_sum, ComponentArray(x=[0.0], y=[0.0]))

which fails with

ERROR: ArgumentError: indexed assignment with a single value to possibly many locations is not supported; perhaps use broadcasting `.=` instead?
Stacktrace:
  [1] setindex_shape_check(::ChainRulesCore.Tangent{Any, Tuple{Float64}}, ::Int64)
    @ Base ./indices.jl:261
  [2] _unsafe_setindex!(#unused#::IndexLinear, A::Vector{Float64}, x::ChainRulesCore.Tangent{Any, Tuple{Float64}}, I::UnitRange{Int64})
    @ Base ./multidimensional.jl:939
  [3] _setindex!
    @ ./multidimensional.jl:930 [inlined]
  [4] setindex!
    @ ./abstractarray.jl:1344 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/ComponentArrays/EjZNJ/src/array_interface.jl:0 [inlined]
  [6] _setindex!(x::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(x = 1:1, y = 2:2)}}}, v::ChainRulesCore.Tangent{Any, Tuple{Float64}}, idx::Val{:y})
    @ ComponentArrays ~/.julia/packages/ComponentArrays/EjZNJ/src/array_interface.jl:129
  [7] setproperty!
    @ ~/.julia/packages/ComponentArrays/EjZNJ/src/namedtuple_interface.jl:17 [inlined]
  [8] (::ComponentArrays.var"#getproperty_adjoint#87"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(x = 1:1, y = 2:2)}}}, Symbol})(Δ::ChainRulesCore.Tangent{Any, Tuple{Float64}})
    @ ComponentArrays ~/.julia/packages/ComponentArrays/EjZNJ/src/compat/chainrulescore.jl:4
  [9] ZBack
    @ ~/.julia/packages/Zygote/PD12J/src/compiler/chainrules.jl:206 [inlined]
 [10] Pullback
    @ ~/.julia/packages/UnPack/EkESO/src/UnPack.jl:34 [inlined]
 [11] (::typeof(∂(unpack)))(Δ::Tuple{Float64})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
 [12] macro expansion
    @ ~/.julia/packages/UnPack/EkESO/src/UnPack.jl:101 [inlined]
 [13] Pullback

pointing to the @unpack call. @avik-pal noted that it also happens even without the @unpack

function my_sum(v)
    ax = getaxes(v)
    ca = ComponentArray([v.x..., v.y...], ax)
    return sum(ca.x + ca.y)
end

Zygote.gradient(my_sum, ComponentArray(x=[0.0], y=[0.0]))

but is resolved by using vcat

function my_sum(v)
    ax = getaxes(v)
    @unpack x, y = v
    ca = ComponentArray(vcat(x,y), ax)
    return sum(ca.x + ca.y)
end

The issue seems to be that \Delta is a Tuple{Float64} in

setproperty!(zero_x, s, Δ)

for splatting.

@kaandocal
Copy link

Constructing an array in general fails with Zygote:

using Zygote
using ComponentArrays

Zygote.gradient(x -> ComponentArray(a = [5])[1], [0.])

gives

ERROR: Mutating arrays is not supported -- called push!(Vector{Any}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/array.jl:86
  [3] (::Zygote.var"#397#398"{Vector{Any}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/array.jl:105
  [4] (::Zygote.var"#2508#back#399"{Zygote.var"#397#398"{Vector{Any}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [5] Pullback
    @ ./namedtuple.jl:309 [inlined]
  [6] (::typeof((merge)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
  [7] Pullback
    @ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:161 [inlined]
  [8] (::typeof((make_idx)))(Δ::Tuple{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:147 [inlined]
 [10] (::typeof((make_carray_args)))(Δ::Tuple{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:139 [inlined]
 [12] (::typeof((make_carray_args)))(Δ::Tuple{Vector{Float64}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:63 [inlined]
 [14] Pullback
    @ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:66 [inlined]
 [15] (::typeof((#ComponentArray#21)))(Δ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(a = 1:1,)}}})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:66 [inlined]
 [17] (::typeof((Type##kw)))(Δ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(a = 1:1,)}}})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [18] Pullback
    @ ./REPL[4]:1 [inlined]
 [19] (::typeof((#3)))(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [20] (::Zygote.var"#60#61"{typeof((#3))})(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
 [21] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
 [22] top-level scope
    @ REPL[4]:1

It seems that to make CA work with Zygote it must entirely avoid mutating arrays (even appending to arrays)...

@Yuan-Ru-Lin
Copy link

Is there any chance to resolve the above error by using Zygote.Buffer as illustrated in here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants