v0.4.0
WARNING:
Our next major release (v0.5.0) will include significant refactoring, and could break your code if you use internal function like nt.utils.typing
, nt.utils.utils
, nt.utils.Kernel
etc. (public API will remain unchanged). This should be easily fixed by updating the imports, e.g. nt.utils -> nt._src.utils
.
This release (v0.4.0):
New feature:
Improvements:
- Various internal refactoring and tighter tests.
Bugfixes:
- Fix values and gradients of non-differentiable
kernel_fn
at zero inputs to be consistent with finite-width kernels, and how JAX defines gradients of non-differentiable functions to be the mean sub-gradient, see also #123. - Fix wrong treatment of
b_std=None
in the infinite-width limit withparameterization='standard'
, see also #123. - Fix a bug in
nt.batch
whenx2 = None
and inputs are PyTrees.
Breaking changes:
- Bump requirements to
jax==0.3
andfrozendict==2.3
.