Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request for more masking tutorials #3187

Open
6 tasks
LysSanzMoreta opened this issue Mar 10, 2023 · 13 comments
Open
6 tasks

Request for more masking tutorials #3187

LysSanzMoreta opened this issue Mar 10, 2023 · 13 comments
Labels
documentation help wanted Issues suitable for, and inviting external contributions usability

Comments

@LysSanzMoreta
Copy link
Contributor

LysSanzMoreta commented Mar 10, 2023

Hi!

As discussed here https://forum.pyro.ai/t/more-doubts-on-masking-runnable-example/5044/6 and here https://forum.pyro.ai/t/vae-classification/5017/10, things might not be very clear on when and how to use the different masking options. Especially in defining differences in masking usage on the model vs guide. Or masking with enumeration

  • Explain the practical differences between all the masking options (poutine.mask, obs_mask, mask())
  • Usage example with VAE model-guide architecture (supervised, semi-supervised, unsupervised approaches)
  • How to use masks when there are event vs batch shapes
  • Usage with enumeration
  • Nested masks
  • Meaning of the "mask" in the trace. Why when i) using the context manager poutine.mask it shows under the "mask" in the trace the used mask tensor and when ii) using .mask() or obs_mask the "mask" of the trace shows "None".

Thanks! :)

@fritzo fritzo added help wanted Issues suitable for, and inviting external contributions documentation usability labels Mar 14, 2023
@LysSanzMoreta
Copy link
Contributor Author

LysSanzMoreta commented Mar 15, 2023

@fritzo I am particularly puzzled about this Why when i) using the context manager poutine.mask it shows under the "mask" in the trace the used mask tensor and when ii) using .mask() or obs_mask the "mask" of the trace shows "None".

I am getting different results when using something like this:

mask_t = torch.tensor([True,True])
logits = torch.tensor([3.,4.])

Case A: The trace shows the tensor mask_t under "mask" and 'fn': Independent(Categorical(logits: torch.Size([2])), 1)

with pyro.poutine.mask(mask=mask_t):
       pyro.sample("c",dist.Categorical(logits=logits).to_event(1))

Case B: The trace shows None under "mask" and then 'fn': Independent(MaskedDistribution(), 1)

pyro.sample("c",dist.Categorical(logits=logits).mask(mask_t).to_event(1))

Is this expected behaviour? Shouldn't there be 2 different mask types?

Thanks :)

@fritzo
Copy link
Member

fritzo commented Mar 20, 2023

IIRC using Distribution.mask() stores the mask internally to the distribution, rather than in the trace; you should be able to see this with trace.nodes[name]["fn"] being a MaskedDistribution. By contrast poutine.mask() preserves the original distribution in the "fn" slot and stores the mask in the "mask" slot. The reason for the difference is that the trace based version is usually nicer and clearer, but the distribution mask is needed when you want to mask out part of the event shape, since the trace mask must be broadcastable with batch_shape and can have no event_shape.

@LysSanzMoreta
Copy link
Contributor Author

@fritzo Thanks, that is very nice to know. I was aware of the need of Distribution.mask() needed whenevent_shapeis required. However, since I could not "find" where the mask values went, I did not know if it was actually working (using the given mask as intended).

This also leads me to the next concern (which lead me to try and find the mask above). Because I get different results when using

mask_t = torch.Tensor([True,True])
logits = torch.Tensor([3.,4.])
targets = torch.tensor([0.,1.])

The mask is all True values, therefore indicating that all values should be used for the marginal --> Good results

with pyro.poutine.mask(mask=mask_t): #the mask is all True
       pyro.sample("c",dist.Categorical(logits=logits).to_event(1),obs=targets)

No poutine mask (therefore all values should be used in the computation?) --> Bad results

pyro.sample("c",dist.Categorical(logits=logits).to_event(1),obs=targets)

I am guessing this has to do with the event shape, but I do not understand how, since in my head they should be equivalent ... unless without using poutine.mask the default is False?

@fehiepsi
Copy link
Member

fehiepsi commented Mar 23, 2023

Just curious, why dist.Categorical(logits=logits).to_event(1) does not raise an error? dist.Categorical(logits=logits) does not have batch shape.

Regarding mask, the rule of thumb is mask only applies to batch dimensions. Assume you have some univariate distributions and a mask with shape (3,), d.expand([3]).mask(mask).to_event(1) is different from d.expand([3]).to_event(1).mask(mask). The former has no batch dimension and event dimension (3,). The later has batch dimension (3,) due to the last mask operator (which has shape (3,)).

d.expand([3]) --> batch_shape: (3,), event_shape: ()
d.expand([3]).mask(mask) --> batch_shape: (3,), event_shape: ()
d.expand([3]).mask(mask).to_event(1) --> batch_shape: (), event_shape: (3,)
d.expand([3]).to_event(1) --> batch_shape: (), event_shape: (3,)
d.expand([3]).to_event(1).mask(mask) -> batch_shape: (3,), event_shape: (3,)

@LysSanzMoreta
Copy link
Contributor Author

LysSanzMoreta commented Mar 23, 2023

@fehiepsi "Just curious, why dist.Categorical(logits=logits).to_event(1) does not raise an error? dist.Categorical(logits=logits) does not have batch shape."

Well that is actually a relief to hear, because the .to_event(1) is doing something (in combination with the poutine.mask), but not sure what. And I did not expect that to happen (I am not familiar enough with which distributions have batch or event shape though). I have Pyro 1.8.2.

Yes, I figured that the order d.expand([3]).mask(mask).to_event(1) vs d.expand([3]).to_event(1).mask(mask). is important. But this is definitely a confusing factor.

And I now understand that poutine.mask applies over the batch dimensions only, but it is still hard to know when to use which masking method.

@fehiepsi
Copy link
Member

poutine.mask and Distribution.mask would have the same role: masking log probabilities of a distribution. log_prob of a distribution will have shape batch_shape; when masked, its value is log_prob * mask (note that broadcasting rule applies for such multiplication).

obs_mask is used for partial observed data (we have a separate feature request for its tutorial #1676).

because the .to_event(1) is doing something (in combination with the poutine.mask), but not sure what

looking at your code,

mask_t = torch.tensor([True,True])
logits = torch.tensor([3.,4.])

,

with pyro.poutine.mask(mask=mask_t):
       pyro.sample("c",dist.Categorical(logits=logits).to_event(1))

will raise an error, see this line. If it works for your code, then please raise a separate issue with small reproducible code.

@LysSanzMoreta
Copy link
Contributor Author

@fehiepsi annotated, looking into making a reproducible code, brb

@LysSanzMoreta
Copy link
Contributor Author

@fehiepsi Nevermind, it did not raise an error because I had pyro.enable_validation(False)

It is still weird that it gives good results when pyro.enable_validation(False) and using

with pyro.poutine.mask(mask=mask_t):
       pyro.sample("c",dist.Categorical(logits=logits).to_event(1))

However when using

pyro.sample("c",dist.Categorical(logits=logits).to_event(1))

or

with pyro.poutine.mask(mask=mask_t):
       pyro.sample("c",dist.Categorical(logits=logits))

The results are random.

I will try to code everything back again with pyro.enable_validation(True)

@fehiepsi
Copy link
Member

fehiepsi commented Mar 24, 2023

I think you can't use to_event here:

import torch
import pyro
import pyro.distributions as dist
pyro.enable_validation(False)
logits = torch.tensor([3.,4.])
dist.Categorical(logits=logits).to_event(1)

would raise an error. Maybe you can check the shapes of your logits again to get better understanding why there is no error in your system. I don't think that it's due to enable_validation. If your distribution has no event shape, .to_event(1) will raise an error whether or not we enable validation.

@LysSanzMoreta
Copy link
Contributor Author

@fehiepsi Oh, ok, interesting, well my logits simply have shape [N, num_classes], where N is the number of data points and num_classes the number of classes. Let me think about it more and the come back.

@LysSanzMoreta
Copy link
Contributor Author

LysSanzMoreta commented Mar 24, 2023

@fehiepsi Should I open another issue with these examples? The fail when enable_validation is True, not otherwise.

import torch
from torch import tensor
from pyro import sample,plate
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import SVI,Trace_ELBO
from pyro.optim import ClippedAdam
import pyro


def model1(x,obs_mask,x_class,class_mask):
    """
    :param x: Data [N,L,feat_dim]
    :param obs_mask: Data sites to mask [N,L]
    :param x_class: Target values [N,]
    :param class_mask: Target values mask [N,]
    :return:
    """
    z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(2))
    logits =  torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
                            [[1,2,7],[0,2,1],[2,7,8]]])
    aa = sample("x",dist.Categorical(logits= logits),obs=x)
    with pyro.poutine.mask(mask=class_mask):
        c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1), obs=x_class)
    return z,c,aa

def model2(x,obs_mask,x_class,class_mask):
    """
    :param x: Data [N,L,feat_dim]
    :param obs_mask: Data sites to mask [N,L]
    :param x_class: Target values [N,]
    :param class_mask: Target values mask [N,]
    :return:
    """
    z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(2))
    logits =  torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
                            [[1,2,7],[0,2,1],[2,7,8]]])
    aa = sample("x",dist.Categorical(logits= logits).mask(obs_mask).to_event(1),obs=x)
    c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1), obs=x_class)
    return z,c


def model3(x,obs_mask,x_class,class_mask):
    """
    :param x: Data [N,L,feat_dim]
    :param obs_mask: Data sites to mask [N,L]
    :param x_class: Target values [N,]
    :param class_mask: Target values mask [N,]
    :return:
    """
    z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(2))
    logits =  torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
                            [[1,2,7],[0,2,1],[2,7,8]]])
    aa = sample("x",dist.Categorical(logits= logits).to_event(1),obs=x)
    c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1), obs=x_class)
    return z,c,aa

def model4(x,obs_mask,x_class,class_mask):
    """
    :param x: Data [N,L,feat_dim]
    :param obs_mask: Data sites to mask [N,L]
    :param x_class: Target values [N,]
    :param class_mask: Target values mask [N,]
    :return:
    """
    z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(2))
    logits =  torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
                            [[1,2,7],[0,2,1],[2,7,8]]])
    aa = sample("x",dist.Categorical(logits= logits),obs=x,obs_mask=obs_mask) #partial observations is what i am looking for here
    c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).mask(class_mask), obs=x_class) #in the fully supervised approach no mask here, but in the semi-supervised i would need to mask fully some observations
    return z,c,aa

def model5(x,obs_mask,x_class,class_mask):
    """
    :param x: Data [N,L,feat_dim]
    :param obs_mask: Data sites to mask [N,L]
    :param x_class: Target values [N,]
    :param class_mask: Target values mask [N,]
    :return:
    """
    with pyro.plate("plate_batch",dim=-1):
        z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(1))
        logits =  torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
                                [[1,2,7],[0,2,1],[2,7,8]]])
        aa = sample("x",dist.Categorical(logits= logits),obs=x,obs_mask=obs_mask) #partial observations is what i am looking for here
        c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).mask(class_mask), obs=x_class)
    return z,c,aa

def guide(x,obs_mask,x_class,class_mask):
    """
    :param x: Data [N,L,feat_dim]
    :param obs_mask: Data sites to mask [N,L]
    :param x_class: Target values [N,]
    :param class_mask: Target values mask [N,]
    """
    z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(2))

    return z


if __name__ == "__main__":
    pyro.enable_validation(False)

    x = tensor([[0,2,1],
                [0,1,1]])
    obs_mask = tensor([[1,0,0],[1,1,0]],dtype=bool) #Partial observations
    x_class = tensor([0,1])
    class_mask = tensor([True,False],dtype=bool) #keep/skip some observations

    models_dict = {"model1":model1,
                   "model2":model2,
                   "model3":model3,
                   "model4":model4,
                   "model5":model5,
                   }

    for model in models_dict.keys():
        print("Using {}".format(model))
        guide_tr = poutine.trace(guide).get_trace(x,obs_mask,x_class,class_mask)
        model_tr = poutine.trace(poutine.replay(models_dict[model], trace=guide_tr)).get_trace(x,obs_mask,x_class,class_mask)
        monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()
        print("MC ELBO estimate: {}".format(monte_carlo_elbo))
        try:
            pyro.clear_param_store()
            svi = SVI(models_dict[model],guide,loss=Trace_ELBO(),optim=ClippedAdam(dict()))
            svi.step(x,obs_mask,x_class,class_mask)
            print("Test passed")
        except:
            print("Test failed")

By the way, I think I want something like model4 or model5. I accept suggestions on what to do about them, because I am not sure how to handle partial observations of "x" in the model, do I have to do something in the guide? What about semisupervised approaches for the "c" variable?

@fehiepsi
Copy link
Member

fehiepsi commented Mar 24, 2023

Your last example is different from the previous one. Now your logits has shape (N, num_classes), so your categorical distribution will have shape (N,) -> to_event will work.

Should I open another issue

I'm not sure if there is an issue here. to_event should work with 2 dimensional logits. Note that with 2D logits, the code

with mask(mask):
    sample(..., Categorical(logits).to_event(1), obs=...)

will give you a distribution/log_prob with batch_shape = mask.shape, and event_shape (N,). Hope that this clarifies the semantics of mask. If your data has shape (N,), log_likelihood of the above code will be the same as

(dist.Categorical(logits).log_prob(data).sum(-1) * mask).sum()

which is the same as

dist.Categorical(logits).log_prob(data).sum(-1) * mask.sum()

In other words, you are scaling the log likelihood by a factor mask.sum(). There is no "masking" applied here. It makes sense that you can get good results by scaling the likelihood.

@LysSanzMoreta
Copy link
Contributor Author

@fehiepsi Oh, ok , I see where the misunderstanding with the .to_event() started, because my first example was not a good one, sorry.

I need to have a fresh mind to reflect about the last part. Cause that would mean that I accidentally scaled up the likelihood and therefore made the training more efficient? That is so interesting

Then, I want to do the same with the variable "x" hahaha (but keeping the partial observations)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation help wanted Issues suitable for, and inviting external contributions usability
Projects
None yet
Development

No branches or pull requests

3 participants