Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"masked" leads to unexpected behavior with callable Pytrees #913

Open
JadM133 opened this issue Apr 8, 2024 · 6 comments
Open

"masked" leads to unexpected behavior with callable Pytrees #913

JadM133 opened this issue Apr 8, 2024 · 6 comments

Comments

@JadM133
Copy link

JadM133 commented Apr 8, 2024

Good morning!

The "masked" function finds the "mask_tree" as follows:

 mask_tree = mask(params) if callable(mask) else mask                                                             (1)

Which is used twice (in "init_fn" and "update_fn").

However, in some cases, we could easily end up with a mask that's a callable PyTree. For example, masks created based on equinox models.

Since in this case, the mask is both callable, and a Pytree, we could expect either of the conditions provided in (1), which will by default take the first condition to be True.

To avoid an unexpected behavior, I propose to add an optional argument to the function "masked" (e.g. call_mask=False) that allows the user to specify if the specified mask should be called or simply used as is. To avoid confusion, we could add a condition that provides the right error message, something like,

if call_mask==True and not callable(mask):
     raise ValueError("Can not set call_mask to True if the provided mask is not callable.")

What do you think?

@vroulet
Copy link
Collaborator

vroulet commented Apr 8, 2024

Hello @JadM133,

Great point. Could you make a minimum working example (MWE) to pinpoint the error we would get?

I'm wondering whether we could simply make our own definition of "callable" to handle such cases. For example, if I understand you well, one could look at the tree_leaves of mask and see if these leaves are callable, but correct me if I'm wrong (an MWE would be useful to design that).

The proposed string logic is somewhat confusing in my opinion (the user will wonder where this comes from and we would need to explain at length the issue due to handling equinox).

Thanks again for pointing this!

@JadM133
Copy link
Author

JadM133 commented Apr 14, 2024

Hello @vroulet, thank you for your answer! You're right, I should have enclosed a MWE. Below, you can find a typical code written in equinox, where the model is simply an MLP with one layer. The idea is to take the gradients of the biases as they are. The creation of the mask is based on the documentation of equinox here.

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import equinox as eqx
import optax
import pdb
import jax.tree_util as jtu

@eqx.filter_value_and_grad
def grad_loss(model, input, output):
    pred = model(input)
    mse = lambda x, y : jnp.mean(jnp.square(x-y))
    return mse(pred, output)

@eqx.filter_jit
def make_step(input, output, model, opt_state):
    loss, grads = grad_loss(model, input, output)
    updates, opt_state = optim.update(grads, opt_state)
    is_working = (grads.layers[0].bias == updates.layers[0].bias).all()
    jax.debug.print("Is it working? {is_working}", is_working=is_working)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

key, subkey = jax.random.split(jrandom.PRNGKey(0))
xs = jnp.ones((100,))
ys = jax.random.normal(key, (1,))

model = eqx.nn.MLP(xs.shape[-1], ys.shape[-1], 10, 1, key=subkey)

lr = 1e-2
filter_spec = jtu.tree_map(lambda _: True, model)
filter_spec = eqx.tree_at(
    lambda tree: (tree.layers[0].bias, tree.layers[1].bias),
    filter_spec,
    replace=(False, False),
)

optim = optax.masked(optax.adabelief(lr), filter_spec)

opt_state =  optim.init(eqx.filter(model, eqx.is_inexact_array))

for epoch in range(2):
    loss, model, opt_state = make_step(xs, ys, model, opt_state)
    print(f"Epoch {epoch}: {loss}")

Running the code will raise the following error,

TypeError: unsupported operand type(s) for @: 'bool' and 'MLP'

After taking a closer look to where the error is coming from, the problem is that the mask "filter_spec" is callable, and so the function "optax.masked" doesn't use the mask correctly, instead it calls it in the inputs which results by Boolean values multiplying MLP class instances.

I like the idea of defining a special callable function to resolve the issue, what about something like this,

def mask_callable(x):
    import jax.tree_util as jtu
    return all(jtu.tree_leaves(jtu.tree_map(lambda e: callable(e), x)))

We could definitely remove the import part if jtu is imported in the file already.

Replacing "callable" by "mask_callable" in both init_fn and update_fn inside the mask function resolves the issue. The output reads,

Is it working? True
Epoch 0: 0.0036341827362775803
Is it working? True
Epoch 1: 0.309080570936203

Just to check wether this would work on a callable mask as well, I modified the code slightly to create the same mask but as a callable function as follows,

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import equinox as eqx
import optax
import pdb
import jax.tree_util as jtu

@eqx.filter_value_and_grad
def grad_loss(model, input, output):
    pred = model(input)
    mse = lambda x, y : jnp.mean(jnp.square(x-y))
    return mse(pred, output)

@eqx.filter_jit
def make_step(input, output, model, opt_state):
    loss, grads = grad_loss(model, input, output)
    updates, opt_state = optim.update(grads, opt_state)
    is_working = (grads.layers[0].bias == updates.layers[0].bias).all()
    jax.debug.print("Is it working? {is_working}", is_working=is_working)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

key, subkey = jax.random.split(jrandom.PRNGKey(0))
xs = jnp.ones((100,))
ys = jax.random.normal(key, (1,))

model = eqx.nn.MLP(xs.shape[-1], ys.shape[-1], 10, 1, key=subkey)

lr = 1e-2

def create_filter(model):
    filter_spec = jtu.tree_map(lambda _: True, model)
    filter_spec = eqx.tree_at(
        lambda tree: (tree.layers[0].bias, tree.layers[1].bias),
        filter_spec,
        replace=(False, False),
    )
    return filter_spec

filter_spec = lambda tree: create_filter(tree)

optim = optax.masked(optax.adabelief(lr), filter_spec)

opt_state =  optim.init(eqx.filter(model, eqx.is_inexact_array))

for epoch in range(2):
    loss, model, opt_state = make_step(xs, ys, model, opt_state)
    print(f"Epoch {epoch}: {loss}")

Again, with the propose "mask_callable" inside both the init_fn and update_fn, we get the following,

Is it working? True
Epoch 0: 0.0036341827362775803
Is it working? True
Epoch 1: 0.309080570936203

I also verified, in the first case, "mask_callable" returns False, in the second case it returns True.

What do you think?

Thanks in advance!

@vroulet
Copy link
Collaborator

vroulet commented Apr 21, 2024

Hello @JadM133,
Sorry for the delay, it's been a busy week.
Thank you for the detailed example! It really helps.
So the issue is that in equinox the model is a callable and represents the params.
I think your workaround does not exactly work. Consider modifying it with

filter_spec = eqx.tree_at(
    lambda tree: (tree.layers[0].bias, tree.layers[1].bias),
    filter_spec,
    replace=(True, False),
)

@eqx.filter_jit
def make_step(input, output, model, opt_state):
    loss, grads = grad_loss(model, input, output)
    updates, opt_state = optim.update(grads, opt_state)
    is_working = (grads.layers[0].bias != updates.layers[0].bias).all()
    jax.debug.print("Is it working? {is_working}", is_working=is_working)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

This does not give the right answer.

I think it would be best to upstream this issue to equinox if you don't mind.
We would prefer not to add equinox as a dependency but would be happy to find a workaround to ensure compatibility.

@randomekek
Copy link

randomekek commented May 14, 2024

There is a potential backward compatible solution: add explicit keyword parameters in addition to the existing positional parameter: (mask, *, mask_fn, mask_value).

So the existing calling convention would remain - mask is inferred. However if it's explicitly set to be mask_fn or mask_pytree, we know to evaluate it or use it as-is.

Something like the following. Although probably need to validate that only 1 of 3 options are set.

def masked(mask=None, *, mask_fn=None, mask_value=None):
  def init_fn(params):
    if mask_fn:
        mask_tree = mask_fn(params)
    elif mask_value:
        mask_tree = mask_value
    else:
        mask_tree = mask(params) if callable(mask) else mask  # existing code

@randomekek
Copy link

I just realised that you can just wrap your value in a constant callable like lambda: mask, a bit ugly but gets the job done with no changes.

# fail because mask is callable but you just want the value
masked(x)
# succeed because it will unwrap it once
masked(lambda: x)

@JadM133
Copy link
Author

JadM133 commented May 24, 2024

Hello @vroulet , @randomekek , thank you for your replies. Apologies for the delay.

I don't see what you mean @vroulet in the example you proposed. What we want is for the gradients to be taken as is when the mask is False, so in your example, while the bias in layer 0 won't be taken as is (since you changed it to True), the bias in layer 1 will be. This is the case because of the following function,

 def mask_pytree(pytree, mask_tree):
    return jax.tree_util.tree_map(
        lambda m, p: p if m else MaskedNode(), mask_tree, pytree
    )

If this is not expected, we can change this condition to the opposite one. So for now, basically False ---> gradient as is.

I modified the printing statements to the following:

is_working = (grads.layers[0].bias == updates.layers[0].bias).all()
jax.debug.print("Is it working for bias 0? {is_working}", is_working=is_working)
is_working = (grads.layers[1].bias == updates.layers[1].bias).all()
jax.debug.print("Is it working for bias 1? {is_working}", is_working=is_working)

And now I tried both of these scenarios:

Scenario 1:

filter_spec = eqx.tree_at(
lambda tree: (tree.layers[0].bias, tree.layers[1].bias),
filter_spec,
replace=(True, False),
)

With the following output:

Is it working for bias 0? False
Is it working for bias 0? True
Epoch 0: 0.0036341827362775803
Is it working for bias 0? False
Is it working for bias 0? True
Epoch 1: 0.33416470885276794

Scenario 2:

filter_spec = eqx.tree_at(
lambda tree: (tree.layers[0].bias, tree.layers[1].bias),
filter_spec,
replace=(False, True),
)

With the following output:

Is it working for bias 0? True
Is it working for bias 0? False
Epoch 0: 0.0036341827362775803
Is it working for bias 0? True
Is it working for bias 0? False
Epoch 1: 0.47130754590034485

So it is indeed working as expected, False --> frozen. Now I think it is not a good idea to have the relationship False = Masked since it is counterintuitive. If you agree, we could change that as well.

I think the response provided by @randomekek solves this problem, but it would be nicer to use pytree properties if possible. What do you all think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants