Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Bug in GP Regression with KeOps Kernels #2480

Open
mrozkamil opened this issue Feb 23, 2024 · 1 comment
Open

[Bug] Bug in GP Regression with KeOps Kernels #2480

mrozkamil opened this issue Feb 23, 2024 · 1 comment
Labels

Comments

@mrozkamil
Copy link

馃悰 Bug

ExactMarginalLogLikelihood fails due to incompatibilities between LazyTensors and dense Tensors

To reproduce

Follow the regression task provided at https://docs.gpytorch.ai/en/v1.11/examples/02_Scalable_Exact_GPs/KeOps_GP_Regression.html
** Code snippet to reproduce **

import math
import torch
import gpytorch
from matplotlib import pyplot as plt

# Checks if GPUs are available for PyTorch to use.

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

# Prints which device PyTorch will use.
print(f"Using {device} device")
n_devices = torch.cuda.device_count()
print('Found {} GPUs.'.format(n_devices))


import urllib.request
import os.path
from scipy.io import loadmat
from math import floor

if not os.path.isfile('../3droad.mat'):
    print('Downloading \'3droad\' UCI dataset...')
    urllib.request.urlretrieve('https://www.dropbox.com/s/f6ow1i59oqx05pl/3droad.mat?dl=1', '../3droad.mat')

data = torch.Tensor(loadmat('../3droad.mat')['data'])

import numpy as np

N = data.shape[0]
# make train/val/test
n_train = int(0.8 * N)
train_x, train_y = data[:n_train, :-1], data[:n_train, -1]
test_x, test_y = data[n_train:, :-1], data[n_train:, -1]

# normalize features
mean = train_x.mean(dim=-2, keepdim=True)
std = train_x.std(dim=-2, keepdim=True) + 1e-6 # prevent dividing by 0
train_x = (train_x - mean) / std
test_x = (test_x - mean) / std

# normalize labels
mean, std = train_y.mean(),train_y.std()
train_y = (train_y - mean) / std
test_y = (test_y - mean) / std

# make continguous
train_x, train_y = train_x.contiguous(), train_y.contiguous()
test_x, test_y = test_x.contiguous(), test_y.contiguous()

output_device = torch.device('cuda:0')

train_x, train_y = train_x.to(output_device), train_y.to(output_device)
test_x, test_y = test_x.to(output_device), test_y.to(output_device)


# We will use the simplest form of GP model, exact inference
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()

        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.keops.MaternKernel(nu=2.5))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# initialize likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda()
model = ExactGPModel(train_x, train_y, likelihood).cuda()


# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

import time
training_iter = 50
for i in range(training_iter):
    start_time = time.time()
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = model(train_x)
    # Calc loss and backprop gradients
    loss = -mll(output, train_y)
    loss.backward()
    print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
        i + 1, training_iter, loss.item(),
        model.covar_module.base_kernel.lengthscale.item(),
        model.likelihood.noise.item()
    ))
    optimizer.step()
    print(time.time() - start_time)


# Get into evaluation (predictive posterior) mode
model.eval()
likelihood.eval()

# Test points are regularly spaced along [0,1]
# Make predictions by feeding model through likelihood
with torch.no_grad(), gpytorch.settings.fast_pred_var():
    observed_pred = likelihood(model(test_x))

torch.sqrt(torch.mean(torch.pow(observed_pred.mean - test_y, 2)))

Stack trace/error message

Using cuda device
Found 1 GPUs.
Traceback (most recent call last):
  File "/home/users/km357/Scripts/Python3/MachineLearning/pyTorch/GP_scattering/test_KeOps_GPyTorch.py", line 99, in <module>
    loss = -mll(output, train_y)
            ^^^^^^^^^^^^^^^^^^^^
  File "/home/users/km357/miniforge3/envs/pytorch/lib/python3.12/site-packages/gpytorch/module.py", line 31, in __call__
    outputs = self.forward(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/km357/miniforge3/envs/pytorch/lib/python3.12/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py", line 64, in forward
    res = output.log_prob(target)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/km357/miniforge3/envs/pytorch/lib/python3.12/site-packages/gpytorch/distributions/multivariate_normal.py", line 193, in log_prob
    inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/km357/miniforge3/envs/pytorch/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py", line 1748, in inv_quad_logdet
    preconditioner, precond_lt, logdet_p = self._preconditioner()
                                           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/km357/miniforge3/envs/pytorch/lib/python3.12/site-packages/linear_operator/operators/added_diag_linear_operator.py", line 126, in _preconditioner
    self._piv_chol_self = self._linear_op.pivoted_cholesky(rank=max_iter)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/km357/miniforge3/envs/pytorch/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py", line 1965, in pivoted_cholesky
    res, pivots = func(self.representation_tree(), rank, error_tol, *self.representation())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/km357/miniforge3/envs/pytorch/lib/python3.12/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/km357/miniforge3/envs/pytorch/lib/python3.12/site-packages/linear_operator/functions/_pivoted_cholesky.py", line 24, in forward
    matrix_diag = matrix._approx_diagonal()
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/km357/miniforge3/envs/pytorch/lib/python3.12/site-packages/linear_operator/operators/constant_mul_linear_operator.py", line 74, in _approx_diagonal
    res = self.base_linear_op._approx_diagonal()
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/km357/miniforge3/envs/pytorch/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py", line 492, in _approx_diagonal
    return self._diagonal()
           ^^^^^^^^^^^^^^^^
  File "/home/users/km357/miniforge3/envs/pytorch/lib/python3.12/site-packages/linear_operator/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/km357/miniforge3/envs/pytorch/lib/python3.12/site-packages/linear_operator/operators/kernel_linear_operator.py", line 233, in _diagonal
    diag_mat = to_dense(self.covar_func(x1, x2, **tensor_params, **self.nontensor_params))
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/km357/miniforge3/envs/pytorch/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py", line 2987, in to_dense
    raise TypeError("object of class {} cannot be made into a Tensor".format(obj.__class__.__name__))
TypeError: object of class LazyTensor cannot be made into a Tensor// Paste the bad output here!

Expected Behavior

The ExactMarginalLogLikelihood computations should be executed successfully.

System information

gpytorch 1.11 0 gpytorch
pytorch 2.2.0 py3.12_cuda12.1_cudnn8.9.2_0 pytorch
pytorch-cuda 12.1 ha16c6d3_5 pytorch
pytorch-mutex 1.0 cuda pytorch
keopscore 2.2.2 pypi_0 pypi
cuda-cudart 12.1.105 0 nvidia
cuda-cupti 12.1.105 0 nvidia
cuda-libraries 12.1.0 0 nvidia
cuda-nvrtc 12.1.105 0 nvidia
cuda-nvtx 12.1.105 0 nvidia
cuda-opencl 12.3.101 0 nvidia
cuda-runtime 12.1.0 0 nvidia

Operating System: CentOS Linux 7 (Core)
Kernel: Linux 3.10.0-1160.105.1.el7.x86_64
Architecture: x86-64

Additional context

The regression task given at https://www.kernel-operations.io/keops/_auto_tutorials/backends/plot_gpytorch.html works well.

#####################################################################
# Setup
# -----------------
# Standard imports, including `gpytorch <https://gpytorch.ai/>`_:

import gpytorch
import math
import torch
from matplotlib import pyplot as plt

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

#####################################################################
# We generate a toy dataset: some regularly spaced samples on the unit interval,
# and a sinusoid signal corrupted by a small Gaussian noise.

N = 1000 if use_cuda else 100
train_x = torch.linspace(0, 1, N).type(dtype)
train_y = torch.sin(train_x * (2 * math.pi)) + 0.2 * torch.randn(train_x.size()).type(
    dtype
)

#####################################################################
# Defining a new KeOps RBF kernel
# ---------------------------------
#
# Internally, GPytorch relies on `LazyTensors  <https://gpytorch.readthedocs.io/en/latest/lazy.html>`_
# parameterized by explicit **torch Tensors** - and **nothing** else.
# To let GPytorch use our KeOps CUDA routines, we should thus create
# a new class of :mod:`gpytorch.lazy.LazyTensor`, encoding an implicit
# kernel matrix built from raw point clouds **x_i** and **y_j**.
#
# .. note::
#   Ideally, we'd like to be able to **export KeOps LazyTensors** directly as
#   GPytorch objects, but the reliance of the latter's internal engine on
#   explicit **torch.Tensor** variables is a hurdle that we could not bypass
#   easily. Working on this problem with the GPytorch team,
#   we hope to provide a simpler interface in future releases.


from pykeops.torch import LazyTensor


class KeOpsRBFLazyTensor(gpytorch.lazy.LazyTensor):
    def __init__(self, x_i, y_j):
        """Creates a symbolic Gaussian RBF kernel out of two point clouds `x_i` and `y_j`."""
        super().__init__(
            x_i, y_j
        )  # GPytorch will remember that self was built from x_i and y_j

        self.x_i, self.y_j = x_i, y_j  # Useful to define a symbolic transpose

        with torch.autograd.enable_grad():  # N.B.: gpytorch operates in no_grad mode
            x_i, y_j = (
                LazyTensor(self.x_i[:, None, :]),
                LazyTensor(self.y_j[None, :, :]),
            )
            K_xy = (
                -((x_i - y_j) ** 2).sum(-1) / 2
            ).exp()  # Compute the kernel matrix symbolically...

        self.K = K_xy  # ... and store it for later use

    def _matmul(self, M):
        """Kernel-Matrix multiplication."""
        return self.K @ M

    def _size(self):
        """Shape attribute."""
        return torch.Size(self.K.shape)

    def _transpose_nonbatch(self):
        """Symbolic transpose operation."""
        return KeOpsRBFLazyTensor(self.y_j, self.x_i)

    def _get_indices(self, row_index, col_index, *batch_indices):
        """Returns a (small) explicit sub-matrix, used e.g. for Nystroem approximation."""
        X_i = self.x_i[row_index]
        Y_j = self.y_j[col_index]
        return (-((X_i - Y_j) ** 2).sum(-1) / 2).exp()  # Genuine torch.Tensor

    def _quad_form_derivative(self, *args, **kwargs):
        """As of gpytorch v0.3.2, the default implementation returns a list instead of a tuple..."""
        return tuple(super()._quad_form_derivative(*args, **kwargs))  # Bugfix!


#####################################################################
# We can now create a new GPytorch **Kernel** object, wrapped around
# our KeOps+GPytorch LazyTensor:


class KeOpsRBFKernel(gpytorch.kernels.Kernel):
    """Simple KeOps re-implementation of 'gpytorch.kernels.RBFKernel'."""

    has_lengthscale = True

    def forward(self, x1, x2, diag=False, **params):
        if diag:  # A Gaussian RBF kernel only has "ones" on the diagonal
            return torch.ones(len(x1)).type_as(x1)
        else:
            if x1.dim() == 1:
                x1 = x1.view(-1, 1)
            if x2.dim() == 1:
                x2 = x2.view(-1, 1)
            # Rescale the input data...
            x_i, y_j = x1.div(self.lengthscale), x2.div(self.lengthscale)
            return KeOpsRBFLazyTensor(
                x_i, y_j
            )  # ... and return it as a gyptorch.lazy.LazyTensor


#####################################################################
# And use it to define a new Gaussian Process model:


# We will use the simplest form of GP model, exact inference
class KeOpsGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(KeOpsRBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


##########################################################
# **N.B., for the sake of comparison:** the GPytorch documentation went with
# the code below, using the standard :meth:`gpytorch.kernels.RBFKernel()`
# instead of our custom :meth:`KeOpsRBFKernel()`:


class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


##########################################################
# **That's it!** We can now initialize our likelihood and model, as recommended by the documentation:

if use_cuda:
    likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda()
    model = KeOpsGPModel(train_x, train_y, likelihood).cuda()
else:
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    model = KeOpsGPModel(train_x, train_y, likelihood)

#####################################################################
# GP training
# -----------------
# The code below is now a direct copy-paste from the
# `GPytorch 101 tutorial <https://docs.gpytorch.ai/en/v1.1.1/examples/01_Exact_GPs/Simple_GP_Regression.html>`_:

# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(
    [
        {"params": model.parameters()},
    ],
    lr=0.1,  # Includes GaussianLikelihood parameters
)

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

training_iter = 50
for i in range(training_iter):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = model(train_x)
    # Calc loss and backprop gradients
    loss = -mll(output, train_y)
    loss.backward()
    if i % 10 == 0 or i == training_iter - 1:
        print(
            "Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f"
            % (
                i + 1,
                training_iter,
                loss.item(),
                model.covar_module.base_kernel.lengthscale.item(),
                model.likelihood.noise.item(),
            )
        )
    optimizer.step()

#####################################################################
# Prediction and display
# -------------------------
# Get into evaluation (predictive posterior) mode
#

model.eval()
likelihood.eval()

#####################################################################
# Test points are regularly spaced along [0,1].
# We make predictions by feeding our ``model`` through the ``likelihood``:

with torch.no_grad(), gpytorch.settings.fast_pred_var():
    test_x = torch.linspace(0, 1, 51).type(dtype)
    observed_pred = likelihood(model(test_x))

#####################################################################
# Display:
#

with torch.no_grad():
    # Initialize plot
    f, ax = plt.subplots(1, 1, figsize=(12, 9))

    # Get upper and lower confidence bounds
    lower, upper = observed_pred.confidence_region()
    # Plot training data as black stars
    ax.plot(train_x.cpu().numpy(), train_y.cpu().numpy(), "k*")
    # Plot predictive means as blue line
    ax.plot(test_x.cpu().numpy(), observed_pred.mean.cpu().numpy(), "b")
    # Shade between the lower and upper confidence bounds
    ax.fill_between(
        test_x.cpu().numpy(), lower.cpu().numpy(), upper.cpu().numpy(), alpha=0.5
    )
    ax.set_ylim([-3, 3])
    ax.legend(["Observed Data", "Mean", "Confidence"])

plt.axis([0, 1, -2, 2])
plt.tight_layout()
plt.show()

output

[KeOps] Generating code for Sum_Reduction reduction (with parameters 0) of formula Exp(-1/2*(a-b)**2)*c with a=Var(0,1,0), b=Var(1,1,1), c=Var(2,11,1) ... OK
[KeOps] Generating code for Sum_Reduction reduction (with parameters 0) of formula -(((d|c)*(a-b))*Exp(-1/2*(a-b)**2)) with a=Var(0,1,0), b=Var(1,1,1), c=Var(2,11,1), d=Var(3,11,0) ... OK
[KeOps] Generating code for Sum_Reduction reduction (with parameters 1) of formula ((d|c)*(a-b))*Exp(-1/2*(a-b)**2) with a=Var(0,1,0), b=Var(1,1,1), c=Var(2,11,1), d=Var(3,11,0) ... OK
Iter 1/50 - Loss: 0.858   lengthscale: 0.693   noise: 0.693
Iter 11/50 - Loss: 0.412   lengthscale: 0.335   noise: 0.311
Iter 21/50 - Loss: 0.051   lengthscale: 0.234   noise: 0.124
Iter 31/50 - Loss: -0.168   lengthscale: 0.219   noise: 0.051
Iter 41/50 - Loss: -0.163   lengthscale: 0.243   noise: 0.030
Iter 50/50 - Loss: -0.176   lengthscale: 0.284   noise: 0.032
[KeOps] Generating code for Sum_Reduction reduction (with parameters 0) of formula c*Exp(-1/2*(a-b)**2) with a=Var(0,1,0), b=Var(1,1,1), c=Var(2,1,1) ... OK
[KeOps] Generating code for Sum_Reduction reduction (with parameters 0) of formula Exp(-1/2*(a-b)**2)*c with a=Var(0,1,0), b=Var(1,1,1), c=Var(2,100,1) ... OK
@mrozkamil mrozkamil added the bug label Feb 23, 2024
@aj8907
Copy link

aj8907 commented Mar 20, 2024

Reverting back to version 1.10 fixed this problem for me.

pip install gpytorch==1.10

Ran this example on Google Colab V100 High Ram.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants