Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax linear indexing requirement _slightly_ #216

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

torfjelde
Copy link

Currently TrackedArray requires the input to satisfy IndexStyle(x) === IndexLinear() since ReverseDiff currently only has the capability of tracking, well, arrays supporting linear indexing.

But supporting linear indexing and having IndexStyle(x) === IndexLinear() are, IIUC, two different things: you can support linear indexing while still having IndexStyle(x) === IndexCartesian(), i.e. linear indexing is not the most efficient indexing.

For example, DifferentialEquations.DESolution supports linear indexing but has IndexStyle(x) === IndexCartesian().

Currently, this means that DiffEq has to hack around this constraint by converting into a Matrix, completely losing all the information related to the DESolution.

This PR adds a method supports_linear_indexing which gives arrays such as DESolution a way to tell ReverseDiff that it supports linear indexing even though it's not maybe the most efficient way to index in the array.

I honestly don't know 100% if this is the way to go, but it seems to do the trick locally (and seem to compute the correct gradients) so figured I'd make a PR to maybe at least get a discussion going.

@devmotion
Copy link
Member

Isn't it much faster to use a Matrix with efficient linear indexing in the DESolution example? Why would you want to use the expensive linear -> cartesian computations every time you index the solution with ReverseDiff? I'm a bit worried that this leads to performance issues that are difficult to debug and surprising for users.

@torfjelde
Copy link
Author

torfjelde commented Jan 17, 2023

Isn't it much faster to use a Matrix with efficient linear indexing in the DESolution example?

Well, currently you end up with every solve call returning ODESolution except if you use ReverseDiff, in which case it returns a Matrix. Yeah it's more efficient, but it's very weird and confusing to the user 😕 You can't even check if the solver converged!

Personally I'd rather take slightly slower AD with ReverseDiff than AD with ReverseDiff that completely breaks the expectation of the user and functionality.

@torfjelde
Copy link
Author

And, in the ODESolution example, there's nothing stopping us from converting the resulting TrackedArray(::ODESolution) into a ODESolution(::TrackedArray) if that is more efficient for subsequent computation. But as things are right now, we can't even construct the TrackedARray(::ODESolution), right? Unless I'm missing something (which is not very unlikely 🙃 ), of course.

@codecov-commenter
Copy link

codecov-commenter commented Jan 17, 2023

Codecov Report

Base: 84.48% // Head: 84.51% // Increases project coverage by +0.03% 🎉

Coverage data is based on head (af4ced7) compared to base (f06b776).
Patch coverage: 100.00% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #216      +/-   ##
==========================================
+ Coverage   84.48%   84.51%   +0.03%     
==========================================
  Files          18       18              
  Lines        1921     1925       +4     
==========================================
+ Hits         1623     1627       +4     
  Misses        298      298              
Impacted Files Coverage Δ
src/tracked.jl 92.33% <100.00%> (+0.02%) ⬆️
src/macros.jl 94.17% <0.00%> (+0.08%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@devmotion
Copy link
Member

I wonder if one could just define

function TrackedArray(sol::ODESolution)
    ODESolution(TrackedArray(sol.u), sol.u_analytic, sol.errors, sol.t, sol.k, sol.prob, sol.alg,   sol.interp, sol.dense, sol.tslocation, sol.destats,  sol.alg_choice, sol.retcode)
end

(possibly one has to handle eltype(sol.u) <: Real and eltype(sol.u) <: AbstractArray{<:Real} separately)? But I'm still a bit confused where exactly the TrackedArray(sol) calls show up and if one could avoid them (at least in some cases) in the first place by constructing an ODESolution(TrackedArray(...), ...) directly.

@torfjelde
Copy link
Author

torfjelde commented Jan 17, 2023

Something like that might be possible?
TrackedArray(sol) happens when we call ReverseDiff.track(solve_up, ...) no?

EDIT: Just TrackedArray(sol) won't work though since we need to propagate the gradient information. Something like

function ReverseDiff.track(::DiffEqBase.ODESolution, tp::Vector{ReverseDiff.AbstractInstruction}=ReverseDiff.InstructionTape())
    DiffEqBase.ODESolution(
        ReverseDiff.track(sol.u, tp),  # But this won't work because `sol.u` is a `Vector{<:Vector}`.
        sol.u_analytic,
        sol.errors,
        sol.t,
        sol.k,
        sol.prob,
        sol.alg,
        sol.interp,
        sol.dense,
        sol.tslocation,
        sol.destats,
        sol.alg_choice,
        sol.retcode
    )
end

@torfjelde
Copy link
Author

But regardless of the discussion related to ODESolution, this feature should be useful in broader context, no?

@devmotion
Copy link
Member

I'm still not sure if disabling the check should be called a feature. I think it would be great though if ReverseDiff would suppport IndexCartesian but I don't know what the challenges/problems are.

From a practical perspective, if you would want to implement supports_linear_indexing in a downstream package, you would have to depend on ReverseDiff. So maybe it would be easier if it would be possible to just redefine the constructor (I think currently that's not possible since it's an inner constructor?) instead of adding an additional function to the API. So maybe an approach to make this IndexLinear hack less official and not advertise it too much would be to move the assertions (which IMO should be changed to proper exceptions) to the outer constructor? Then downstream packages or users could add outer constructors for their array types if they want to.

BTW could implementing the 3-arg outer constructor (with the derivative information and tape) fix the ODESolution issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants