Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type errors for general pytrees #384

Open
brentyi opened this issue Jul 31, 2022 · 1 comment
Open

Type errors for general pytrees #384

brentyi opened this issue Jul 31, 2022 · 1 comment

Comments

@brentyi
Copy link
Contributor

brentyi commented Jul 31, 2022

Hello!

When it comes to annotations, optax currently relies heavily on optax.Updates and optax.Params, which are all aliases for chex.ArrayTree.

This makes sense, but for folks who run type checkers means that a lot of type errors happen when working with pytrees that aren't strictly nested Iterable or Mapping types as specified in chex. For example:

from typing import Tuple
import optax
from jax import numpy as jnp
import flax.struct


@flax.struct.dataclass
class Params:
    weights: jnp.ndarray
    bias: jnp.ndarray


def make_optimizer(
    params: Params,
) -> Tuple[optax.GradientTransformation, optax.OptState]:
    """Make an optimizer."""
    optimizer = optax.sgd(learning_rate=1e-3)
    state = optimizer.init(params)  # Type error.
    return optimizer, state

A few questions from this:

  • Is this considered a bug, or something that the optax team would be open to supporting? Are there better solutions for suppressing this error than simply adding a # type: ignore?
  • It seems like type safety with optax could benefit immensely from support for generics, which have been present since Python 3.5 (typing.Generic, typing.TypeVar). Any chance this would be something that optax would be open to supporting?
    • Simple example: with ArrayTreeT = TypeVar("ArrayTreeT", bound=chex.ArrayTree), optax.apply_updates() could be annotated as optax.apply_updates(params: ArrayTreeT, updates: ArrayTreeT) -> ArrayTreeT to indicate that the argument and return types should all be the same.
@mkunesch
Copy link
Member

Hi,

thanks a lot for pointing this out! This is definitely something we should discuss especially if it would be convenient for flax to have these types supported.

I think we should prefer to stick to chex types as the standard to make it easier to ensure safe interoperability with other jax libraries that use chex (e.g. this bug isn't directly related to typing but it shows the bugs that can arise by differences in how the libraries treat more complicated pytrees). I think we should avoid defining our own versions of common types in optax if possible.

@hbq1 : has there been a discussion in chex on extending ArrayTree to include some (common) dataclass implementations?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants