Skip to content

Commit

Permalink
Add fastgp to tfp.experimental.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 628478279
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Apr 26, 2024
1 parent f9ab36e commit 4e1d631
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 46 deletions.
2 changes: 2 additions & 0 deletions tensorflow_probability/python/experimental/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ multi_substrate_py_library(
"//tensorflow_probability/python/experimental/bijectors",
"//tensorflow_probability/python/experimental/distribute",
"//tensorflow_probability/python/experimental/distributions",
"//tensorflow_probability/python/experimental/fastgp",
"//tensorflow_probability/python/experimental/joint_distribution_layers",
"//tensorflow_probability/python/experimental/linalg",
"//tensorflow_probability/python/experimental/marginalize",
Expand All @@ -70,5 +71,6 @@ multi_substrate_py_library(
"//tensorflow_probability/python/experimental/vi",
"//tensorflow_probability/python/internal:all_util",
"//tensorflow_probability/python/internal:auto_composite_tensor",
"//tensorflow_probability/python/internal:lazy_loader",
],
)
7 changes: 7 additions & 0 deletions tensorflow_probability/python/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,15 @@
from tensorflow_probability.python.experimental.util.composite_tensor import as_composite
from tensorflow_probability.python.experimental.util.composite_tensor import register_composite
from tensorflow_probability.python.internal import all_util
from tensorflow_probability.python.internal import lazy_loader
from tensorflow_probability.python.internal.auto_composite_tensor import auto_composite_tensor
from tensorflow_probability.python.internal.auto_composite_tensor import AutoCompositeTensor

# TODO(thomaswc): Figure out why fastgp needs to be lazy_loaded.
globals()['fastgp'] = lazy_loader.LazyLoader(
'fastgp', globals(), 'tensorflow_probability.python.experimental.fastgp'
)


_allowed_symbols = [
'auto_batching',
Expand All @@ -63,6 +69,7 @@
'bijectors',
'distribute',
'distributions',
'fastgp',
'joint_distribution_layers',
'linalg',
'marginalize',
Expand Down
13 changes: 0 additions & 13 deletions tensorflow_probability/python/experimental/fastgp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ py_library(
name = "mbcg",
srcs = ["mbcg.py"],
deps = [
# jax dep,
],
)

Expand All @@ -84,7 +83,6 @@ py_library(
":fast_log_det",
":mbcg",
":preconditioners",
# jax dep,
"//tensorflow_probability/python/bijectors:softplus.jax",
"//tensorflow_probability/python/distributions:distribution.jax",
"//tensorflow_probability/python/distributions:gaussian_process_regression_model.jax",
Expand Down Expand Up @@ -121,7 +119,6 @@ py_library(
":linear_operator_sum",
":mbcg",
":preconditioners",
# jax dep,
"//tensorflow_probability/python/distributions:distribution.jax",
"//tensorflow_probability/python/distributions/internal:stochastic_process_util.jax",
"//tensorflow_probability/python/experimental/psd_kernels:multitask_kernel.jax",
Expand Down Expand Up @@ -153,7 +150,6 @@ py_library(
":fast_gp",
":preconditioners",
":schur_complement",
# jax dep,
"//tensorflow_probability/python/bijectors:softplus.jax",
"//tensorflow_probability/python/distributions/internal:stochastic_process_util.jax",
"//tensorflow_probability/python/internal:dtype_util.jax",
Expand All @@ -180,9 +176,7 @@ py_library(
srcs = ["linalg.py"],
deps = [
":partial_lanczos",
# jax dep,
# jax:experimental_sparse dep,
# jaxtyping dep,
"//tensorflow_probability/python/internal/backend/jax",
],
)
Expand All @@ -205,7 +199,6 @@ py_library(
srcs = ["partial_lanczos.py"],
deps = [
":mbcg",
# jax dep,
# scipy dep,
"//tensorflow_probability/python/internal/backend/jax",
],
Expand All @@ -231,8 +224,6 @@ py_library(
":mbcg",
":partial_lanczos",
":preconditioners",
# jax dep,
# jaxtyping dep,
# numpy dep,
# scipy dep,
],
Expand Down Expand Up @@ -266,9 +257,6 @@ py_library(
deps = [
":linalg",
":linear_operator_sum",
# jax dep,
# jax:experimental_sparse dep,
# jaxtyping dep,
"//tensorflow_probability/python/internal/backend/jax",
"//tensorflow_probability/python/math:linalg.jax",
],
Expand All @@ -292,7 +280,6 @@ py_library(
srcs = ["schur_complement.py"],
deps = [
":preconditioners",
# jax dep,
"//tensorflow_probability/python/bijectors:softplus.jax",
"//tensorflow_probability/python/internal:distribution_util.jax",
"//tensorflow_probability/python/internal:dtype_util.jax",
Expand Down
16 changes: 16 additions & 0 deletions tensorflow_probability/python/experimental/fastgp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,25 @@
from tensorflow_probability.python.experimental.fastgp import partial_lanczos
from tensorflow_probability.python.experimental.fastgp import preconditioners
from tensorflow_probability.python.experimental.fastgp import schur_complement
from tensorflow_probability.python.experimental.fastgp.fast_gp import GaussianProcess
from tensorflow_probability.python.experimental.fastgp.fast_gp import GaussianProcessConfig
from tensorflow_probability.python.experimental.fastgp.fast_gprm import GaussianProcessRegressionModel
from tensorflow_probability.python.experimental.fastgp.fast_log_det import get_log_det_algorithm
from tensorflow_probability.python.experimental.fastgp.fast_log_det import ProbeVectorType
from tensorflow_probability.python.experimental.fastgp.fast_mtgp import MultiTaskGaussianProcess
from tensorflow_probability.python.experimental.fastgp.preconditioners import get_preconditioner
from tensorflow_probability.python.experimental.fastgp.schur_complement import SchurComplement
from tensorflow_probability.python.internal import all_util

_allowed_symbols = [
'GaussianProcessConfig',
'GaussianProcess',
'GaussianProcessRegressionModel',
'ProbeVectorType',
'get_log_det_algorithm',
'MultiTaskGaussianProcess',
'get_preconditioner',
'SchurComplement',
'fast_log_det',
'fast_gp',
'fast_gprm',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class GaussianProcessConfig:
probe_vector_type: str = 'rademacher'
# The number of probe vectors to use when estimating the log det.
num_probe_vectors: int = 35
# One of 'slq' (for stochastic Lancos quadrature) or
# One of 'slq' (for stochastic Lanczos quadrature) or
# 'r1', 'r2', 'r3', 'r4', 'r5', or 'r6' for the rational function
# approximation of the given order.
log_det_algorithm: str = 'r3'
Expand Down
33 changes: 16 additions & 17 deletions tensorflow_probability/python/experimental/fastgp/fast_log_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import jax
import jax.numpy as jnp
from jaxtyping import Float
import numpy as np
from tensorflow_probability.python.experimental.fastgp import mbcg
from tensorflow_probability.python.experimental.fastgp import partial_lanczos
Expand Down Expand Up @@ -159,7 +158,7 @@ def _log_det_rational_approx_with_hutchinson(
probe_vectors: Array,
key: jax.Array,
num_iters: int,
) -> Float:
):
"""Approximate log det using a rational function.
We calculate log det M as the trace of log M, and we approximate the
Expand Down Expand Up @@ -295,7 +294,7 @@ def _r1(
probe_vectors: Array,
key: jax.Array,
num_iters: int,
) -> Float:
):
"""Approximate log det using a 1st order rational function."""
return _log_det_rational_approx_with_hutchinson(
jnp.asarray(R1_SHIFTS, dtype=probe_vectors.dtype),
Expand All @@ -315,7 +314,7 @@ def _r2(
probe_vectors: Array,
key: jax.Array,
num_iters: int,
) -> Float:
):
"""Approximate log det using a 2nd order rational function."""
return _log_det_rational_approx_with_hutchinson(
jnp.asarray(R2_SHIFTS, dtype=probe_vectors.dtype),
Expand All @@ -335,7 +334,7 @@ def _r3(
probe_vectors: Array,
key: jax.Array,
num_iters: int,
) -> Float:
):
"""Approximate log det using a 4th order rational function."""
return _log_det_rational_approx_with_hutchinson(
jnp.asarray(R3_SHIFTS, dtype=probe_vectors.dtype),
Expand All @@ -355,7 +354,7 @@ def _r4(
probe_vectors: Array,
key: jax.Array,
num_iters: int,
) -> Float:
):
"""Approximate log det using a 4th order rational function."""
return _log_det_rational_approx_with_hutchinson(
jnp.asarray(R4_SHIFTS, dtype=probe_vectors.dtype),
Expand All @@ -375,7 +374,7 @@ def _r5(
probe_vectors: Array,
key: jax.Array,
num_iters: int,
) -> Float:
):
"""Approximate log det using a 4th order rational function."""
return _log_det_rational_approx_with_hutchinson(
jnp.asarray(R5_SHIFTS, dtype=probe_vectors.dtype),
Expand All @@ -395,7 +394,7 @@ def _r6(
probe_vectors: Array,
key: jax.Array,
num_iters: int,
) -> Float:
):
"""Approximate log det using a 4th order rational function."""
return _log_det_rational_approx_with_hutchinson(
jnp.asarray(R6_SHIFTS, dtype=probe_vectors.dtype),
Expand Down Expand Up @@ -453,7 +452,7 @@ def r1(
probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER,
num_iters: int = 20,
**unused_kwargs,
) -> Float:
):
"""Approximate log det using a 2nd order rational function."""
n = M.shape[-1]
key1, key2 = jax.random.split(key)
Expand All @@ -473,7 +472,7 @@ def r2(
probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER,
num_iters: int = 20,
**unused_kwargs,
) -> Float:
):
"""Approximate log det using a 2nd order rational function."""
n = M.shape[-1]
key1, key2 = jax.random.split(key)
Expand All @@ -493,7 +492,7 @@ def r3(
probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER,
num_iters: int = 20,
**unused_kwargs,
) -> Float:
):
"""Approximate log det using a 3rd order rational function."""
n = M.shape[-1]
key1, key2 = jax.random.split(key)
Expand All @@ -513,7 +512,7 @@ def r4(
probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER,
num_iters: int = 20,
**unused_kwargs,
) -> Float:
):
"""Approximate log det using a 4th order rational function."""
n = M.shape[-1]
key1, key2 = jax.random.split(key)
Expand All @@ -533,7 +532,7 @@ def r5(
probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER,
num_iters: int = 20,
**unused_kwargs,
) -> Float:
):
"""Approximate log det using a 5th order rational function."""
n = M.shape[-1]
key1, key2 = jax.random.split(key)
Expand All @@ -553,7 +552,7 @@ def r6(
probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER,
num_iters: int = 20,
**unused_kwargs,
) -> Float:
):
"""Approximate log det using a 6th order rational function."""
n = M.shape[-1]
key1, key2 = jax.random.split(key)
Expand Down Expand Up @@ -597,7 +596,7 @@ def _stochastic_lanczos_quadrature_log_det(
unused_key,
probe_vectors_are_rademacher: bool,
num_iters: int,
) -> Float:
):
"""Fast log det using the alg. from https://arxiv.org/pdf/1809.11165.pdf ."""
n = M.shape[-1]

Expand Down Expand Up @@ -639,7 +638,7 @@ def stochastic_lanczos_quadrature_log_det(
probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER,
num_iters: int = 25,
**unused_kwargs,
) -> Float:
):
"""Fast log det using the alg. from https://arxiv.org/pdf/1809.11165.pdf ."""
n = M.shape[-1]
num_iters = min(n, num_iters)
Expand Down Expand Up @@ -703,7 +702,7 @@ def log_det_taylor_series_with_hutchinson(
num_probe_vectors: int,
key: jax.Array,
num_taylor_series_iterations: int = 10,
) -> Float:
):
"""Return an approximation of log det M."""
# TODO(thomaswc): Consider having this support a batch of LinearOperators.
n = M.shape[0]
Expand Down
5 changes: 2 additions & 3 deletions tensorflow_probability/python/experimental/fastgp/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import jax
import jax.experimental.sparse
import jax.numpy as jnp
from jaxtyping import Float
import numpy as np
from tensorflow_probability.python.experimental.fastgp import partial_lanczos
from tensorflow_probability.python.internal.backend import jax as tf2jax
Expand All @@ -36,7 +35,7 @@ def _matvec(M, x) -> jax.Array:

def largest_eigenvector(
M: tf2jax.linalg.LinearOperator, key: jax.Array, num_iters: int = 10
) -> tuple[Float, Array]:
):
"""Returns the largest (eigenvalue, eigenvector) of M."""
n = M.shape[-1]
v = jax.random.uniform(key, shape=(n,), dtype=M.dtype)
Expand All @@ -55,7 +54,7 @@ def make_randomized_truncated_svd(
rank: int = 20,
oversampling: int = 10,
num_iters: int = 4,
) -> tuple[Float, Array]:
):
"""Returns approximate SVD for symmetric `M`."""
# This is based on:
# N. Halko, P.G. Martinsson, J. A. Tropp
Expand Down

0 comments on commit 4e1d631

Please sign in to comment.