Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equinox models integration #1709

Open
juanitorduz opened this issue Dec 27, 2023 · 4 comments
Open

Equinox models integration #1709

juanitorduz opened this issue Dec 27, 2023 · 4 comments

Comments

@juanitorduz
Copy link
Contributor

It would be nice to have equinox_module and random_equinox_module model functions in https://github.com/pyro-ppl/numpyro/blob/master/numpyro/contrib/module.py as Equinox seems to be in quite active development.

Would this be a good addition?

I could give it a shot in the upcoming months but I will need some guidance :) Still, I am also happy if a more experienced dev wants to give it a go. XD.

@fehiepsi
Copy link
Member

Hi @juanitorduz, if you need this feature, please feel free to put it in contrib.module. I guess you can mimic random_flax_module for an implementation. If you need to clarify something, please leave a comment in this issue thread.

@juanitorduz
Copy link
Contributor Author

Great! Makes sense. Thank you @fehiepsi ! I'll give it a try in the upcoming months!

@danielward27
Copy link
Contributor

I've been using this in my package flowjax for registering parameters for equinox modules.


def register_params(
    name: str,
    model: PyTree,
    filter_spec: Callable | PyTree = eqx.is_inexact_array,
):
    """Register numpyro params for an arbitrary pytree.

    This partitions the parameters and static components, registers the parameters using
    numpyro.param, then recombines them. This should be called from within an inference
    context to have an effect, e.g. within a numpyro model or guide function.

    Args:
        name: Name for the parameter set.
        model: The pytree (e.g. an equinox module, flowjax distribution/bijection).
        filter_spec: Equinox `filter_spec` for specifying trainable parameters. Either a
            callable `leaf -> bool`, or a PyTree with prefix structure matching `dist`
            with True/False values. Defaults to `eqx.is_inexact_array`.

    """
    params, static = eqx.partition(model, filter_spec)
    if callable(params):
        # Wrap to avoid special handling of callables by numpyro. Numpyro expects a
        # callable to be used for lazy initialization, whereas in our case it is likely
        # a callable module we wish to train.
        params = numpyro.param(name, lambda _: params)
    else:
        params = numpyro.param(name, params)
    return eqx.combine(params, static)

It's not particularly well tested, and I'm not familiar with the implementations for other frameworks, but maybe it's another useful reference. After training I just use eqx.combine(trained_params, model) to retrieve the trained module.

@juanitorduz
Copy link
Contributor Author

Thank you @danielward27 ! This will be a great entry point! (I am planning to tackle this sometime in February)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants