Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use zygote2differential to wrap chainrules inputs #1057

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Aug 25, 2021

@mzgubic implemented zygote2differential as a better version of wrap_chainrules_inputs and added it to use in the code for rrule_via_ad.
But it was not added to the normal path for when Zygote uses ChainRules.
I guess because it requires keeping the primal values in memory.
Which is probably a lot?

Anyway this would give us more consistent chainrules types.
No more Tangent{Any} or nothings that are hidden with-in arrays.

We probably do not want to merge this as is because of the extra memory use.
or maybe it is not too bad. Do we have a benchmark for it?

But hopefully this will fix the problems in TuringLang/DistributionsAD.jl#197
cc @devmotion .
If it does we can look at reworking zygote2differential to not have to store so much.
We learnt a lot about doing that for ProjectTo
same techniques can be applied here.

NB: I am putting this PR up at 9:30 at night, and I have not even run it locally.
Might have typos etc and just not work.
It also has no tests, yet.

src/compiler/chainrules.jl Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member Author

well this breaks some tests in weird ways.
Clearly some edge cases that zygote2differential doesn't yet cover.
Still might work for TuringLang/DistributionsAD.jl#197
and we can fix those other things

@devmotion
Copy link
Collaborator

I just checked TuringLang/DistributionsAD.jl#198 (the CR1 version) locally and it still fails with the same error messages ("adjoint for constructor ..."), even with this PR.

@oxinabox
Copy link
Member Author

Yeah won't fix that.
I meant fixing
TuringLang/DistributionsAD.jl#197 (comment)

@oxinabox
Copy link
Member Author

There was a matching differential2zygote that@mzgubic wrote.
Which might fix that, if the underlying cause was that the Tangent escaped by hiding in an arrays.

@devmotion
Copy link
Collaborator

Ah sorry, I misunderstood your comment. Unfortunately, the example is not fixed either.

@mzgubic
Copy link
Collaborator

mzgubic commented Aug 26, 2021

Here it is, in case you find it useful:

differential2legacy(x) = unthunk(x) # TODO eventually remove this
differential2legacy(::AbstractZero) = nothing
differential2legacy(t::Union{Tuple, NamedTuple}) = map(differential2legacy, t)
differential2legacy(::Nothing) = (legacytype_warn(Nothing); return nothing)
differential2legacy(a::AbstractArray) = differential2legacy.(a) # TODO: what to do with arrays with nothing?
differential2legacy(a::AbstractArray{<:Number}) = a
for T_outer in (:Tuple, :NamedTuple)
  # we create separate methods rather than using a `Union` + an `if` so that we avoid a
  # branch that changes output type, because nested AD on that kinda thing makes Zygote less
  # than happy.
  @eval @inline function differential2legacy(x::Composite{P, T}) where {P, T<:$T_outer}
    xp = map(differential2legacy, canonicalize(x))
    convert($T_outer, xp)
  end
end

I do recall getting into some kind of trouble when using this instead of wrap_chainrules_outputs. Don't think I've solved it though

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants