Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] BBB vs BBB w/ Local Reparameterization #14

Open
danielkelshaw opened this issue Apr 30, 2020 · 4 comments
Open

[Question] BBB vs BBB w/ Local Reparameterization #14

danielkelshaw opened this issue Apr 30, 2020 · 4 comments
Assignees
Labels
good first issue Good for newcomers question Further information is requested

Comments

@danielkelshaw
Copy link

Hi @JavierAntoran @stratisMarkou,

First of all, thanks for making all of this code available - it's been great to look through!

Im currently spending some time trying to work through the Weight Uncertainty in Neural Networks in order to implement Bayes-by-Backprop. I was struggling to understand the difference between your implementation of Bayes-by-Backprop and Bayes-by-Backprop with Local Reparameterization.

I was under the impression that the local reparameterization was the following:

eps_W = Variable(self.W_mu.data.new(self.W_mu.size()).normal_())
eps_b = Variable(self.b_mu.data.new(self.b_mu.size()).normal_())
# sample parameters
std_w = 1e-6 + F.softplus(self.W_p, beta=1, threshold=20)
std_b = 1e-6 + F.softplus(self.b_p, beta=1, threshold=20)
W = self.W_mu + 1 * std_w * eps_W
b = self.b_mu + 1 * std_b * eps_b

However this same approach is used in both methods.

The main difference I see in the code you've implemented is the calculation of the KL Divergence in closed form in the Local Reparameterization version of the code due to the use of a Gaussian prior / posterior distribution.

I was wondering if my understanding of the local reparameterization method was wrong, or if I had simply misunderstood the code?

Any guidance would be much appreciated!

@JavierAntoran JavierAntoran self-assigned this Apr 30, 2020
@JavierAntoran JavierAntoran added the question Further information is requested label Apr 30, 2020
@danielkelshaw
Copy link
Author

Furthermore, your implementation of the closed form KL Divergence is the same as seen in equation 10 of the Auto-Encoding Variational Bayes paper:

def KLD_cost(mu_p, sig_p, mu_q, sig_q):
KLD = 0.5 * (2 * torch.log(sig_p / sig_q) - 1 + (sig_q / sig_p).pow(2) + ((mu_p - mu_q) / sig_p).pow(2)).sum()
# https://arxiv.org/abs/1312.6114 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
return KLD

I was wondering if you could provide any detail on how you arrived at the equation that you implemented in the code?

Thanks again!

@JavierAntoran
Copy link
Owner

Hi @danielkelshaw,

Thanks for your question. Similarly to the regular reparametrisation trick, the local reparametrisation trick is used to estimate gradients with respect to parameters of a distribution. However, the local reparametrisation trick takes advantage of the fact that, for a fixed input and Gaussian distributions over the weights, the resulting distribution over activations is also Gaussian. Instead of sampling all the weights individually and then combining them with the inputs to compute a sample from the activations, we can directly sample from the distribution over activations. This results in a lower variance gradient estimator which in turn makes training faster and more stable. Using the local reparametrisation trick is always recommended if possible.

The code for both gradient estimators is similar but not quite the same. In the code you referenced, if you look closely, you can see that we first sample the Gaussian weights:

        W = self.W_mu + 1 * std_w * eps_W
        b = self.b_mu + 1 * std_b * eps_b

And then pass the input through a linear layer with parameters that we just sampled:

        output = torch.mm(X, W) + b.unsqueeze(0).expand(X.shape[0], -1)  # (batch_size, n_output) 

On the other hand, for the local reparametrisation trick, we compute the parameters of the Gaussian over activations directly:

        act_W_mu = torch.mm(X, self.W_mu)  # self.W_mu + std_w * eps_W
        act_W_std = torch.sqrt(torch.mm(X.pow(2), std_w.pow(2)))

And then sample from the distribution over activations.

        act_W_out = act_W_mu + act_W_std * eps_W  # (batch_size, n_output)
        act_b_out = self.b_mu + std_b * eps_b

        output = act_W_out + act_b_out.unsqueeze(0).expand(X.shape[0], -1)

With regard to the KL divergence, the form used in regular BayesByBackprop is more general but requires using MC sampling to estimate it. It has the benefit of allowing for non-Gaussian priors and non-Gaussian approximate posteriors (note that our code implements the former but not the latter). We use the same weight samples to compute the model predictions and KL divergence here, saving compute and reducing variance due to the law of common random numbers.

When running the local reparametrization trick, we sample activations instead of weights. Thus, we don't have access to weight samples needed to estimate the KL divergence. Because of this, we opted for the closed-form implementation. It restricts us to the Gaussian prior but has lower variance and results in faster convergence.

With regard to your second question: the KL divergence between 2 Gaussians can be obtained in closed form by solving a Gaussian integral. See: https://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians

@danielkelshaw
Copy link
Author

@JavierAntoran - thank you for taking the time to help explain this, I really appreciate it!

I found your explanation of the local reparameterisation trick very intuitive and feel like I've got a much better grasp of that now.

I'm very interested in learning more about Bayesian Neural Networks, I was wondering if you had any recommended reading that would help get me up to speed with some more of the theory?

@JavierAntoran
Copy link
Owner

For general ideas about re-casting learning as inference, I would check out chapter 41 of David MacKay's Information Theory, Inference, and Learning Algorithms. Yarin Gal's thesis is also a good source.

On the more practical side, the tutorial made by the guys at Papercup is quite nice.

Other than that, read the papers implemented in this repo and try to understand both the algorithm and implementation.

@JavierAntoran JavierAntoran added the good first issue Good for newcomers label May 1, 2020
@JavierAntoran JavierAntoran pinned this issue May 1, 2020
@stratisMarkou stratisMarkou unpinned this issue May 12, 2020
@stratisMarkou stratisMarkou pinned this issue May 12, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants