Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maybe add epsilon to RelaxedOneHotCategorical to prevent underflow #3304

Open
mtvector opened this issue Dec 7, 2023 · 3 comments
Open

Maybe add epsilon to RelaxedOneHotCategorical to prevent underflow #3304

mtvector opened this issue Dec 7, 2023 · 3 comments

Comments

@mtvector
Copy link

mtvector commented Dec 7, 2023

I've noticed that pyro.distributions.RelaxedOneHotCategorical tends to underflow pretty dramatically if you decrease the temperature below 0.3 or so with many categories. I've been adding a slight modification to the rsample function of the ExpRelaxedCategorical class it's built on. Just wanted to post this in case you want to consider this (maybe hacky) fix to make this distribution work with pyro support constraints.

modified from here https://github.com/pytorch/pytorch/blob/main/torch/distributions/relaxed_categorical.py :

class ExpRelaxedCategorical(Distribution):
    r"""
    Creates a ExpRelaxedCategorical parameterized by
    :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both).
    Returns the log of a point in the simplex. Based on the interface to
    :class:`OneHotCategorical`.

    Implementation based on [1].

    See also: :func:`torch.distributions.OneHotCategorical`

    Args:
        temperature (Tensor): relaxation temperature
        probs (Tensor): event probabilities
        logits (Tensor): unnormalized log probability for each event

    [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
    (Maddison et al, 2017)

    [2] Categorical Reparametrization with Gumbel-Softmax
    (Jang et al, 2017)
    """
    arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
    support = (
        constraints.real_vector
    )  # The true support is actually a submanifold of this.
    has_rsample = True

    def __init__(self, temperature, probs=None, logits=None, validate_args=None):
        self._categorical = Categorical(probs, logits)
        self.temperature = temperature
        batch_shape = self._categorical.batch_shape
        event_shape = self._categorical.param_shape[-1:]
        super().__init__(batch_shape, event_shape, validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(ExpRelaxedCategorical, _instance)
        batch_shape = torch.Size(batch_shape)
        new.temperature = self.temperature
        new._categorical = self._categorical.expand(batch_shape)
        super(ExpRelaxedCategorical, new).__init__(
            batch_shape, self.event_shape, validate_args=False
        )
        new._validate_args = self._validate_args
        return new

    def _new(self, *args, **kwargs):
        return self._categorical._new(*args, **kwargs)

    @property
    def param_shape(self):
        return self._categorical.param_shape

    @property
    def logits(self):
        return self._categorical.logits

    @property
    def probs(self):
        return self._categorical.probs

    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        uniforms = clamp_probs(
            torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
        )
        gumbels = -((-(uniforms.log())).log())
        scores = (self.logits + gumbels) / self.temperature
        #######*add a floor to prevent underflow*#########
        #could also clamp_probs
        outs = scores - scores.logsumexp(dim=-1, keepdim=True)
        outs = outs.exp()
        outs=outs+1e-10
        outs = (outs/outs.sum(1,keepdim=True)).log()
        return outs
       ###########################################

    def log_prob(self, value):
        K = self._categorical._num_events
        if self._validate_args:
            self._validate_sample(value)
        logits, value = broadcast_all(self.logits, value)
        log_scale = torch.full_like(
            self.temperature, float(K)
        ).lgamma() - self.temperature.log().mul(-(K - 1))
        score = logits - value.mul(self.temperature)
        score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
        return score + log_scale
@mtvector
Copy link
Author

mtvector commented Dec 18, 2023

FYI: I also took a stab at fixing the straightthroughcategorical, this could still use some work but it works for me where the previous RelaxedCategoricalStraightThrough would not train as part of an GMM-VAE

class RelaxedQuantizeCategorical(torch.autograd.Function):
    temperature = None  # Default temperature
    epsilon = 1e-10    # Default epsilon

    @staticmethod
    def set_temperature(new_temperature):
        RelaxedQuantizeCategorical.temperature = new_temperature

    @staticmethod
    def set_epsilon(new_epsilon):
        RelaxedQuantizeCategorical.epsilon = new_epsilon

    @staticmethod
    def forward(ctx, soft_value):
        temperature = float(RelaxedQuantizeCategorical.temperature)
        epsilon = RelaxedQuantizeCategorical.epsilon
        uniforms = clamp_probs(
            torch.rand(soft_value.shape, dtype=soft_value.dtype, device=soft_value.device)
        )
        gumbels = -((-(uniforms.log())).log())
        scores = (soft_value + gumbels) / temperature
        outs = scores - scores.logsumexp(dim=-1, keepdim=True)
        outs = outs.exp()
        outs = outs + epsilon  # Use the class variable epsilon
        hard_value = (outs / outs.sum(1, keepdim=True)).log()
        hard_value._unquantize = soft_value
        return hard_value

    @staticmethod
    def backward(ctx, grad):
        return grad


class ExpRelaxedCategoricalStraightThrough(Distribution):
    arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
    support = (
        constraints.real_vector
    )  # The true support is actually a submanifold of this.
    has_rsample = True

    def __init__(self, temperature, probs=None, logits=None, validate_args=None, epsilon=1e-10):
        self._categorical = Categorical(probs, logits)
        self.temperature = temperature
        RelaxedQuantizeCategorical.set_temperature(temperature)
        RelaxedQuantizeCategorical.set_epsilon(epsilon)
        
        batch_shape = self._categorical.batch_shape
        event_shape = self._categorical.param_shape[-1:]
        super().__init__(batch_shape, event_shape, validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(ExpRelaxedCategorical, _instance)
        batch_shape = torch.Size(batch_shape)
        new.temperature = self.temperature
        new._categorical = self._categorical.expand(batch_shape)
        super(ExpRelaxedCategorical, new).__init__(
            batch_shape, self.event_shape, validate_args=False
        )
        new._validate_args = self._validate_args
        return new

    def _new(self, *args, **kwargs):
        return self._categorical._new(*args, **kwargs)

    @property
    def param_shape(self):
        return self._categorical.param_shape

    @property
    def logits(self):
        return self._categorical.logits

    @property
    def probs(self):
        return self._categorical.probs

    def rsample(self, sample_shape=torch.Size()):
        outs=RelaxedQuantizeCategorical.apply(self.logits)
        return outs

    def log_prob(self, value):
        value = getattr(value, "_unquantize", value)
        K = self._categorical._num_events
        if self._validate_args:
            self._validate_sample(value)
        logits, value = broadcast_all(self.logits, value)
        score = logits 
        score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
        return score 

class SafeAndRelaxedOneHotCategoricalStraightThrough(TransformedDistribution,TorchDistributionMixin):
    #Don't understand why these were broken (doesn't call straighthrough rsample in pyro)?
    arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
    support = constraints.simplex
    has_rsample = True

    def __init__(self, temperature, probs=None, logits=None, validate_args=None):
        base_dist = ExpRelaxedCategoricalStraightThrough(
            temperature, probs, logits, validate_args=validate_args
        )
        super().__init__(base_dist, ExpTransform(), validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
        return super().expand(batch_shape, _instance=new)

    @property
    def temperature(self):
        return self.base_dist.temperature

    @property
    def logits(self):
        return self.base_dist.logits

    @property
    def probs(self):
        return self.base_dist.probs

@fritzo
Copy link
Member

fritzo commented Dec 30, 2023

Hi @mtvector, I think our general design principle with distributions is to make them hackable with decent defaults. In this case I'd lean towards letting users add their own epsilon in a custom distribution class. In my own projects I often have one or two custom distributions for each data science project. What do you think of a simple patched distribution, just for your project?

from pyro.distributions import ExpRelaxedCategorical

class SafeExpRelaxedCategorical(ExpRelaxedCategorical):
    epsilon = 1e-10

    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        uniforms = clamp_probs(
            torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
        )
        gumbels = -((-(uniforms.log())).log())
        scores = (self.logits + gumbels) / self.temperature
        #could also clamp_probs
        outs = scores - scores.logsumexp(dim=-1, keepdim=True)
        outs = outs.exp()
        outs = outs + self.epsilon  # prevent underflow
        outs = (outs / outs.sum(1, keepdim=True)).log()
        return outs

Actually I often find that (1) clamping is safer than adding, and (2) it's best to use torch.finfo(-).tiny rather than a hard-coded epsilon. So you might customize

class SafeExpRelaxedCategorical2(ExpRelaxedCategorical):
    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        uniforms = clamp_probs(
            torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
        )
        gumbels = -((-(uniforms.log())).log())
        scores = (self.logits + gumbels) / self.temperature
        #could also clamp_probs
        outs = scores - scores.logsumexp(dim=-1, keepdim=True)
        outs = outs.exp()
        outs = outs.clamp(min=torch.finfo(outs.dtype).tiny)
        outs = (outs / outs.sum(1, keepdim=True)).log()
        return outs

WDYT?

@mtvector
Copy link
Author

mtvector commented Jan 3, 2024

Hi @fritzo,
I agree in principle, you're right about the the hackability as well as using the proper epsilon or torch tiny, still working on my coding modularity :). I do think it's important to fix the default though, I used pyro for two years and thought the RelaxedCategorical was totally unusable because it seems to fail in the following:

import pyro
import torch
import pyro.distributions as dist

def model(logits):
    pyro.sample('cat_sample',dist.RelaxedOneHotCategorical(temperature=0.1*torch.ones(1),
                                                                      logits=torch.zeros(1,1000)))


def guide(logits):
    pyro.sample('cat_sample',dist.RelaxedOneHotCategorical(temperature=0.1*torch.ones(1),
                                                                      logits=logits))

pyro.clear_param_store()
optim = pyro.optim.Adam({"lr": 0.1})
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(model, guide, optim, loss=elbo)
for i in range(10):
    logits=torch.randn(1,1000)
    loss = svi.step(logits)

Giving the error due to underflow:

 warn_if_nan(
.../pyro/lib/python3.11/site-packages/pyro/poutine/trace_struct.py:285: UserWarning: Encountered NaN: log_prob_sum at site 'cat_sample'

You're right about the fix, for instance your first resolves the issue with the underflow in a more elegant way than what I proposed:

import pyro.distributions
from torch.distributions.relaxed_categorical import ExpRelaxedCategorical
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.distributions import TransformedDistribution

class SafeExpRelaxedCategorical(ExpRelaxedCategorical):
    epsilon = 1e-10

    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        uniforms = clamp_probs(
            torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
        )
        gumbels = -((-(uniforms.log())).log())
        scores = (self.logits + gumbels) / self.temperature
        #could also clamp_probs
        outs = scores - scores.logsumexp(dim=-1, keepdim=True)
        outs = outs.exp()
        outs = outs + self.epsilon  # prevent underflow
        outs = (outs / outs.sum(1, keepdim=True)).log()
        return outs


class SafeRelaxedOneHotCategorical(TransformedDistribution,TorchDistributionMixin):
    r"""
    Creates a RelaxedOneHotCategorical distribution parametrized by
    :attr:`temperature`, and either :attr:`probs` or :attr:`logits`.
    This is a relaxed version of the :class:`OneHotCategorical` distribution, so
    its samples are on simplex, and are reparametrizable.

    Example::

        >>> # xdoctest: +IGNORE_WANT("non-deterinistic")
        >>> m = RelaxedOneHotCategorical(torch.tensor([2.2]),
        ...                              torch.tensor([0.1, 0.2, 0.3, 0.4]))
        >>> m.sample()
        tensor([ 0.1294,  0.2324,  0.3859,  0.2523])

    Args:
        temperature (Tensor): relaxation temperature
        probs (Tensor): event probabilities
        logits (Tensor): unnormalized log probability for each event
    """
    arg_constraints = {'probs': constraints.simplex,
                       'logits': constraints.real_vector}
    support = constraints.simplex
    has_rsample = True

    def __init__(self, temperature, probs=None, logits=None, validate_args=None):
        base_dist = SafeExpRelaxedCategorical(temperature, probs, logits, validate_args=validate_args)
        super().__init__(base_dist, ExpTransform(), validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
        return super().expand(batch_shape, _instance=new)


    @property
    def temperature(self):
        return self.base_dist.temperature

    @property
    def logits(self):
        return self.base_dist.logits

    @property
    def probs(self):
        return self.base_dist.probs




def model(logits):
    pyro.sample('cat_sample',SafeRelaxedOneHotCategorical(temperature=0.1*torch.ones(1),
                                                                      logits=torch.ones(1,1000)))


def guide(logits):
    
    pyro.sample('cat_sample',SafeRelaxedOneHotCategorical(temperature=0.1*torch.ones(1),
                                                                      logits=logits))

pyro.clear_param_store()
optim = pyro.optim.Adam({"lr": 0.1})
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(model, guide, optim, loss=elbo)
for i in range(10):
    logits=torch.randn(1,1000)
    loss = svi.step(logits)

Which gives no error, like my SafeAndRelaxedOneHotCategoricalStraightThrough above

So, yeah, it seems like the default for RelaxedOneHotCategorical should use one of these SafeExpRelaxedCategorical bases you've proposed here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants