Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing ChainRule for ComponentVector(a; b...) #207

Open
scheidan opened this issue May 20, 2023 · 4 comments
Open

Missing ChainRule for ComponentVector(a; b...) #207

scheidan opened this issue May 20, 2023 · 4 comments

Comments

@scheidan
Copy link
Contributor

I've tried to use Zygote with ComponentArrays but cannot it cannot get through this code:

c = ComponentVector(a; b...)

I think it is just a missing ChainRule rule. I've tired, but unfortunately writing rules is black magic for me...

@scheidan
Copy link
Contributor Author

Here is a (failing) attempt of mine:

using ComponentArrays
import ChainRulesCore
import Zygote

# -----------
# rule

function ChainRulesCore.rrule(::typeof(ComponentArrays.ComponentVector),
                              x::ComponentVector; kwargs...)
    res = ComponentVector(x; kwargs...)
    function pullback(Δ)
        one_x = zero(similar(x, eltype(Δ))) .+ 1
        one_y = zero(ComponentVector{eltype(Δ)}(kwargs)) .+ 1
        return ChainRulesCore.NoTangent(), one_x, one_y
    end
    return res, pullback
end

# -----------
# test

function mymerge(x::ComponentVector, y::ComponentVector)
    z = ComponentVector(x; y...)
    z
end


x = ComponentVector(a=1.0, b=2, c=(e=3, f=4))
y = ComponentVector(a = 11, e=4.0, d=5.0)
mymerge(x, y)

Zygote.gradient(a -> sum(mymerge(a, y)), x)[1] # fails with StackOverflowError
Zygote.gradient(a -> sum(mymerge(x, a)), y)[1] # fails with StackOverflowError

Not sure why this is causing a StackOverflowError. A test version without the kwargs seemed to work.

@jonniediegelman
Copy link
Collaborator

jonniediegelman commented Jul 28, 2023

It looks like you were super close. You just needed to splat out the keyword arguments in the pullback.

function ChainRulesCore.rrule(::typeof(ComponentArrays.ComponentVector),
                              x::ComponentVector; kwargs...)
    res = ComponentVector(x; kwargs...)
    function pullback(Δ)
        one_x = zero(similar(x, eltype(Δ))) .+ 1
        one_y = zero(ComponentVector{eltype(Δ)}(; kwargs...)) .+ 1
        return ChainRulesCore.NoTangent(), one_x, one_y
    end
    return res, pullback
end

Thanks, though! I'll add it as soon as I get a chance.

@jonniediegelman
Copy link
Collaborator

Wait no, that gives the wrong answer.

@jonniedie
Copy link
Owner

Interesting: ChainRules doesn't work with keyword arguments. We may want to instead define the behavior in a merge method so it's compatible with ChainRules.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants