Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to compute the empirical after kernel? #189

Open
VMS-6511 opened this issue Sep 10, 2023 · 1 comment
Open

How to compute the empirical after kernel? #189

VMS-6511 opened this issue Sep 10, 2023 · 1 comment
Labels
question Further information is requested

Comments

@VMS-6511
Copy link

I'm looking to use the library to compute the after kernel for a model trained with the FLAX library? I followed this Colab: https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_resnet.ipynb.

Instead of these lines:

  params = model.init(random.PRNGKey(0), x1)
  return params, (jacobian_contraction, ntvp, str_derivatives, auto)

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)

I used the params from the following TrainState of the FLAX model:

state = TrainState.create(
    apply_fn = model.apply,
    params = variables['params'],
    batch_stats = variables['batch_stats'],
    tx = tx)

I was wondering if this is the correct way to do this? Thanks!

@romanngg
Copy link
Contributor

Hi Vinith, yes I think this is correct, if something isn't working as expected let me know!

@romanngg romanngg added the question Further information is requested label Sep 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants