Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ad extension [WIP] #85

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open

Ad extension [WIP] #85

wants to merge 15 commits into from

Conversation

Jutho
Copy link
Owner

@Jutho Jutho commented May 13, 2024

No description provided.

Copy link

codecov bot commented May 13, 2024

Codecov Report

Attention: Patch coverage is 90.65109% with 56 lines in your changes are missing coverage. Please review.

Project coverage is 84.31%. Comparing base (da91706) to head (06a36a6).
Report is 1 commits behind head on master.

Files Patch % Lines
ext/KrylovKitChainRulesCoreExt/eigsolve.jl 90.22% 26 Missing ⚠️
ext/KrylovKitChainRulesCoreExt/svdsolve.jl 90.27% 21 Missing ⚠️
ext/KrylovKitChainRulesCoreExt/linsolve.jl 89.79% 5 Missing ⚠️
src/linsolve/linsolve.jl 50.00% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master      #85      +/-   ##
==========================================
+ Coverage   82.05%   84.31%   +2.26%     
==========================================
  Files          27       31       +4     
  Lines        2753     3271     +518     
==========================================
+ Hits         2259     2758     +499     
- Misses        494      513      +19     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

pbrehmer and others added 3 commits May 13, 2024 21:15
* Add untested svdsolve rrule

* Fix typos

* Add svdsolve rrule to extension folder

* Delete src/adrules/svdsolve.jl

---------

Co-authored-by: Jutho <Jutho@users.noreply.github.com>
@Jutho Jutho mentioned this pull request May 25, 2024
@Jutho
Copy link
Owner Author

Jutho commented May 25, 2024

I think this is now mostly ready, up to some cleanup and streamlining of the interface. Maybe @lkdvos wants to review?

@Jutho
Copy link
Owner Author

Jutho commented May 25, 2024

There is one more significant TODO:

The eigenvalue approach for solving the linear problem/Sylvester problem in both the rrule of eigsolve and svdsolve is nonhermitian, which means that the results are always obtained in complex arithmetic, even when the forward calculation can be completely real (namely for a real symmetric eigenvalue problem or a real singular value problem). This so far does not cause problems, as apparently the imaginary parts of the computed quantities is exactly zero and therefore it is implicitly converted back to real vectors. However, this might break with custom types, so it would be better to explicitly restrict to real arithmetic by using schursolve and a custom routine to extract the real eigenvectors associated with the results coming out of schursolve.

@lkdvos
Copy link
Collaborator

lkdvos commented May 25, 2024

There is one more significant TODO:

The eigenvalue approach for solving the linear problem/Sylvester problem in both the rrule of eigsolve and svdsolve is nonhermitian, which means that the results are always obtained in complex arithmetic, even when the forward calculation can be completely real (namely for a real symmetric eigenvalue problem or a real singular value problem). This so far does not cause problems, as apparently the imaginary parts of the computed quantities is exactly zero and therefore it is implicitly converted back to real vectors. However, this might break with custom types, so it would be better to explicitly restrict to real arithmetic by using schursolve and a custom routine to extract the real eigenvectors associated with the results coming out of schursolve.

I haven't looked into this in detail, but this sounds like it could also be solved with appropriate calls to ProjectTo, which should work even for custom types as there it gives a hook to add the correct projection methods.

Copy link
Collaborator

@lkdvos lkdvos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely a lot of nice work, looks great!

I think I mainly have some minor typos and small nitpicking things, but maybe as a more general comment:
I am not a huge fan of the dummy (; alg_rrule=nothing) in the keyword arguments of the forward passes. Conceptually, I quite like that the method definition does not need to know anything about the AD that may or not happen, and it feels a bit unnatural that if I were to decide to implement eg a MinRes linsolver, I need to remember to add the keyword argument.

I think, once we have some benchmarks, it should be possible to have decent default rrule algorithm defaults based on the forward algorithm, and then any expert user who still wants to play around with the different implementations, or experiment with new ones could do the (little bit of) extra work of doing something along the lines of hook_pullback(f, args...; kwargs..., alg_rrule=my_alg), which we can even hide in a macro: @hook_pullback f(args...; kwargs..., alg_rrule=myalg) or @alg_rrule f(args...; kwargs..., alg_rrule=myalg) or something similar.
This being said, this could also just be me, and if it does not bother you as much, it might not be worth it to change it.

Finally,

ext/KrylovKitChainRulesCoreExt/eigsolve.jl Outdated Show resolved Hide resolved
Comment on lines 42 to 45
if n == 0
∂f = ZeroTangent()
return ∂self, ∂f, ∂x₀, ∂howmany, ∂which, ∂alg
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a type instability (I assume this is what you were referring to this week?)
Do you think we can avoid this by:

  1. inserting pullback_eigsolve(ΔX::Tuple{AbstractZero, AbstractZero, Any}) = [...] (with ∂f = ZeroTangent())
  2. throwing a warning and explicitly computing the zero pullback

I am honestly not sure if the second case ever happens, as I think this implies that the dependence on the eigenvalues and eigenvectors is exactly zero, which sounds incredibly implausible with floating point accuracy. This would both mean that the regular (most common) case is now type-stable, and that the case where n = 0 gets handled properly when both inputs are AbstractZero (which afaik e.g. Zygote would never even generate either)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I will get rid of this.

ext/KrylovKitChainRulesCoreExt/eigsolve.jl Outdated Show resolved Hide resolved
ext/KrylovKitChainRulesCoreExt/eigsolve.jl Outdated Show resolved Hide resolved
function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, fᴴ, T,
alg_primal::Arnoldi, alg_rrule::Arnoldi)
n = length(Δvecs)
G = zeros(T, n, n)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
G = zeros(T, n, n)
G = Matrix{T}(undef, n, n) # eigenvector overlap matrix

Comment on lines +205 to +206
eigsolve(W₀, n, :LR, alg_rrule) do w
x, y, z = w
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
eigsolve(W₀, n, :LR, alg_rrule) do w
x, y, z = w
eigsolve(W₀, n, :LR, alg_rrule) do (x, y, z)

ext/KrylovKitChainRulesCoreExt/svdsolve.jl Outdated Show resolved Hide resolved
end
return xs, ys
end

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be possible to merge some of the underlying methods by explicitly using rrule_via_ad(config, apply_adjoint, ...) and the apply_normal variants.
Similarly, it might make sense to define rrules for these functions, as we have access to both normal and adjoint function applications? This helps with the compile times and can be a lot more convenient in allowing mutable steps within the function, even when the total function is non-mutating. (e.g. preallocating an array and then filling it)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throughout this file I think the syntax f = x -> body is used over f(x) = body. Semantically, they are the same, but the latter associates the name with the anonymous function, such that debugging becomes a little clearer:

julia> buildfun() = h(x) = x * 3;
julia> buildfun()
(::var"#h#11") (generic function with 1 method)
julia> buildfun2() = x -> x * 3;
julia> buildfun2()
#12 (generic function with 1 method)

@@ -95,3 +96,476 @@ end
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are some tolerances missing above this line :)

@Jutho
Copy link
Owner Author

Jutho commented May 30, 2024

Ok, I think this is mostly ready. Maybe I can add a few more tests to improve coverage (e.g. print warnings and test for them). The TODO for not having to go via complex values in case of Arnoldi rrule for svdsolve or hermitian eigsolve is also still open.

About the interface; I did not really study the comment of @lkdvos above. What do you dislike about the keyword argument ; alg_rrule = ... in the methods? Whether the actual values to this keyword need to have the values they currently have (recycling existing structures) or some new values is certainly open for debate. And a sensible default definitely needs to be in place after benchmarking. This also reminds me, documentation about this needs to be added.

@lkdvos
Copy link
Collaborator

lkdvos commented May 30, 2024

I think my main argument against the alg_rrule kwarg is that it "pollutes" the method definition with information that is only relevant to AD. In other words, if I were to for example write an implementation of minres, I now need to remember to add that keyword argument, even though this has nothing to do with the primal computation.

I would much rather have a different implementation that keeps the primal computations clean. One such way is to simply add a wrapper function with the alg_rrule kwarg added, which can then correctly distribute the args and kwargs to their relevant places.

I am a bit more in favour of something like:

vals, vecs, info = eigsolve(A, x, num, which, alg) # can infer default AD algorithm

# option 1:
vals, vecs, info = hook_pullback(eigsolve, A, x, num, which, alg; alg_rrule) # expert mode -- specifies rrule algorithm

# option 2:
@alg_rrule vals, vecs, info = eigsolve(A, x, num, which, alg; alg_rrule) # looks like current implementation, but expands to option 1

This being said, in principle this is mostly a conceptual issue, and this does not really change all that much. I like keeping AD separated, but this is obviously strictly necessary.


Concerning the test coverage, it might be a good idea to add a sparse matrix to the set of tests. This seems like a good candidate for something that checks if our assumptions about what we can/cannot do with AbstractMatrix types are fair

@Jutho
Copy link
Owner Author

Jutho commented May 30, 2024

Ok I see; I agree that it ads a keyword that is irrelevant to the forward computation. As that is mostly a "burden" on the developer side, I don't mind too much. If there would be some centralised infrastructure to do the hook or macro solution, I would use it, but for now, I think there is less overhead in simply adding those keywords rather than developing it within the scope of this package.

From the user side, I think the current approach is fine right? If they don't need AD, they don't need to care about this keyword, and if they do, it is easy to try out the different choices without any significant change to the code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants