Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invalid value when calling log_prob after sample #7

Open
louisabraham opened this issue Dec 23, 2022 · 5 comments
Open

Invalid value when calling log_prob after sample #7

louisabraham opened this issue Dec 23, 2022 · 5 comments

Comments

@louisabraham
Copy link

My code looks like

    m = TruncatedNormal(loc, scale, 0, 1)
    action_pt = m.sample()
    return m.log_prob(action_pt)

It looks like action_pt can take the value 1.0 and causes log_prob to raise an error:

    111     def log_prob(self, value):
    112         if self._validate_args:
--> 113             self._validate_sample(value)
    114         return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5
    115 

~/.pyenv/versions/3.8.8/lib/python3.8/site-packages/torch/distributions/distribution.py in _validate_sample(self, value)
    291         valid = support.check(value)
    292         if not valid.all():
--> 293             raise ValueError(

I don't know if the error is:

  1. that the value 1.0 shouldn't be able to be picked
  2. that the value 1.0 is in the possible interval and shouldn't be called out as impossible
@toshas
Copy link
Owner

toshas commented Dec 23, 2022

The exception is raised based on the support check, meaning that 1.0 doesn't land into the support interval. Since loc and scale aren't given in the snippet, it is hard to say if this is an issue with precision or incorrect usage of parameters. The interface was designed to follow conventions of the similar scipy function

@louisabraham
Copy link
Author

Here is a reproducible example:

m = TruncatedNormal(torch.full((1000,), 2.), torch.full((1000,), .2), 0, 1)
m.log_prob(m.sample())

Some values are more than 1 and some are -inf:

tensor([0.9667, 0.9819, 0.9819, 0.9160, 0.9819, 0.9974, 0.9411, 0.9929, 0.9667,
        0.9819, 0.9667, 0.9160, 0.9667,   -inf, 0.9560,   -inf, 0.9560, 0.9160,
        0.9160, 0.9751, 0.9160, 0.9411, 1.0015,...

@toshas
Copy link
Owner

toshas commented Dec 24, 2022

One thing I'd try first is plug these values in the unit test here https://github.com/toshas/torch_truncnorm/blob/main/tests/test.py#L97 and see if it passes the check against scipy. If not, there is a bug..

@louisabraham
Copy link
Author

I added a line self._test_numerical(2.0, 0.2, 0.0, 1.0)

It gives:

======================================================================
FAIL: test_simple (__main__.Tests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "testa.py", line 103, in test_simple
    self._test_numerical(2.0, 0.2, 0.0, 1.0)
  File "testa.py", line 74, in _test_numerical
    self.assertRelativelyEqual(mean_sc, mean_pt)
  File "testa.py", line 66, in assertRelativelyEqual
    raise self.failureException(msg)
AssertionError: array(0.96269921) != array(1.0022793, dtype=float32) within tol=1e-06 abs=1e-05 (rel=0.03949006605978869 diff=0.039580075041381724)

======================================================================
FAIL: test_support (__main__.Tests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "testa.py", line 131, in test_support
    self.assertEqual(
AssertionError: 'Expected value argument (Tensor of shape [157 chars]10.0' != 'The value argument must be within the support'
+ The value argument must be within the support- Expected value argument (Tensor of shape ()) to be within the support (Interval(lower_bound=-1.0, upper_bound=2.0)) of the distribution TruncatedNormal(a: -1.0, b: 2.0), but found invalid values:
- -10.0

The second error is not due to my test, you might want to fix it in another issue. The first IS a bug.

@Wu-Chenyang
Copy link

Wu-Chenyang commented Apr 15, 2023

The reason seems to be that extreme values for the icdf function should be clamped. https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants