Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Group Lasso - Pytree compatibility #78

Open
BalzaniEdoardo opened this issue Jan 5, 2024 · 0 comments
Open

Group Lasso - Pytree compatibility #78

BalzaniEdoardo opened this issue Jan 5, 2024 · 0 comments

Comments

@BalzaniEdoardo
Copy link
Collaborator

As of now, the group Lasso regularizer is not compatible with pytrees, but only with arrays.
The fundamental difference is that when we pass a model matrix in array format, parameters are grouped based on column indices; on the other hand, when we pass pytree parameters, groups are given by the dictionary structure itself, (i.e. params["group_1"], ... , params["group_n"], are the different groups).

We need to discuss if and how to maintain compatibility with both pytrees and arrays for group Lasso. Below an implementation of the group Lasso operator that works with pytrees only:

this is a proximal operator that is equivalent to the original implementation that works with pytrees parameters. It assumes a tree representation of the regularizer, which could be more flexible for cases in which we want to have a regularizer strength that is variable specific.

# should be added to src/nemos/proximal_operator.py
def prox_group_lasso_pytree(
    params: Tuple[DESIGN_INPUT_TYPE, jnp.ndarray], l2reg: DESIGN_INPUT_TYPE, scaling=1.0
):
    r"""Proximal operator for the l2 norm, i.e., block soft-thresholding operator.

     Parameters
    ----------
    params :
        The input. `params[0]` are the weights (a tree of JAX arrays or FeaturePytree).
        `params[1]` are the intercepts (a JAX array).
    l2reg :
        The regularization strength, which is a pytree with
        the same structure as `params[0]`.
    scaling :
        A scaling factor applied to the regularization term. Defaults to 1.0.

    Returns
    -------
    :
        The rescaled weights.

    Notes
    -----
    This function implements the proximal operator for a group-Lasso penalization which
    can be derived in analytical form.
    The proximal operator equation are,

    $$
    \text{prox}(\beta_g) = \text{min}_{\beta} \left[ \lambda  \sum\_{g=1}^G \Vert \beta_g \Vert_2 +
     \frac{1}{2} \Vert \hat{\beta} - \beta \Vert_2^2
    \right],
    $$
    where $G$ is the number of groups, and $\beta_g$ is the parameter vector
    associated with the $g$-th group.
    The analytical solution[^1] for the beta is,

    $$
    \text{prox}(\beta\_g) = \max \left(1 - \frac{\lambda \sqrt{p\_g}}{\Vert \hat{\beta}\_g \Vert_2},
     0\right) \cdot \hat{\beta}\_g,
    $$
    where $p_g$ is the dimensionality of $\beta\_g$ and $\hat{\beta}$ is typically the gradient step
    of the un-regularized optimization objective function. It's easy to see how the group-Lasso
    proximal operator acts as a shrinkage factor for the un-penalize update, and the half-rectification
    non-linearity that effectively sets to zero group of coefficients satisfying,
    $$
    \Vert \hat{\beta}\_g \Vert_2 \le \frac{1}{\lambda \sqrt{p\_g}}.
    $$

    [^1]:
        Yuan, Ming, and Yi Lin. "Model selection and estimation in regression with grouped variables."
        Journal of the Royal Statistical Society Series B: Statistical Methodology 68.1 (2006): 49-67.
    """
    # assume that the last axis are the features
    l2_norm = jax.tree_map(
        lambda xx: jnp.linalg.norm(xx, axis=-1, keepdims=True) / jnp.sqrt(xx.shape[-1]), params[0]
    )
    factor = jax.tree_map(lambda xx, yy: 1 - xx * scaling / yy, l2reg, l2_norm)
    factor = jax.tree_map(jax.nn.relu, factor)
    return jax.tree_map(lambda xx, yy: xx * yy, factor, params[0]), params[1]

A test that checks the equivalence between the pytree-based and array based implementation is the following:

# this should be added to tests/test_proximal_operator.py
def test_compare_group_lasso(example_data_prox_operator):
    """Compare the group lasso prox operators."""
    params, regularizer_strength, mask, scaling = example_data_prox_operator
    # create a pytree version of params
    params_tree = FeaturePytree(**{f"{k}": params[0][:, jnp.array(msk, dtype=bool)] for k, msk in enumerate(mask)})
    # create a regularizer tree with the same struct as params_tree
    treedef = jax.tree_util.tree_structure(params_tree)
    # make sure the leaves are arrays (otherwise FeaturePytree cannot be instantiated)
    alpha_tree = jax.tree_util.tree_unflatten(treedef, [jnp.atleast_1d(regularizer_strength)] * treedef.num_leaves)
    # compute updates using both functions
    updated_params = prox_group_lasso(params, regularizer_strength, mask, scaling)
    updated_params_tree = prox_group_lasso_pytree((params_tree, params[1]), alpha_tree, scaling)
    # check agreement
    check_updates = [
        jnp.all(updated_params[0][:, jnp.array(msk, dtype=bool)] == updated_params_tree[0][f"{k}"])
        for k, msk in enumerate(mask)
    ]
    assert all(check_updates)
    assert all(updated_params_tree[1] == updated_params[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant