Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NumpyroPlayerModel doesn't work after updating numpyro #611

Open
jack89roberts opened this issue Aug 7, 2023 · 1 comment
Open

NumpyroPlayerModel doesn't work after updating numpyro #611

jack89roberts opened this issue Aug 7, 2023 · 1 comment
Labels
bug Something isn't working

Comments

@jack89roberts
Copy link
Contributor

test_get_fitted_player_model_numpyro marked as an xfail.

Error is:

./airsenal/tests/test_score_predictions.py::test_get_fitted_player_model_numpyro Failed: [undefined]RuntimeError: Cannot find valid initial parameters. Please check your model again.
def test_get_fitted_player_model_numpyro():
        pm = NumpyroPlayerModel()
        assert isinstance(pm, NumpyroPlayerModel)
        with test_past_data_session_scope() as ts:
>           fpm = fit_player_data("FWD", "1819", 12, model=pm, dbsession=ts)

airsenal/tests/test_score_predictions.py:269: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
airsenal/framework/prediction_utils.py:525: in fit_player_data
    fitted_model = model.fit(data)
airsenal/framework/player_model.py:177: in fit
    mcmc.run(
.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py:628: in run
    states_flat, last_state = partial_map_fn(map_args)
.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py:410: in _single_chain_mcmc
    new_init_state = self.sampler.init(
.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py:713: in init
    init_params = self._init_state(
.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py:657: in _init_state
    ) = initialize_model(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

rng_key = Array([3923418436, 1366451097], dtype=uint32)
model = <numpyro.handlers.substitute object at 0x2880dec50>

    def initialize_model(
        rng_key,
        model,
        *,
        init_strategy=init_to_uniform,
        dynamic_args=False,
        model_args=(),
        model_kwargs=None,
        forward_mode_differentiation=False,
        validate_grad=True,
    ):
        """
        (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn`
        and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood
        to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`).
    
        :param jax.random.PRNGKey rng_key: random number generator seed to
            sample from the prior. The returned `init_params` will have the
            batch shape ``rng_key.shape[:-1]``.
        :param model: Python callable containing Pyro primitives.
        :param callable init_strategy: a per-site initialization function.
            See :ref:`init_strategy` section for available functions.
        :param bool dynamic_args: if `True`, the `potential_fn` and
            `constraints_fn` are themselves dependent on model arguments.
            When provided a `*model_args, **model_kwargs`, they return
            `potential_fn` and `constraints_fn` callables, respectively.
        :param tuple model_args: args provided to the model.
        :param dict model_kwargs: kwargs provided to the model.
        :param bool forward_mode_differentiation: whether to use forward-mode differentiation
            or reverse-mode differentiation. By default, we use reverse mode but the forward
            mode can be useful in some cases to improve the performance. In addition, some
            control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop`
            only supports forward-mode differentiation. See
            `JAX's The Autodiff Cookbook <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html>`_
            for more information.
        :param bool validate_grad: whether to validate gradient of the initial params.
            Defaults to True.
        :return: a namedtupe `ModelInfo` which contains the fields
            (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where
            `param_info` is a namedtuple `ParamInfo` containing values from the prior
            used to initiate MCMC, their corresponding potential energy, and their gradients;
            `postprocess_fn` is a callable that uses inverse transforms
            to convert unconstrained HMC samples to constrained values that
            lie within the site's support, in addition to returning values
            at `deterministic` sites in the model.
        """
        model_kwargs = {} if model_kwargs is None else model_kwargs
        substituted_model = substitute(
            seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]),
            substitute_fn=init_strategy,
        )
        (
            inv_transforms,
            replay_model,
            has_enumerate_support,
            model_trace,
        ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
        # substitute param sites from model_trace to model so
        # we don't need to generate again parameters of `numpyro.module`
        model = substitute(
            model,
            data={
                k: site["value"]
                for k, site in model_trace.items()
                if site["type"] in ["param"]
            },
        )
        constrained_values = {
            k: v["value"]
            for k, v in model_trace.items()
            if v["type"] == "sample"
            and not v["is_observed"]
            and not v["fn"].support.is_discrete
        }
    
        if has_enumerate_support:
            from numpyro.contrib.funsor import config_enumerate, enum
    
            if not isinstance(model, enum):
                max_plate_nesting = _guess_max_plate_nesting(model_trace)
                _validate_model(model_trace, plate_warning="error")
                model = enum(config_enumerate(model), -max_plate_nesting - 1)
        else:
            _validate_model(model_trace, plate_warning="loose")
    
        potential_fn, postprocess_fn = get_potential_fn(
            model,
            inv_transforms,
            replay_model=replay_model,
            enum=has_enumerate_support,
            dynamic_args=dynamic_args,
            model_args=model_args,
            model_kwargs=model_kwargs,
        )
    
        init_strategy = (
            init_strategy if isinstance(init_strategy, partial) else init_strategy()
        )
        if (init_strategy.func is init_to_value) and not replay_model:
            init_values = init_strategy.keywords.get("values")
            unconstrained_values = transform_fn(inv_transforms, init_values, invert=True)
            init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
        prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
        (init_params, pe, grad), is_valid = find_valid_initial_params(
            rng_key,
            substitute(
                model,
                data={
                    k: site["value"]
                    for k, site in model_trace.items()
                    if site["type"] in ["plate"]
                },
            ),
            init_strategy=init_strategy,
            enum=has_enumerate_support,
            model_args=model_args,
            model_kwargs=model_kwargs,
            prototype_params=prototype_params,
            forward_mode_differentiation=forward_mode_differentiation,
            validate_grad=validate_grad,
        )
    
        if not_jax_tracer(is_valid):
            if device_get(~jnp.all(is_valid)):
                with numpyro.validation_enabled(), trace() as tr:
                    # validate parameters
                    substituted_model(*model_args, **model_kwargs)
                    # validate values
                    for site in tr.values():
                        if site["type"] == "sample":
                            with warnings.catch_warnings(record=True) as ws:
                                site["fn"]._validate_sample(site["value"])
                            if len(ws) > 0:
                                for w in ws:
                                    # at site information to the warning message
                                    w.message.args = (
                                        "Site {}: {}".format(
                                            site["name"], w.message.args[0]
                                        ),
                                    ) + w.message.args[1:]
                                    warnings.showwarning(
                                        w.message,
                                        w.category,
                                        w.filename,
                                        w.lineno,
                                        file=w.file,
                                        line=w.line,
                                    )
>               raise RuntimeError(
                    "Cannot find valid initial parameters. Please check your model again."
                )
E               RuntimeError: Cannot find valid initial parameters. Please check your model again.

.venv/lib/python3.11/site-packages/numpyro/infer/util.py:745: RuntimeError
@jack89roberts jack89roberts added the bug Something isn't working label Aug 7, 2023
@jack89roberts
Copy link
Contributor Author

Also:

airsenal/tests/test_score_predictions.py::test_get_fitted_player_model_numpyro
  /Users/jroberts/GitHub/AIrsenal/airsenal/framework/player_model.py:177: UserWarning: Site obs: Out-of-support values provided to log prob method. The value argument should be within the support.
    mcmc.run(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant