Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NTK/NNGP behavior in the infinite regime when weights are drawn from Gaussians with high standard deviation #197

Open
tengandreaxu opened this issue Dec 29, 2023 · 7 comments
Labels
question Further information is requested

Comments

@tengandreaxu
Copy link

Hi everyone, thank you so much for your exceptional work!

I'm encountering some numerical issues when weights are drawn from Gaussians with a high standard deviation. Please see the snippet below:

import numpy as np
from neural_tangents import stax
from jax import jit

W_stds = list(range(1, 17))
# W_stds.reverse()
layer_fn = []
for i in range(len(W_stds) - 1):
    layer_fn.append(stax.Dense(1, W_std=W_stds[i]))
    layer_fn.append(stax.Relu())

layer_fn.append(stax.Dense(1, 1.0, 0.0))
_, _, kernel_fn = stax.serial(*layer_fn)

kernel_fn = jit(kernel_fn, static_argnames="get")

x = np.random.rand(100, 100)

print(kernel_fn(x, x, "ntk"))

The result achieves:

[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]

By enabling float64 precision, the results indicate numerical values blowing up:

[[2.2293401e+18 9.3420067e+17 9.2034030e+17 ... 8.9008971e+17
  9.6801663e+17 9.6436509e+17]
 [9.3420067e+17 2.3730658e+18 9.4658846e+17 ... 9.6854199e+17
  9.6182735e+17 9.9944418e+17]
 [9.2034030e+17 9.4658846e+17 2.3106050e+18 ... 9.1702287e+17
  9.5415269e+17 9.9692925e+17]
 ...
 [8.9008971e+17 9.6854199e+17 9.1702300e+17 ... 2.2127619e+18
  9.2056034e+17 1.0147568e+18]
 [9.6801663e+17 9.6182728e+17 9.5415269e+17 ... 9.2056034e+17
  2.3979914e+18 9.9505658e+17]
 [9.6436488e+17 9.9944418e+17 9.9692925e+17 ... 1.0147568e+18
  9.9505658e+17 2.4954969e+18]]

What's interesting is that the behavior appears to be more dependent on the depth than the high values in the weights' standard deviation. If the standard deviation of the weights were reversed (by uncommenting the code), so that in layer 1 we would have $w_{ij} \sim \mathcal{N}(0,17)$, and so on so forth. The results would remain unchanged.

Thank you in advance, and happy new year!

@zhangbububu
Copy link

hi, how can i enabling float64 precision ?

@romanngg
Copy link
Contributor

Sorry for the late reply!

@zhangbububu see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision

@tengandreaxu could you try using Relu(do_stabilize=True)? https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.Relu.html This parameter triggers a way of calculating the nonlinearity kernel in a way that helps prevent numerical overflow.

@romanngg romanngg added the question Further information is requested label Jan 28, 2024
@tengandreaxu
Copy link
Author

Thank you so much, Roman. It's no problem at all!

import numpy as np
from neural_tangents import stax
from jax import jit

W_stds = list(range(1, 17))
# W_stds.reverse()
layer_fn = []
for i in range(len(W_stds) - 1):
    layer_fn.append(stax.Dense(1, W_std=W_stds[i]))
    layer_fn.append(stax.Relu(do_stabilize=True))

layer_fn.append(stax.Dense(1, 1.0, 0.0))
_, _, kernel_fn = stax.serial(*layer_fn)

kernel_fn = jit(kernel_fn, static_argnames="get")

x = np.random.rand(100, 100)

print(kernel_fn(x, x, "ntk"))

results in

[[2.61008562e+20 1.12163820e+20 1.23732785e+20 ... 1.08229372e+20
  1.05533967e+20 1.10687273e+20]
 [1.12163820e+20 2.92078984e+20 1.31143308e+20 ... 1.16449180e+20
  1.15616286e+20 1.19062657e+20]
 [1.23732785e+20 1.31143308e+20 3.36093753e+20 ... 1.28641726e+20
  1.19473708e+20 1.28997387e+20]
 ...
 [1.08229363e+20 1.16449180e+20 1.28641726e+20 ... 2.74442324e+20
  1.07858132e+20 1.20695995e+20]
 [1.05533967e+20 1.15616286e+20 1.19473708e+20 ... 1.07858132e+20
  2.69344883e+20 1.11830439e+20]
 [1.10687273e+20 1.19062657e+20 1.28997387e+20 ... 1.20695995e+20
  1.11830439e+20 2.83645061e+20]]

Do you think that there is no sense in having weights drawn from a higher standard deviation as we go deeper into the neural net in the infinite width?

@zhangbububu
Copy link

@romanngg @tengandreaxu

hi, i meet a confuse problem

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True),
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True),
    stax.Dense(1, W_std=1.5, b_std=0.05)
)

s = 10
l = jnp.pi * -s
r = jnp.pi * s 
N_tr = 100
N_te = 5
train_xs = jnp.linspace(l, r , N_tr).reshape(-1, 1).astype(jnp.float64)
train_ys = jnp.sin(train_xs) + jnp.sin(2*train_xs).astype(jnp.float64)
test_xs = jnp.linspace(l, r, N_te).reshape(-1, 1).astype(jnp.float64)

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_xs,
                                                      train_ys, diag_reg=1e-4)
nkt_mean, nkt_covariance = predict_fn(x_test=test_xs, get='ntk',
                                        compute_cov=True)
print(nkt_mean)


if i increate the number of training samples (N_tr), i will get a all NaN nkt_mean

@zhangbububu
Copy link

@romanngg @tengandreaxu

hi, i meet a confuse problem

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True),
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True),
    stax.Dense(1, W_std=1.5, b_std=0.05)
)

s = 10
l = jnp.pi * -s
r = jnp.pi * s 
N_tr = 100
N_te = 5
train_xs = jnp.linspace(l, r , N_tr).reshape(-1, 1).astype(jnp.float64)
train_ys = jnp.sin(train_xs) + jnp.sin(2*train_xs).astype(jnp.float64)
test_xs = jnp.linspace(l, r, N_te).reshape(-1, 1).astype(jnp.float64)

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_xs,
                                                      train_ys, diag_reg=1e-4)
nkt_mean, nkt_covariance = predict_fn(x_test=test_xs, get='ntk',
                                        compute_cov=True)
print(nkt_mean)

if i increate the number of training samples (N_tr), i will get a all NaN nkt_mean

image image

@romanngg
Copy link
Contributor

@tengandreaxu

Do you think that there is no sense in having weights drawn from a higher standard deviation as we go deeper into the neural net in the infinite width?

I think so, ideally you would want the mean and variance of your network outputs to match the mean and variance of your training labels, as a sensible prior. But even if your training labels have a large variance, it's common practice to just standardize them (together with test labels) to have mean 0 and variance 1 for best numerical stability.

Then in a Relu network, to have mean zero / variance one outputs (given mean zero, variance one inputs), you would want to set W_std=2**0.5 for all intermediate layers preceding Relus, and W_std=1 for the top layer.

@zhangbububu replied in your separate thread, let's continue there.

@tengandreaxu
Copy link
Author

Thank you for your prompt help Roman!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants