Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Cholesky passing CI #217

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Cholesky passing CI #217

wants to merge 1 commit into from

Conversation

rofinn
Copy link
Member

@rofinn rofinn commented Jun 22, 2022

I don't think type piracy is the right solution here, but it narrows down the specific changes that broke our codebase. Perhaps we should re-add these methods to ChainRules?

Closes #216

@rofinn rofinn changed the title Narrowed down the minimum changes needed to get CI passing again Cholesky passing CI Jun 22, 2022
Comment on lines +56 to +58
idx = hasfield(T, :factors) && sym in (:U, :L) ? :factors : sym
hasfield(T, idx) || return ZeroTangent()
return unthunk(getfield(ChainRulesCore.backing(tangent), idx))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this return an upper/lower triangular matrix if tangent.U/tangent.L is requested? It seems this implementation would just return tangent.factors in both cases.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC in the CR PR it was also discussed if getproperty for these tangents should be added to CR.

Copy link
Member Author

@rofinn rofinn Jun 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this return an upper/lower triangular matrix if tangent.U/tangent.L is requested?

Yeah, this was just to see what was needed to make tests pass. AFAIK, tangent.U and tangent.L have just been renamed to factors shouldn't this just work as is in most cases? I guess the concern is that factors would be the incorrect type?

IIRC in the CR PR it was also discussed if getproperty for these tangents should be added to CR.

Ideally, this is something that should be added to CR, but if folks disagree it can live here to keep things working.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth at least opening an issue on ChainRules to discuss adding this there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# https://github.com/JuliaDiff/ChainRules.jl/pull/630

# Single arg function was dropped
function ChainRules.rrule(::typeof(cholesky), A::AbstractMatrix{<:Real})

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary? In LinearAlgebra, the single-arg method calls the 2-arg method.

Copy link
Member Author

@rofinn rofinn Jun 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still trying to wrap my head around how all the overloading works here, but I think it has to do with how Nabla.jl generates overloads for Nabla.Node types based of existing rrule signatures. I could probably dig a bit deeper into how to define an explicit overloads in Nabla, but the rrule solution seems easier :)

https://github.com/invenia/Nabla.jl/blob/f12de3ea148f1b348615b1ee24ab2a63e68d92d5/src/sensitivities/chainrules.jl

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, but this issue would likely then crop up elsewhere, since ChainRules tries to define rrules for methods that are called by other methods with fewer arguments. e.g. the rrule for lu would probably also be missed by Nabla: https://github.com/JuliaDiff/ChainRules.jl/blob/6ff4c319f8fd25f27636d28144d78c92f81d8753/src/rulesets/LinearAlgebra/factorization.jl#L134-L136

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the "missing" single-arg method is only breaking some very specific tests such as

X_ = Matrix{Float64}(I, 5, 5)
X = Leaf(Tape(), X_)
C = cholesky(X)
@test C isa Branch{<:Cholesky}
but AD is still working? That is, maybe one just has to update the tests?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nabla.jls Node type doesn't subtype AbstractMatrix (or Number).
This means it had problems going through things that have that kind of type restriction.

@oxinabox
Copy link
Member

oxinabox commented Jun 24, 2022

Well done.
Nice detective work.
I think these are the right fixes and we should add the piratical ones to CR.jl

I wonder why we didn't spot them in the reverse dependency checks?

(I am still away sick)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Breaking ChainRules 1.35.3 release
4 participants