Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

predict_y and predict_log_density support for full_cov or full_output_cov #1461

Open
mohitrajpal1 opened this issue May 10, 2020 · 7 comments · May be fixed by #1597
Open

predict_y and predict_log_density support for full_cov or full_output_cov #1461

mohitrajpal1 opened this issue May 10, 2020 · 7 comments · May be fixed by #1597

Comments

@mohitrajpal1
Copy link

mohitrajpal1 commented May 10, 2020

With latest version of GPFlow 2.0.2, calling predict_y with a SVGP model with full_cov = True results on a covariance matrix where the likelihood noise value is added to every element in the covariance matrix, as opposed to only the diagonal as expected.

From looking at the code, this happens because SVGP does not implement predict_y, but inherits it from base class. However, the base class implementation doesn't seem to be correct with full_cov = True.

To reproduce

import gpflow
import numpy as np
rng = np.random.RandomState(123)
N = 100  # Number of training observations
X = rng.rand(N, 1) * 2 - 1  # X values
M = 50  # Number of inducing locations
kernel = gpflow.kernels.SquaredExponential()
Z = X[:M, :].copy()  # Initialize inducing locations to the first M inputs in the dataset
m = gpflow.models.SVGP(kernel, gpflow.likelihoods.Gaussian(), Z, num_data=N)

pX = np.linspace(-1, 1, 100)[:, None]  # Test locations
_, pYv = m.predict_y(pX, full_cov = True)  # Predict Y values at test locations
_, pFv = m.predict_f(pX, full_cov = True)

pYv_offdiag = pYv - (1.0 - np.eye(pYv.shape[0]))*pYv
pFv_offdiag = pFv - (1.0 - np.eye(pYv.shape[0]))*pFv

print(np.linalg.norm(pYv_offdiag - pFv_offdiag)) #should be 0

Expected behavior

Expected behavior is for off diagonals to match that of predict_f, however this isn't the case. The last line of the above snippet should output 0.

System information

  • GPflow version: 2.0.2
  • GPflow installed from: pip
  • TensorFlow version: 2.1
  • Python version : 3.6
  • Operating system: Ubuntu 16.04
@vdutor
Copy link
Contributor

vdutor commented May 11, 2020

HI @mohitrajpal1

Thanks for your excellent bug report. You are completely right: the likelihood variance is incorrectly added to all the elements of the covariance matrix whereas it should only have been added to the diagonal of the matrix. This is, for example, caused by unsolicited broadcasting in the _predict_mean_and_var in the Gaussian likelihood.

As GPflow is an open source project we would very much appreciate if you could help us create a fix for this bug by submitting a PR. My initial plan would be to pass a keyword argument full_cov: bool = False to predict_mean_and_var, which will allow to add the likelihood variance correctly depending on the kwarg. If you start a PR I'm happy to review and make sure it gets merged. Would you be up for that?

Feel free to join our slack workspace (see README for details on how to join) if you want to discuss this in more detail.

@mohitrajpal1
Copy link
Author

Hi, ok, I can fix this. I'll try to find time this week but may not be possible with looming deadlines. I'll also add a test that verifies that predict_y works for all configs of full_cov and full_output_cov (for multi output gp).

I'll reach out to you on slack if there's a significant delay. Let me know if there's a rough date by which you need this in.

Thanks

@vdutor
Copy link
Contributor

vdutor commented May 12, 2020

Looking forward to your PR - thanks. There is no deadline for this, but I'll get in touch if someone else wants to pick up this ticket before you start your work. The minimal example you wrote above would be an ideal candidate to add to your unit tests.

@st-- st-- added this to To do in Open bugs via automation May 28, 2020
@st--
Copy link
Member

st-- commented Oct 2, 2020

This is related to #1569, which will be closed by #1582; this issue is about actually implementing the full_(output_)cov cases.

@st--
Copy link
Member

st-- commented Oct 2, 2020

@mohitrajpal1 - would you still be up for looking at how to implement this ?

@mohitrajpal1
Copy link
Author

@st-- Hi sorry. Yes I can do this shortly. Things got lost due to cascading deadlines. Expect a PR next week.

@st--
Copy link
Member

st-- commented Oct 3, 2020

Great! Looking forward to it.

@st-- st-- changed the title SVGP predict_y appears to incorrectly add noise variance for full_cov predict_y and predict_log_density do not support full_cov or full_output_cov Oct 5, 2020
@st-- st-- added the feature label Oct 5, 2020
@st-- st-- changed the title predict_y and predict_log_density do not support full_cov or full_output_cov predict_y and predict_log_density support for full_cov or full_output_cov Oct 5, 2020
@st-- st-- removed the bug label Oct 5, 2020
@st-- st-- removed this from To do in Open bugs Oct 5, 2020
@st-- st-- mentioned this issue Oct 7, 2020
mohitrajpal1 pushed a commit to mohitrajpal1/GPflow that referenced this issue Oct 11, 2020
mohitrajpal1 added a commit to mohitrajpal1/GPflow that referenced this issue Oct 11, 2020
@st-- st-- added this to TODO in GPflow Features Nov 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
3 participants