Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow gradient transform parameters to be dynamic #516

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

NeilGirdhar
Copy link
Contributor

No description provided.

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Mar 30, 2023

@hawkinsp Pinging you since you recently repaired some type annotation errors. The optimizer classes accepting only float breaks type annotations for the Tjax shim classes (https://github.com/NeilGirdhar/tjax/blob/main/tjax/_src/gradient/transforms.py). Tjax provides a parallel set of optimizers, identical in functionality, except they support dynamic optimizer parameters. They do this by storing dynamic fields in a dataclass rather than closing over parameters.

However, the optimizer functionality is delegated to Optax, which means calling Optax update methods with Jax arrays. Is there any reason Optax methods can't accept such arrays? Would it be possible to widen these parameter types to jax.Array | float?

@NeilGirdhar
Copy link
Contributor Author

@mtthss Would you mind taking a look at this?

@mtthss
Copy link
Collaborator

mtthss commented Oct 10, 2023

Hello. I was on paternity leave for most of the past year. Are you still having this issue? Happy to look into it if that's the case

@NeilGirdhar
Copy link
Contributor Author

@mtthss Hello, yes I'm still getting the type errors. (Congrats on becoming a father!)

@mtthss
Copy link
Collaborator

mtthss commented Oct 10, 2023

which arguments are causing errors to you?

@NeilGirdhar
Copy link
Contributor Author

which arguments are causing errors to you?

All of the ones I changed. I maintain a shim library so that I can use optax with dynamic, inspectable parameters. What I ended up doing for the time being is to mark every use of optax with pyright: ignore.

Thanks for taking a look at this.

@NeilGirdhar
Copy link
Contributor Author

(Of course, my dream would be that you adopt the dynamic design so that I don't have to maintain my shim library 😄.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants