Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strip zygote frames from mutation error stack trace #1501

Open
LilithHafner opened this issue Feb 16, 2024 · 1 comment
Open

Strip zygote frames from mutation error stack trace #1501

LilithHafner opened this issue Feb 16, 2024 · 1 comment
Labels
compiler help wanted Extra attention is needed

Comments

@LilithHafner
Copy link
Contributor

LilithHafner commented Feb 16, 2024

Motivation and description

When differentiating something complicated which contains mutation, it can be hard to know exactly where the mutation is. In this example, the mutation is tucked away inside the ComponentArray constructor, and in a larger example (e.g. DARPA-ASKEM/sciml-service#141) it might be hard to figure that out.

It would be very helpful if the stack trace provided the exact location of the mutation that triggers this error, rather than interleaving that stack trace with zygote frames. Failing that, it would at least by nice to inform the user that they should look at every third frame to figure out where in their code the mutation is.

julia> using ComponentArrays, Zygote

julia> function f(x)
           ca = ComponentArray(var=x)
           ca.var
       end
f (generic function with 1 method)

julia> Zygote.jacobian(f, [1,2,3])
ERROR: Mutating arrays is not supported -- called push!(Vector{Any}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:70
  [3] (::Zygote.var"#547#548"{Vector{Any}})(::Nothing)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:89
  [4] (::Zygote.var"#2643#back#549"{Zygote.var"#547#548"{Vector{Any}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [5] merge
    @ ./namedtuple.jl:371 [inlined]
  [6] (::Zygote.Pullback{Tuple{typeof(merge), @NamedTuple{}, Base.Generator{…}}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [7] make_idx
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:170 [inlined]
  [8] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Vector{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [9] make_carray_args
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:151 [inlined]
 [10] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Vector{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [11] make_carray_args
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:144 [inlined]
 [12] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Vector{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [13] ComponentArray
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:64 [inlined]
 [14] #ComponentArray#21
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:67 [inlined]
 [15] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ComponentVector{Int64, Vector{…}, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [16] ComponentArray
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:67 [inlined]
 [17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ComponentVector{Int64, Vector{…}, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [18] f
    @ ./REPL[2]:2 [inlined]
 [19] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [20] #291
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206 [inlined]
 [21] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{}, Zygote.Pullback{}}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [22] call_composed
    @ ./operators.jl:1045 [inlined]
 [23] (::Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{…}, Tuple{…}, @Kwargs{}}, Any})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [24] call_composed
    @ ./operators.jl:1044 [inlined]
 [25] #_#103
    @ ./operators.jl:1041 [inlined]
 [26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [27] #291
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206 [inlined]
 [28] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [29] ComposedFunction
    @ ./operators.jl:1041 [inlined]
 [30] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [31] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [32] withjacobian(f::Function, args::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/grad.jl:150
 [33] jacobian(f::Function, args::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/grad.jl:128
 [34] top-level scope
    @ REPL[3]:1
Some type information was truncated. Use `show(err)` to see complete types.
@ToucheSir ToucheSir added help wanted Extra attention is needed compiler labels Feb 16, 2024
@ToucheSir
Copy link
Member

The story with Zygote stacktraces is more complex than described and could use a little explaining. If we use this stacktrace for illustration:

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:70

The first two stackframes are what you'd expect, common error reporting code. More interesting is the next two:

  [3] (::Zygote.var"#547#548"{Vector{Any}})(::Nothing)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:89
  [4] (::Zygote.var"#2643#back#549"{Zygote.var"#547#548"{Vector{Any}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72

As you've correctly identified, [3] has the actual rule we should look at. So what is [4]? That would be the rule machinery itself at https://github.com/FluxML/ZygoteRules.jl/blob/f9bf0e367fa259c5aa68f0e14ccbf2125d734bd6/src/adjoint.jl#L72. Not very helpful.

Now for the surprising revelation: there is actually no "interleaving of Zygote frames in this stacktrace". From [2] to [33], it's all Zygote. But how can that be when we have frames like this?

  [5] merge
    @ ./namedtuple.jl:371 [inlined]
  [6] (::Zygote.Pullback{Tuple{typeof(merge), @NamedTuple{}, Base.Generator{…}}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0

Essentially, Zygote's generated functions can spoof line numbers from the original function, so that [5] merge frame is actually the same call as [6] (the AD-generated pullback).

I don't know why this was done. It was either intentional to help with looking up the original function since the genfunc code provides little info to work with, or the line info is sticking around by accident. This snippet also shows the limitations Zygote has around stackframe printing. Ideally, we'd want the call info of [5] with the file name and line number of [6] to make stacktraces would be shorter and cleaner. I'm assuming this is possible, but the main problem is that Zygote's internals are a PITA to work with (mostly because of IRTools, IMO).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiler help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants