Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix VNNGP with batches #2375

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open

Conversation

LuhuanWu
Copy link
Contributor

@LuhuanWu LuhuanWu commented Jul 9, 2023

  • Fix VNNGP in batch settings VNNGP with Batches #2300
  • In addition, set default jitter value to 1e-3 and variational stddev init value to 1e-2.

@LuhuanWu LuhuanWu mentioned this pull request Jul 9, 2023
@@ -87,8 +90,6 @@ def __init__(
super().__init__(
model, inducing_points, variational_distribution, learn_inducing_locations=False, jitter_val=jitter_val
)
# Make sure we don't try to initialize variational parameters - because of minibatching
self.variational_params_initialized.fill_(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this line get deleted?

)
# initialize with a small variational stddev for quicker conv. of kl divergence
self._variational_distribution._variational_stddev.data.copy_(torch.tensor(1e-2))
self.variational_params_initialized.fill_(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh okay, I see you've added a new initialization scheme.

Copy link
Contributor Author

@LuhuanWu LuhuanWu Jul 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. In practice I found that the variational standard deviation tends to shrink towards 0 in the end given that inducing points = data points. If initialized with ones, the KL term is way larger than log likelihood term, resulting a long time to converge. Initializing with a smaller value speeds up the convergence.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the future, it's good to list all of these little changes in the PR description as well :)

@@ -266,78 +298,121 @@ def _firstk_kl_helper(self) -> Tensor:
variational_inducing_covar = DiagLinearOperator(variational_covar_fisrtk)

variational_distribution = MultivariateNormal(inducing_values, variational_inducing_covar)
kl = torch.distributions.kl.kl_divergence(variational_distribution, prior_dist) # model_batch_shape
with settings.max_preconditioner_size(0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this setting necessary? Can you add a comment in code?

Copy link
Contributor Author

@LuhuanWu LuhuanWu Jul 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following the KL computation in _variational_strategy, see this line. What comment do you think is suitable here? Or do you suggest removing this line?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get rid of it for now, I don't think it's necessary. And we might want to remove it eventually from _variational_strategy, but it's not necessary for this PR.

@@ -359,5 +434,7 @@ def _compute_nn(self) -> "NNVariationalStrategy":
with torch.no_grad():
inducing_points_fl = self.inducing_points.data.float()
self.nn_util.set_nn_idx(inducing_points_fl)
self.nn_xinduce_idx = self.nn_util.build_sequential_nn_idx(inducing_points_fl)
if self.k < self.M:
self.nn_xinduce_idx = self.nn_util.build_sequential_nn_idx(inducing_points_fl)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a test in code for the k < M case? If not, can you add one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think k = 3 in the test code, which is smaller than M. I could add a case when k = M later, which is missing currently.

@@ -115,7 +115,7 @@ def _training_iter(
return output, loss

def _eval_iter(self, model, cuda=False):
inducing_batch_shape = model.variational_strategy.inducing_points.shape[:-2]
inducing_batch_shape = model.variational_strategy._inducing_batch_shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a unit test that runs VNNGP with batches? AFIAK, this unit test is still only running for non-batched VNNGP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think line 194 test_training_iteration_batch_model does the batch model testing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LuhuanWu why wasn't test_training_iteration_batch_model failing before this PR then?
Ideally we want to have a test case that (1) would've failed before this PR was added (capturing the behavior described in #2300) but (2) does not fail with the new code in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. I'll look into that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants