You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As of now, the group Lasso regularizer is not compatible with pytrees, but only with arrays.
The fundamental difference is that when we pass a model matrix in array format, parameters are grouped based on column indices; on the other hand, when we pass pytree parameters, groups are given by the dictionary structure itself, (i.e. params["group_1"], ... , params["group_n"], are the different groups).
We need to discuss if and how to maintain compatibility with both pytrees and arrays for group Lasso. Below an implementation of the group Lasso operator that works with pytrees only:
this is a proximal operator that is equivalent to the original implementation that works with pytrees parameters. It assumes a tree representation of the regularizer, which could be more flexible for cases in which we want to have a regularizer strength that is variable specific.
# should be added to src/nemos/proximal_operator.pydefprox_group_lasso_pytree(
params: Tuple[DESIGN_INPUT_TYPE, jnp.ndarray], l2reg: DESIGN_INPUT_TYPE, scaling=1.0
):
r"""Proximal operator for the l2 norm, i.e., block soft-thresholding operator. Parameters ---------- params : The input. `params[0]` are the weights (a tree of JAX arrays or FeaturePytree). `params[1]` are the intercepts (a JAX array). l2reg : The regularization strength, which is a pytree with the same structure as `params[0]`. scaling : A scaling factor applied to the regularization term. Defaults to 1.0. Returns ------- : The rescaled weights. Notes ----- This function implements the proximal operator for a group-Lasso penalization which can be derived in analytical form. The proximal operator equation are, $$ \text{prox}(\beta_g) = \text{min}_{\beta} \left[ \lambda \sum\_{g=1}^G \Vert \beta_g \Vert_2 + \frac{1}{2} \Vert \hat{\beta} - \beta \Vert_2^2 \right], $$ where $G$ is the number of groups, and $\beta_g$ is the parameter vector associated with the $g$-th group. The analytical solution[^1] for the beta is, $$ \text{prox}(\beta\_g) = \max \left(1 - \frac{\lambda \sqrt{p\_g}}{\Vert \hat{\beta}\_g \Vert_2}, 0\right) \cdot \hat{\beta}\_g, $$ where $p_g$ is the dimensionality of $\beta\_g$ and $\hat{\beta}$ is typically the gradient step of the un-regularized optimization objective function. It's easy to see how the group-Lasso proximal operator acts as a shrinkage factor for the un-penalize update, and the half-rectification non-linearity that effectively sets to zero group of coefficients satisfying, $$ \Vert \hat{\beta}\_g \Vert_2 \le \frac{1}{\lambda \sqrt{p\_g}}. $$ [^1]: Yuan, Ming, and Yi Lin. "Model selection and estimation in regression with grouped variables." Journal of the Royal Statistical Society Series B: Statistical Methodology 68.1 (2006): 49-67. """# assume that the last axis are the featuresl2_norm=jax.tree_map(
lambdaxx: jnp.linalg.norm(xx, axis=-1, keepdims=True) /jnp.sqrt(xx.shape[-1]), params[0]
)
factor=jax.tree_map(lambdaxx, yy: 1-xx*scaling/yy, l2reg, l2_norm)
factor=jax.tree_map(jax.nn.relu, factor)
returnjax.tree_map(lambdaxx, yy: xx*yy, factor, params[0]), params[1]
A test that checks the equivalence between the pytree-based and array based implementation is the following:
# this should be added to tests/test_proximal_operator.pydeftest_compare_group_lasso(example_data_prox_operator):
"""Compare the group lasso prox operators."""params, regularizer_strength, mask, scaling=example_data_prox_operator# create a pytree version of paramsparams_tree=FeaturePytree(**{f"{k}": params[0][:, jnp.array(msk, dtype=bool)] fork, mskinenumerate(mask)})
# create a regularizer tree with the same struct as params_treetreedef=jax.tree_util.tree_structure(params_tree)
# make sure the leaves are arrays (otherwise FeaturePytree cannot be instantiated)alpha_tree=jax.tree_util.tree_unflatten(treedef, [jnp.atleast_1d(regularizer_strength)] *treedef.num_leaves)
# compute updates using both functionsupdated_params=prox_group_lasso(params, regularizer_strength, mask, scaling)
updated_params_tree=prox_group_lasso_pytree((params_tree, params[1]), alpha_tree, scaling)
# check agreementcheck_updates= [
jnp.all(updated_params[0][:, jnp.array(msk, dtype=bool)] ==updated_params_tree[0][f"{k}"])
fork, mskinenumerate(mask)
]
assertall(check_updates)
assertall(updated_params_tree[1] ==updated_params[1])
The text was updated successfully, but these errors were encountered:
As of now, the group Lasso regularizer is not compatible with pytrees, but only with arrays.
The fundamental difference is that when we pass a model matrix in array format, parameters are grouped based on column indices; on the other hand, when we pass pytree parameters, groups are given by the dictionary structure itself, (i.e. params["group_1"], ... , params["group_n"], are the different groups).
We need to discuss if and how to maintain compatibility with both pytrees and arrays for group Lasso. Below an implementation of the group Lasso operator that works with pytrees only:
this is a proximal operator that is equivalent to the original implementation that works with pytrees parameters. It assumes a tree representation of the regularizer, which could be more flexible for cases in which we want to have a regularizer strength that is variable specific.
A test that checks the equivalence between the pytree-based and array based implementation is the following:
The text was updated successfully, but these errors were encountered: