Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional support of JAX to accelerate some partial derivatives #418

Open
kanekosh opened this issue Dec 19, 2023 · 0 comments
Open

Add optional support of JAX to accelerate some partial derivatives #418

kanekosh opened this issue Dec 19, 2023 · 0 comments

Comments

@kanekosh
Copy link
Contributor

Description of feature

When using a dense VLM mesh, compute_partials in some components (e.g., eval_mtx in aerodynamics) becomes a bottleneck for derivative computations. These partials can be accelerated by replacing the current analytical derivatives with AD.
Aditya Deshpande and Sriram Bommakanti tried it out for the AE588 project, and they showed that AD actually accelerated the partials. Their prototype implementation can be found in their fork.

AD support should be optional because we don't want to add JAX as a hard dependency (for now), and AD likely doesn't offer performance benefits for moderate mesh size.

Potential solution

  1. Run profiling and identify the components that can be accelerated by AD. eval_mtx is one, but there could be others.
  2. Replace (part of) the compute_partials method with AD. We'll need to try out multiple AD options as Aditya and Sriram did.
  3. Add an optional dependency on JAX in setup.py
  4. Add a documentation page on AD - ideally, suggest a mesh size threshold at which the AD becomes faster than the default analytical partials.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant