Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorised AD #1697

Open
athas opened this issue Jul 2, 2022 · 2 comments
Open

Vectorised AD #1697

athas opened this issue Jul 2, 2022 · 2 comments
Labels
AD Related to automatic differentiation enhancement

Comments

@athas
Copy link
Member

athas commented Jul 2, 2022

I would like the following functions to be made available:

val jvp2_vec 'a 'b [n] : (a -> b) -> a -> [n]a -> (b, [n]b)

val vjp2_vec 'a 'b [n] : (a -> b) -> a -> [n]b -> (b, [n]b)

The names are open to bikeshedding. The idea is to let AD compute multiple (co)tangents in one go. This can avoid n executions of the primal function. In some cases the compiler might be able to optimise the replicated work, but I wouldn't want to rely on it in all cases.

I think this is fairly straightforward to implement: we just need to teach the AD passes that the (co)tangent of a primal variable of type t is not necessarily of type t, but can also be an array of type [n]t (where n is a constant in any instance of AD).

@athas athas added the AD Related to automatic differentiation label Jul 2, 2022
@zfnmxt
Copy link
Collaborator

zfnmxt commented Jul 3, 2022

Is this inspired by the work Martin is/was doing?

@melsman
Copy link
Contributor

melsman commented Jul 3, 2022 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AD Related to automatic differentiation enhancement
Projects
None yet
Development

No branches or pull requests

3 participants