Skip to content

How to break kernel_fn of serial into multiple kernel_fns of Dense and Activation #179

Answered by romanngg
FabianoVeglianti asked this question in Q&A
Discussion options

You must be logged in to vote

Thanks for the detailed repro! Here's a fix for the last cell in your colab

### Test Cell
...
kernel_0_ouptut = layers_functions["kernel_fns"][0](x1.reshape(1,16), None)
print(kernel_0_ouptut)

kernel_1_ouptut = layers_functions["kernel_fns"][1](kernel_0_ouptut)
print(kernel_1_ouptut.nngp)

The change lies in passing the whole Kernel dataclass as the input to the kernel_fn as opposed to Kernel.nngp which is just an array. When it's only an array, it misses the necessary metadata and actually interpreted as if you passed an input (akin to x1), and this raises an error when it's passed to the ReLU function, the infinite width limit of which requires it to be proceeded with a Gaussian linear …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@romanngg
Comment options

Answer selected by FabianoVeglianti
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
question Further information is requested
2 participants