Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TurncatedNormal gives wrong results sometimes #1788

Open
Joshuaalbert opened this issue Feb 19, 2024 · 4 comments
Open

TurncatedNormal gives wrong results sometimes #1788

Joshuaalbert opened this issue Feb 19, 2024 · 4 comments

Comments

@Joshuaalbert
Copy link
Contributor

Truncated normal gives wrong values sometimes. Seems to be when the scale is relatively small, but in surprising situations where you'd expect it to work like TruncatedNormal(1, 0.1, 0, 10).

MVCE

import jax
import jax.numpy as jnp
import pytest
import tensorflow_probability.substrates.jax as tfp

tfpd = tfp.distributions


@pytest.mark.parametrize("scale", [0.01, 0.1])
@pytest.mark.parametrize("low", [0.0, 0.])
@pytest.mark.parametrize("high", [10, jnp.inf])
def test_truncated_normal(low, high, scale):
    dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
    u = jnp.linspace(0., 1., 100)

    samples = jax.vmap(dist.quantile)(u)
    assert jnp.all(samples >= low)
    assert jnp.all(samples <= high)
@Joshuaalbert
Copy link
Contributor Author

6 out of 8 tests fail

========================= 6 failed, 2 passed in 1.72s ==========================
FAILED                     [ 12%]
debug/error.py:8 (test_truncated_normal[10-0.00-0.01])
Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,
       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,
       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,
       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,
       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,
       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) != 0.0

<Click to see difference>

low = 0.0, high = 10, scale = 0.01

    @pytest.mark.parametrize("scale", [0.01, 0.1])
    @pytest.mark.parametrize("low", [0.0, 0.])
    @pytest.mark.parametrize("high", [10, jnp.inf])
    def test_truncated_normal(low, high, scale):
        dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
        u = jnp.linspace(0., 1., 100)
    
        samples = jax.vmap(dist.quantile)(u)
>       assert jnp.all(samples >= low)
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,\n       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,\n       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,\n       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,\n       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,\n       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) >= 0.0)
E        +    where <function all at 0x7f09ad5271a0> = jnp.all

error.py:17: AssertionError
FAILED                      [ 25%]
debug/error.py:8 (test_truncated_normal[10-0.00-0.1])
Array([0.        , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,
       0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,
       0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,
       0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,
       0.9165573 , 0.92009175, 0.923529  , 0.9268783 , 0.9301474 ,
       0.9333436 , 0.936473  , 0.93954146, 0.942554  , 0.9455153 ,
       0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,
       0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,
       0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,
       0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734  ,
       1.001266  , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,
       1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,
       1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,
       1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,
       1.0544847 , 1.057446  , 1.0604585 , 1.063527  , 1.0666565 ,
       1.0698526 , 1.0731218 , 1.076471  , 1.0799083 , 1.0834427 ,
       1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,
       1.1073987 , 1.1120205 , 1.116895  , 1.122064  , 1.1275817 ,
       1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,
       1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 ,        inf],      dtype=float32) != 10

<Click to see difference>

low = 0.0, high = 10, scale = 0.1

    @pytest.mark.parametrize("scale", [0.01, 0.1])
    @pytest.mark.parametrize("low", [0.0, 0.])
    @pytest.mark.parametrize("high", [10, jnp.inf])
    def test_truncated_normal(low, high, scale):
        dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
        u = jnp.linspace(0., 1., 100)
    
        samples = jax.vmap(dist.quantile)(u)
        assert jnp.all(samples >= low)
>       assert jnp.all(samples <= high)
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([0.        , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,\n       0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,\n       0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,\n       0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,\n       0.9165573 , 0.92009175, 0.923529  , 0.9268783 , 0.9301474 ,\n       0.9333436 , 0.936473  , 0.93954146, 0.942554  , 0.9455153 ,\n       0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,\n       0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,\n       0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,\n       0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734  ,\n       1.001266  , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,\n       1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,\n       1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,\n       1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,\n       1.0544847 , 1.057446  , 1.0604585 , 1.063527  , 1.0666565 ,\n       1.0698526 , 1.0731218 , 1.076471  , 1.0799083 , 1.0834427 ,\n       1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,\n       1.1073987 , 1.1120205 , 1.116895  , 1.122064  , 1.1275817 ,\n       1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,\n       1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 ,        inf],      dtype=float32) <= 10)
E        +    where <function all at 0x7f09ad5271a0> = jnp.all

error.py:18: AssertionError
FAILED                     [ 37%]
debug/error.py:8 (test_truncated_normal[10-0.01-0.01])
Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,
       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,
       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,
       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,
       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,
       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) != 0.0

<Click to see difference>

low = 0.0, high = 10, scale = 0.01

    @pytest.mark.parametrize("scale", [0.01, 0.1])
    @pytest.mark.parametrize("low", [0.0, 0.])
    @pytest.mark.parametrize("high", [10, jnp.inf])
    def test_truncated_normal(low, high, scale):
        dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
        u = jnp.linspace(0., 1., 100)
    
        samples = jax.vmap(dist.quantile)(u)
>       assert jnp.all(samples >= low)
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,\n       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,\n       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,\n       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,\n       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,\n       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) >= 0.0)
E        +    where <function all at 0x7f09ad5271a0> = jnp.all

error.py:17: AssertionError
FAILED                      [ 50%]
debug/error.py:8 (test_truncated_normal[10-0.01-0.1])
Array([0.        , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,
       0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,
       0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,
       0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,
       0.9165573 , 0.92009175, 0.923529  , 0.9268783 , 0.9301474 ,
       0.9333436 , 0.936473  , 0.93954146, 0.942554  , 0.9455153 ,
       0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,
       0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,
       0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,
       0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734  ,
       1.001266  , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,
       1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,
       1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,
       1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,
       1.0544847 , 1.057446  , 1.0604585 , 1.063527  , 1.0666565 ,
       1.0698526 , 1.0731218 , 1.076471  , 1.0799083 , 1.0834427 ,
       1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,
       1.1073987 , 1.1120205 , 1.116895  , 1.122064  , 1.1275817 ,
       1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,
       1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 ,        inf],      dtype=float32) != 10

<Click to see difference>

low = 0.0, high = 10, scale = 0.1

    @pytest.mark.parametrize("scale", [0.01, 0.1])
    @pytest.mark.parametrize("low", [0.0, 0.])
    @pytest.mark.parametrize("high", [10, jnp.inf])
    def test_truncated_normal(low, high, scale):
        dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
        u = jnp.linspace(0., 1., 100)
    
        samples = jax.vmap(dist.quantile)(u)
        assert jnp.all(samples >= low)
>       assert jnp.all(samples <= high)
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([0.        , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,\n       0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,\n       0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,\n       0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,\n       0.9165573 , 0.92009175, 0.923529  , 0.9268783 , 0.9301474 ,\n       0.9333436 , 0.936473  , 0.93954146, 0.942554  , 0.9455153 ,\n       0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,\n       0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,\n       0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,\n       0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734  ,\n       1.001266  , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,\n       1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,\n       1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,\n       1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,\n       1.0544847 , 1.057446  , 1.0604585 , 1.063527  , 1.0666565 ,\n       1.0698526 , 1.0731218 , 1.076471  , 1.0799083 , 1.0834427 ,\n       1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,\n       1.1073987 , 1.1120205 , 1.116895  , 1.122064  , 1.1275817 ,\n       1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,\n       1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 ,        inf],      dtype=float32) <= 10)
E        +    where <function all at 0x7f09ad5271a0> = jnp.all

error.py:18: AssertionError
FAILED                    [ 62%]
debug/error.py:8 (test_truncated_normal[inf-0.00-0.01])
Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,
       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,
       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,
       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,
       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,
       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) != 0.0

<Click to see difference>

low = 0.0, high = inf, scale = 0.01

    @pytest.mark.parametrize("scale", [0.01, 0.1])
    @pytest.mark.parametrize("low", [0.0, 0.])
    @pytest.mark.parametrize("high", [10, jnp.inf])
    def test_truncated_normal(low, high, scale):
        dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
        u = jnp.linspace(0., 1., 100)
    
        samples = jax.vmap(dist.quantile)(u)
>       assert jnp.all(samples >= low)
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,\n       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,\n       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,\n       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,\n       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,\n       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) >= 0.0)
E        +    where <function all at 0x7f09ad5271a0> = jnp.all

error.py:17: AssertionError
PASSED                     [ 75%]FAILED                    [ 87%]
debug/error.py:8 (test_truncated_normal[inf-0.01-0.01])
Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,
       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,
       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,
       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,
       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,
       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) != 0.0

<Click to see difference>

low = 0.0, high = inf, scale = 0.01

    @pytest.mark.parametrize("scale", [0.01, 0.1])
    @pytest.mark.parametrize("low", [0.0, 0.])
    @pytest.mark.parametrize("high", [10, jnp.inf])
    def test_truncated_normal(low, high, scale):
        dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
        u = jnp.linspace(0., 1., 100)
    
        samples = jax.vmap(dist.quantile)(u)
>       assert jnp.all(samples >= low)
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,\n       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,\n       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,\n       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,\n       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,\n       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) >= 0.0)
E        +    where <function all at 0x7f09ad5271a0> = jnp.all

error.py:17: AssertionError
PASSED                     [100%]

@ColCarroll
Copy link
Contributor

Hey! Thanks for opening this issue -- it looks like the problem is with the boundaries here, as we might expect

import numpy.testing as npt
import scipy.stats as st

low = 0.0
u = jnp.linspace(0., 1., 100)
for scale in [0.01, 0.1]:
  for high in [10, jnp.inf]:
    rv = st.truncnorm((low - 1.) / scale, (high - 1.) / scale, 1.0, scale)
    dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
    print(scale, low, high)
    print(dist.quantile(jnp.array([0, 1.])), rv.ppf(jnp.array([0, 1.])))
    npt.assert_allclose(dist.quantile(u[1:-1]), rv.ppf(u[1:-1]), atol=1e-7)

Outputs

0.01 0.0 10
[-inf  inf] [ 0. 10.]
0.01 0.0 inf
[-inf  inf] [ 0. inf]
0.1 0.0 10
[ 0. inf] [ 0. 10.]
0.1 0.0 inf
[ 0. inf] [ 0. inf]

@Joshuaalbert
Copy link
Contributor Author

Joshuaalbert commented Feb 23, 2024

What's interesting is that if you go to log space, the argument to ndtri(...) in the quantile is finite at both ends. It's just fairly close to infinite. I think following up with a few steps of bisection would solve this, because ndtr is more stable than ndtri. Make sense? WDYT?

@Joshuaalbert
Copy link
Contributor Author

Or, thinking about this again, perhaps the best would be to clip the output of the quantile to the range, and then define a safe custom gradient rule.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants