Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question regarding the cookbook #200

Open
XinyiYS opened this issue Feb 26, 2024 · 0 comments
Open

Question regarding the cookbook #200

XinyiYS opened this issue Feb 26, 2024 · 0 comments

Comments

@XinyiYS
Copy link

XinyiYS commented Feb 26, 2024

Hi there, thanks for the great repo!

I was working through the Neural Tangents Cookbook and am a bit confused by the loss_fn (reproduced below):

def loss_fn(predict_fn, ys, t, xs=None):
  mean, cov = predict_fn(t=t, get='ntk', x_test=xs, compute_cov=True)
  mean = jnp.reshape(mean, mean.shape[:1] + (-1,))
  var = jnp.diagonal(cov, axis1=1, axis2=2)
  ys = jnp.reshape(ys, (1, -1))

  mean_predictions = 0.5 * jnp.mean(ys ** 2 - 2 * mean * ys + var + mean ** 2,
                                   axis=1)

  return mean_predictions

It looks like this function is later used to calculate the training or test losses for plotting. What I am confused by is, the calculation for (each test point in) the mean_predictions contains var, making it effectively the sum of the squared error (between a prediction and a label) and the variance. While it does make sense to include the variance as part of the performance (or loss), but why this speicfic form (e.g., why $+ 1 \times \text{var}$ instead of $2 \times \text{var}$ or why variance and not standard deviation, and why is there a $0.5$ in front)? Perhaps you could point me to a reference that I probably missed somewhere?

Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant