Skip to content

Commit

Permalink
InverseBijector layer (#91)
Browse files Browse the repository at this point in the history
Summary:
Adds a new InverseBijector layer class. Instances can be created by calling `bij.invert()` or `InverseBijector(bij)`. We make sure that `bij.invert().invert() is bij` returns `True`.
Suggested by #90

Pull Request resolved: #91

Reviewed By: vmoens

Differential Revision: D34122085

Pulled By: stefanwebb

fbshipit-source-id: 250e12d7d4a474377d04a9e0ca6f74f98c729cfa
  • Loading branch information
vmoens authored and facebook-github-bot committed Apr 22, 2022
1 parent 4992731 commit 13cf226
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 27 deletions.
2 changes: 2 additions & 0 deletions flowtorch/bijectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from flowtorch.bijectors.elu import ELU
from flowtorch.bijectors.exp import Exp
from flowtorch.bijectors.fixed import Fixed
from flowtorch.bijectors.invert import Invert
from flowtorch.bijectors.leaky_relu import LeakyReLU
from flowtorch.bijectors.permute import Permute
from flowtorch.bijectors.power import Power
Expand Down Expand Up @@ -52,6 +53,7 @@
("Fixed", Fixed),
("Bijector", Bijector),
("Compose", Compose),
("Invert", Invert),
("VolumePreserving", VolumePreserving),
]

Expand Down
30 changes: 15 additions & 15 deletions flowtorch/bijectors/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) Meta Platforms, Inc
from __future__ import annotations

import warnings
from typing import Optional, Sequence, Tuple, Union, Callable, Iterator
from typing import Callable, Optional, Sequence, Tuple, Union

import flowtorch.parameters
import torch
import torch.distributions
from flowtorch.bijectors.bijective_tensor import to_bijective_tensor, BijectiveTensor
from flowtorch.bijectors.bijective_tensor import BijectiveTensor, to_bijective_tensor
from flowtorch.bijectors.utils import is_record_flow_graph_enabled
from flowtorch.parameters import Parameters
from torch.distributions import constraints
Expand All @@ -15,12 +17,9 @@
]


class Bijector(metaclass=flowtorch.LazyMeta):
class Bijector(torch.nn.Module, metaclass=flowtorch.LazyMeta):
codomain: constraints.Constraint = constraints.real
domain: constraints.Constraint = constraints.real
_shape: torch.Size
_context_shape: Optional[torch.Size]
_params_fn: Optional[Union[Parameters, torch.nn.ModuleList]] = None

def __init__(
self,
Expand All @@ -29,6 +28,8 @@ def __init__(
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
) -> None:
super().__init__()

# Prevent "meta bijectors" from being initialized
# NOTE: We define a "standard bijector" as one that inherits from a
# subclass of Bijector, hence why we need to test the length of the MRO
Expand All @@ -42,18 +43,13 @@ def __init__(
self._context_shape = context_shape

# Instantiate parameters (tensor, hypernets, etc.)
self._params_fn: Optional[Union[Parameters, torch.nn.ModuleList]] = None
if params_fn is not None:
param_shapes = self.param_shapes(shape)
self._params_fn = params_fn( # type: ignore
param_shapes, self._shape, self._context_shape
)

def parameters(self) -> Iterator[torch.Tensor]:
assert self._params_fn is not None
if hasattr(self._params_fn, "parameters"):
for param in self._params_fn.parameters():
yield param

def _check_bijective_x(
self, x: torch.Tensor, context: Optional[torch.Tensor]
) -> bool:
Expand Down Expand Up @@ -94,7 +90,9 @@ def _forward(
"""
Abstract method to compute forward transformation.
"""
raise NotImplementedError
raise NotImplementedError(
f"layer {self.__class__.__name__} does not have an `_forward` method"
)

def _check_bijective_y(
self, y: torch.Tensor, context: Optional[torch.Tensor]
Expand Down Expand Up @@ -138,7 +136,9 @@ def _inverse(
"""
Abstract method to compute inverse transformation.
"""
raise NotImplementedError
raise NotImplementedError(
f"layer {self.__class__.__name__} does not have an `_inverse` method"
)

def log_abs_det_jacobian(
self,
Expand Down Expand Up @@ -170,7 +170,7 @@ def log_abs_det_jacobian(
if ladj is None:
if is_record_flow_graph_enabled():
warnings.warn(
"Computing _log_abs_det_jacobian from values and not from cache."
"Computing _log_abs_det_jacobian from values and not " "from cache."
)
params = (
self._params_fn(x, context) if self._params_fn is not None else None
Expand Down
2 changes: 1 addition & 1 deletion flowtorch/bijectors/bijective_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc
from typing import Any, Optional, Iterator, Type, TYPE_CHECKING, Union
from typing import Any, Iterator, Optional, Type, TYPE_CHECKING, Union

if TYPE_CHECKING:
from flowtorch.bijectors.base import Bijector
Expand Down
28 changes: 17 additions & 11 deletions flowtorch/bijectors/compose.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) Meta Platforms, Inc
from typing import Optional, Sequence, Iterator
import warnings
from typing import Optional, Sequence

import flowtorch.parameters
import torch
import torch.distributions
from flowtorch.bijectors.base import Bijector
from flowtorch.bijectors.bijective_tensor import to_bijective_tensor, BijectiveTensor
from flowtorch.bijectors.bijective_tensor import BijectiveTensor, to_bijective_tensor
from flowtorch.bijectors.utils import is_record_flow_graph_enabled, requires_log_detJ
from torch.distributions.utils import _sum_rightmost

Expand All @@ -19,25 +20,21 @@ def __init__(
context_shape: Optional[torch.Size] = None,
):
assert len(bijectors) > 0
super().__init__(None, shape=shape, context_shape=context_shape)

# Instantiate all bijectors, propagating shape information
self.bijectors = []
self.bijectors = torch.nn.ModuleList()
for bijector in bijectors:
assert issubclass(bijector.cls, Bijector)

self.bijectors.append(bijector(shape=shape))
self.bijectors.append(bijector(shape=shape)) # type: ignore
shape = self.bijectors[-1].forward_shape(shape) # type: ignore

self.domain = self.bijectors[0].domain # type: ignore
self.codomain = self.bijectors[-1].codomain # type: ignore

self._context_shape = context_shape

def parameters(self) -> Iterator[torch.Tensor]:
for b in self.bijectors:
for param in b.parameters(): # type: ignore
yield param

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -77,7 +74,7 @@ def inverse(
) -> torch.Tensor:
log_detJ: Optional[torch.Tensor] = None
y_temp = y
for bijector in reversed(self.bijectors):
for bijector in reversed(self.bijectors._modules.values()): # type: ignore
x = bijector.inverse(y_temp, context) # type: ignore
if is_record_flow_graph_enabled() and requires_log_detJ():
if isinstance(y_temp, BijectiveTensor) and y_temp.from_forward():
Expand Down Expand Up @@ -124,7 +121,16 @@ def log_abs_det_jacobian(
else:
_use_cached_inverse = False

for bijector in reversed(self.bijectors):
if (
is_record_flow_graph_enabled()
and not _use_cached_inverse
and not isinstance(y, BijectiveTensor)
):
warnings.warn(
"Computing _log_abs_det_jacobian from values and not from cache."
)

for bijector in reversed(self.bijectors._modules.values()): # type: ignore
if not _use_cached_inverse:
y_inv = bijector.inverse(y, context) # type: ignore
else:
Expand Down
79 changes: 79 additions & 0 deletions flowtorch/bijectors/invert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc
from typing import Optional, Sequence

import flowtorch
import torch
from flowtorch.bijectors.base import Bijector


class Invert(Bijector):
"""
Lazily inverts a bijector by swapping the forward and inverse operations.
`Invert` flips a bijector such that forward calls inverse and inverse
calls forward. The log-abs-det-Jacobian is adjusted accordingly.
Args:
bijector (Bijector): layer to be inverted
Examples:
"""

def __init__(
self,
bijector: flowtorch.Lazy,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None
) -> None:
b = bijector(shape=shape)
super().__init__(None, shape=shape, context_shape=context_shape)
self.bijector = b

def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
y = self.bijector.inverse(x, context=context) # type: ignore
return y

def inverse(
self,
y: torch.Tensor,
x: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if x is not None:
raise RuntimeError("x must be None when calling InverseBijector.inverse")
x = self.bijector.forward(y, context=context) # type: ignore
return x

def log_abs_det_jacobian(
self,
x: torch.Tensor,
y: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.bijector.log_abs_det_jacobian(y, x, context) # type: ignore

def param_shapes(self, shape: torch.Size) -> Sequence[torch.Size]:
return self.bijector.param_shapes(shape) # type: ignore

def __repr__(self) -> str:
return self.bijector.__repr__() # type: ignore

def forward_shape(self, shape: torch.Size) -> torch.Size:
"""
Infers the shape of the forward computation, given the input shape.
Defaults to preserving shape.
"""
return self.bijector.forward_shape(shape) # type: ignore

def inverse_shape(self, shape: torch.Size) -> torch.Size:
"""
Infers the shapes of the inverse computation, given the output shape.
Defaults to preserving shape.
"""
return self.bijector.inverse_shape(shape) # type: ignore
46 changes: 46 additions & 0 deletions tests/test_bijector.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright (c) Meta Platforms, Inc
import warnings

import flowtorch.bijectors as bijectors
import flowtorch.parameters as params
import numpy as np
import pytest
import torch
import torch.distributions as dist
import torch.optim
from flowtorch.bijectors import AffineAutoregressive, Compose
from flowtorch.distributions import Flow

"""
Expand Down Expand Up @@ -121,6 +125,48 @@ def test_inverse(flow, epsilon=1e-5):
assert (J_1 - J_2).abs().max().item() < epsilon


def test_invert():
# Define a simple bijector to invert
ar = Compose(
[
AffineAutoregressive(params.DenseAutoregressive()),
AffineAutoregressive(params.DenseAutoregressive()),
]
)
shape = torch.Size(
[
16,
]
)

# Instantiate the bijector and its inverse
bij = ar(shape=shape)
inv_bij = bijectors.Invert(ar)(shape=shape)

# Make parameters the same for both
inv_bij.load_state_dict(bij.state_dict(prefix="bijector."))

# Test if inversion is correct
x = torch.randn(50, 16, requires_grad=True)
torch.testing.assert_allclose(inv_bij.forward(x), bij.inverse(x))

y = inv_bij.forward(x)

# checks that no warning is displayed, which can happen if no cache is used
with warnings.catch_warnings():
warnings.simplefilter("error")
inv_bij.log_abs_det_jacobian(x, y)

with pytest.warns(UserWarning):
y_det = y.detach_from_flow()
inv_bij.log_abs_det_jacobian(x, y_det)

y = y.detach_from_flow()
torch.testing.assert_allclose(
inv_bij.log_abs_det_jacobian(x, y), bij.log_abs_det_jacobian(y, x)
)


"""
# TODO
def _test_shape(self, base_shape, transform):
Expand Down

0 comments on commit 13cf226

Please sign in to comment.