Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse AD: more nuanced handling of consumed arrays. #1702

Open
zfnmxt opened this issue Jul 13, 2022 · 0 comments
Open

Reverse AD: more nuanced handling of consumed arrays. #1702

zfnmxt opened this issue Jul 13, 2022 · 0 comments
Labels
AD Related to automatic differentiation compiler enhancement

Comments

@zfnmxt
Copy link
Collaborator

zfnmxt commented Jul 13, 2022

In each scope in which arrays are consumed, the reverse AD pass makes a copy of each consumed array just before entering the scope. In the forward pass, the arrays are substituted with their respective copies so that the original arrays (with their original values) remain available to the reverse pass. [1]

While this works, it's inefficient. For example, consider

let xs' = if it_is_raining
          then
            let z = xs[0] * xs[0]
            let res = xs with [0] = z
            in res
          else
            ...

After applying AD, we have

-- forward sweep
let xs_copy = copy xs
let xs' = if it_is_raining
          then
            let z = xs_copy[0] * xs_copy[0]
            let res = xs_copy with [0] = z
            in res
          else
            ...
     
-- reverse sweep
let xs_adj = if it_is_raining
             then
               let z = xs[0] * xs[0]
               let res = xs with [0] = z
               let z_adj += res[0]
               let xs_adj = xs_adj with [0] = 2 * xs[0] * z_adj
               in xs_adj
             else
               ...

Instead, we can just save the individual updated element(s) and avoid copying an entire new array. The saved element(s) are then used to restore the array so that all intermediate variables can be reproduced.

-- forward sweep
let (xs', xs_0) = if it_is_raining
                  then
                    let z = xs[0] * xs[0]
                    let xs_0 = xs[0] -- save the overwritten element
                    let res = xs with [0] = z
                    in (res, xs_0)
                  else
                    ...
     
-- reverse sweep
let xs_adj = if it_is_raining
             then
               let xs_restore = xs' with [0] = xs_0  -- restore the overwritten element
               let z = xs_restore[0] * xs_restore[0]
               let res = xs_restore with [0] = z
               let z_adj += res[0]
               let xs_adj = xs_adj with [0] = 2 * xs[0] * z_adj
               in xs_adj
             else
               ...

This is very similar to the saving/restoring we already do for scatter, just instrumented to work across scopes instead of only within scopes; the main distinction is that the forward re-execution sweep in each new scope must be modified to appropriately restore values--a restore must be placed preceding any statements which preceded the corresponding save in the forward sweep.

[1] returnSweepCode is responsible for substituting the names of the copies back to the originals in the reverse pass. There was a technical reason for choosing the originals in the return pass (instead of the copies), and unfortunately I forget it.

@zfnmxt zfnmxt added enhancement compiler AD Related to automatic differentiation labels Jul 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AD Related to automatic differentiation compiler enhancement
Projects
None yet
Development

No branches or pull requests

1 participant