New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix VNNGP with batches #2375
base: master
Are you sure you want to change the base?
Fix VNNGP with batches #2375
Conversation
LuhuanWu
commented
Jul 9, 2023
- Fix VNNGP in batch settings VNNGP with Batches #2300
- In addition, set default jitter value to 1e-3 and variational stddev init value to 1e-2.
@@ -87,8 +90,6 @@ def __init__( | |||
super().__init__( | |||
model, inducing_points, variational_distribution, learn_inducing_locations=False, jitter_val=jitter_val | |||
) | |||
# Make sure we don't try to initialize variational parameters - because of minibatching | |||
self.variational_params_initialized.fill_(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did this line get deleted?
) | ||
# initialize with a small variational stddev for quicker conv. of kl divergence | ||
self._variational_distribution._variational_stddev.data.copy_(torch.tensor(1e-2)) | ||
self.variational_params_initialized.fill_(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahhh okay, I see you've added a new initialization scheme.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. In practice I found that the variational standard deviation tends to shrink towards 0 in the end given that inducing points = data points. If initialized with ones, the KL term is way larger than log likelihood term, resulting a long time to converge. Initializing with a smaller value speeds up the convergence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the future, it's good to list all of these little changes in the PR description as well :)
@@ -266,78 +298,121 @@ def _firstk_kl_helper(self) -> Tensor: | |||
variational_inducing_covar = DiagLinearOperator(variational_covar_fisrtk) | |||
|
|||
variational_distribution = MultivariateNormal(inducing_values, variational_inducing_covar) | |||
kl = torch.distributions.kl.kl_divergence(variational_distribution, prior_dist) # model_batch_shape | |||
with settings.max_preconditioner_size(0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this setting necessary? Can you add a comment in code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was following the KL computation in _variational_strategy
, see this line. What comment do you think is suitable here? Or do you suggest removing this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's get rid of it for now, I don't think it's necessary. And we might want to remove it eventually from _variational_strategy, but it's not necessary for this PR.
@@ -359,5 +434,7 @@ def _compute_nn(self) -> "NNVariationalStrategy": | |||
with torch.no_grad(): | |||
inducing_points_fl = self.inducing_points.data.float() | |||
self.nn_util.set_nn_idx(inducing_points_fl) | |||
self.nn_xinduce_idx = self.nn_util.build_sequential_nn_idx(inducing_points_fl) | |||
if self.k < self.M: | |||
self.nn_xinduce_idx = self.nn_util.build_sequential_nn_idx(inducing_points_fl) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a test in code for the k < M
case? If not, can you add one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think k = 3 in the test code, which is smaller than M. I could add a case when k = M later, which is missing currently.
@@ -115,7 +115,7 @@ def _training_iter( | |||
return output, loss | |||
|
|||
def _eval_iter(self, model, cuda=False): | |||
inducing_batch_shape = model.variational_strategy.inducing_points.shape[:-2] | |||
inducing_batch_shape = model.variational_strategy._inducing_batch_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a unit test that runs VNNGP with batches? AFIAK, this unit test is still only running for non-batched VNNGP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think line 194 test_training_iteration_batch_model
does the batch model testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. I'll look into that.