Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pullback in constructor rrule should unwrap a NamedDimsArray gradient back to an Array #179

Open
mzgubic opened this issue Sep 8, 2021 · 4 comments
Labels
bug Something isn't working good first issue Good for newcomers

Comments

@mzgubic
Copy link
Collaborator

mzgubic commented Sep 8, 2021

From @mcabbott: #178 (comment)

I am too slow but shouldn't this remove names from the array it gets? On the assumption that similar or broadcasting in the computation of a gradient for the array with names may, probably should, produce another array with names?

julia> nda = NamedDimsArray(rand(Int8,3,4), (:x, :y));

julia> gradient(x -> sum(x), nda)  # does not know about names
(Fill(1, 3, 4),)

julia> gradient(x -> sum(x; dims=1)[1], nda)  # involves broadcasting, so might? 
ERROR: DimensionMismatch("destination axes (Base.OneTo(3), Base.OneTo(4)) are not compatible with source axes (Base.OneTo(1), Base.OneTo(4))")

julia> gradient(x -> sum(x; dims=1)[1], nda.data)
(Int8[1 0 0 0; 1 0 0 0; 1 0 0 0],)
@mcabbott
Copy link
Collaborator

mcabbott commented Sep 8, 2021

The error in my example seems to have gone away. But better examples would be:

julia> using NamedDims, Zygote

julia> nda = NamedDimsArray(rand(Int8,3,4), (:x, :y));

julia> gradient(x -> sum(NamedDimsArray(x, (:x, :y)); dims=1)[1], rand(3,4))[1]
3×4 NamedDimsArray{(:x, :y), Float64, 2, Matrix{Float64}}:
 1.0  0.0  0.0  0.0
 1.0  0.0  0.0  0.0
 1.0  0.0  0.0  0.0

julia> gradient(x -> sum(NamedDimsArray(x, (:x, :y)) * nda'), rand(3,4))[1]
3×4 NamedDimsArray{(:_, :y), Float64, 2, Matrix{Float64}}:
 94.0  121.0  -116.0  -165.0
 94.0  121.0  -116.0  -165.0
 94.0  121.0  -116.0  -165.0

@oxinabox
Copy link
Member

oxinabox commented Sep 9, 2021

Tangent should have same names as primal, no?

@mcabbott
Copy link
Collaborator

mcabbott commented Sep 9, 2021

Yes, ideally. Many rules won't do that, but ProjectTo could enforce it.

@oxinabox
Copy link
Member

oxinabox commented Sep 9, 2021

I do not understand the title of this issue.

@mcabbott mcabbott changed the title constructor rrule should remove names Pullback in constructor rrule should unwrap a NamedDimsArray gradient back to an Array Sep 9, 2021
@mzgubic mzgubic added bug Something isn't working good first issue Good for newcomers labels Sep 9, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

3 participants