Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Erroneous detaching with (custom?) mean #2521

Open
villetan opened this issue May 2, 2024 · 0 comments
Open

[Bug] Erroneous detaching with (custom?) mean #2521

villetan opened this issue May 2, 2024 · 0 comments
Labels

Comments

@villetan
Copy link

villetan commented May 2, 2024

🐛 Bug

Hi,

First of all, thank you for developing such a versatile and efficient library! I suspect that I came across a bug when working with BoTorch, but I believe it is originating from GPyTorch.

When working with a custom mean function that depends on some parameters $\theta$ that we wish to optimize (e.g. NN feature extractor), the derivative of the predictive mean is wrong.

The predictive mean $\mu(\cdot)$ (given a custom mean function $m_\theta$) at $x^*$ is given by

$$ \mu(x^* | X, y) = m_\theta(x^*) K_{*}^T(K + \sigma I)^{-1} (y - m_\theta(X))\text{.} $$

Easy test for the derivative is to predict at the observed data points $X$, which gives us (when the observational noise is small, $\sigma \approx 0$)

$$ \mu(X | X, y) \approx m_\theta(X) + y - m_\theta(X) = y $$

whose derivative w.r.t the mean module's parameters should be zero.

It appears that at least in the following specific case this does not happen, and seems to be related to incorrect detaching at one place (see below for a hypothetical location where this happens)

To reproduce

** Code snippet to reproduce **

from botorch.models import SingleTaskGP
import torch
torch.manual_seed(123)

#define a model for the mean
class LinearModel(torch.nn.Module):
    def __init__(self, D=1, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.beta = torch.nn.Parameter(torch.randn(1,D))

    def forward(self, X):
        return (X * self.beta).sum(-1)

lm = LinearModel(1)
opt = torch.optim.Adam(lm.parameters(), lr = 0.424242)#for zeroing grads

#generate some data
N=50
x_data = torch.linspace(0, 1, N).view(-1, 1)
y_data = 2*torch.sin(10*x_data)  + 0.01*torch.randn_like(x_data)


#gp prediction manually
def gp_pred(xstar, obs_X, obs_Y, prior_mean, gp, detach_bug=False):
    samples_prior_mean = prior_mean(xstar)
    obs_prior_mean = prior_mean(obs_X)
    gp_y = obs_Y - obs_prior_mean.unsqueeze(-1)

    #gp pred
    K = gp.covar_module(obs_X, obs_X).to_dense().detach() 
    likelihood_additive_noise = gp.likelihood.noise_covar.raw_noise.detach()
    KplusNoise = K + likelihood_additive_noise * torch.eye(K.shape[0])
    Kstar = gp.covar_module(xstar, obs_X).to_dense().detach()
    KpNinv = torch.linalg.inv(KplusNoise)
    #pred_mean = Kstar @ KpNinv @ gp_y
    pred_mean = Kstar @ torch.linalg.solve(KplusNoise, gp_y)
    if detach_bug: #the bug is here
        pred_mean = pred_mean.detach()
    pred_mean_og_scale = pred_mean.squeeze() + samples_prior_mean
    
    pred_var = gp.covar_module.outputscale.detach() - (Kstar @ KpNinv @ Kstar.T).diag()
    return pred_mean_og_scale, pred_var


#define the Gpytorch model
import gpytorch
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, mean_module, covar_module):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = mean_module
        self.covar_module = covar_module

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
    

#botorch model
gp_botorch = SingleTaskGP(x_data, y_data, mean_module=lm)
gp_botorch.likelihood.noise_covar.noise = 0.0001 #to simulate near noiseless prediction
gp_botorch.eval()
gp_botorch.likelihood.eval()

#gpytorch model
gp_gpytorch = ExactGPModel(x_data, y_data.squeeze(), gp_botorch.likelihood, lm, gp_botorch.covar_module)
gp_gpytorch.eval()

#test that manual, botorch and gpytorch actually output same predictions
x_star = torch.linspace(0, 1, 100).view(-1, 1)
pred_mean, pred_var = gp_pred(x_star, x_data, y_data, lm, gp_botorch)
pred_botorch = gp_botorch(x_star)
pred_gpytorch = gp_gpytorch(x_star)
#check botorch and gpytorch are equal
print("botorch vs. gpytorch")
print(torch.abs(pred_botorch.mean - pred_gpytorch.mean).max())
print(torch.abs(pred_botorch.variance - pred_gpytorch.variance).max())
#check predictions botorch and manual are approximately equal
print("botorch vs manual")
print(torch.abs(pred_mean - gp_botorch(x_star).mean).max())
print(torch.abs(pred_var - gp_botorch(x_star).variance).max())

#plotting
# import matplotlib.pyplot as plt
# plt.scatter(x_data, y_data, label="data")
# plt.plot(x_star, pred_mean.detach(), label="manual")
# plt.plot(x_star, pred_botorch.mean.detach(), label="botorch")
# plt.plot(x_star, pred_gpytorch.mean.detach(), label="gpytorch")
# plt.legend()
# plt.show()


#prediciton at the training points
opt.zero_grad()
pred_mean, pred_var = gp_pred(x_data, x_data, y_data, lm, gp_botorch)
pred_mean.mean().backward()
print(lm.beta.grad) #correct: ≈ 0

opt.zero_grad()
pred_mean, pred_var = gp_pred(x_data, x_data, y_data, lm, gp_botorch, detach_bug=True)
pred_mean.mean().backward()
print(lm.beta.grad) #incorrect: ≠ 0

opt.zero_grad()
pred_botorch = gp_botorch(x_data)
pred_botorch.mean.mean().backward()
print(lm.beta.grad) #incorrect: ≠ 0

opt.zero_grad()
pred_gpytorch = gp_gpytorch(x_data)
pred_gpytorch.mean.mean().backward()
print(lm.beta.grad) #incorrect: ≠ 0

Now the last three predictions are equal in outputs and in gradients (imo incorrect), but the first one matches in outputs and produces correct gradient.

Expected Behavior

The gradient of the predictive mean at the observation locations $X$ to be $0$. See #correct in the above snippet.

System information

Please complete the following information:

  • Gpytorch 1.11
  • PyTorch 2.1.0
  • Sonoma 14.4.1

I could not locate a bug in GPyTorch code, but hopefully you will be able to locate it with this report.

@villetan villetan added the bug label May 2, 2024
@villetan villetan changed the title [Bug] Erroneous detaching with custom(?) mean [Bug] Erroneous detaching with (custom?) mean May 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant