Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing an adjoint calculation for backprop-ing through time #1

Open
ianwilliamson opened this issue Apr 24, 2019 · 8 comments
Open
Assignees
Labels
enhancement New feature or request

Comments

@ianwilliamson
Copy link
Collaborator

ianwilliamson commented Apr 24, 2019

Should consider the performance benefit of implementing an adjoint calculation for the backward pass through the forward() method in WaveCell. This would potentially save us on memory during gradient computation because pytorch doesn't need to construct as large of a graph.

The approach is described here: https://pytorch.org/docs/stable/notes/extending.html

@ianwilliamson ianwilliamson added the enhancement New feature or request label Apr 24, 2019
@twhughes twhughes self-assigned this Apr 30, 2019
@parenthetical-e
Copy link

Sorry to pop in, but on the off and maybe small chance you folks haven’t seen this lib/paper:

https://github.com/rtqichen/torchdiffeq
https://arxiv.org/pdf/1806.07366.pdf

Implements a ODE solver and uses adjoint methods for the backward pass. This is what you need?

I was already thinking about porting WaveCell to it for my own use. Collaborate?

@ianwilliamson
Copy link
Collaborator Author

Thanks for your interest in this! We are aware of that paper, but unfortunately we can't apply the scheme they propose here because the wave equation with loss (from the absorbing layer) is not reversible.

The "adjoint calculation" I'm referring to here is basically just hard coding the gradient for the time step using the pytorch API documented here: https://pytorch.org/docs/stable/notes/extending.html The motivation for this is that we can potentially save a bunch of memory because pytorch doesn't need to store the fields at every sub-operation of each time step. However, it still needs to store the fields at each time step (there's no getting around this when the differential equation isn't reversible.) In contrast, the neural ODE paper reconstructs these fields by reversing the forward equation during backpropagation, thus, avoiding the need to store the fields from the forward pass.

We actually have this adjoint approach implemented, I just need to push the commits to this repository.

@ianwilliamson
Copy link
Collaborator Author

I'm definitely interested to learn about your project and what you hope to do. We would certainly be open to collaboration if there's an opportunity.

@parenthetical-e
Copy link

Thanks for your interest in this! We are aware of that paper, but unfortunately we can't apply the scheme they propose here because the wave equation with loss (from the absorbing layer) is not reversible.

The "adjoint calculation" I'm referring to here is basically just hard coding the gradient for the time step using the pytorch API documented here: https://pytorch.org/docs/stable/notes/extending.html The motivation for this is that we can potentially save a bunch of memory because pytorch doesn't need to store the fields at every sub-operation of each time step. However, it still needs to store the fields at each time step (there's no getting around this when the differential equation isn't reversible.) In contrast, the neural ODE paper reconstructs these fields by reversing the forward equation during backpropagation, thus, avoiding the need to store the fields from the forward pass.

Ah. I understand. Thanks for the explanation.

@parenthetical-e
Copy link

I sent you an email about the project I'm pondering. :)

@twhughes
Copy link
Member

twhughes commented Aug 26, 2019 via email

@parenthetical-e
Copy link

Done, @twhughes

@ianwilliamson
Copy link
Collaborator Author

This is now partially implemented. Currently, the individual time step is a primitive. This seems to help with memory utilization during training, especially with nonlinearity. Perhaps we could investigate if there would be significant performance benefits from adjoint-ing the time loop as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants