Skip to content

Commit

Permalink
Use JAX type annotation for random keys. Fix pypi links to tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 430463106
  • Loading branch information
romanngg committed Feb 23, 2022
1 parent 6d49d59 commit 8863a8b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
8 changes: 4 additions & 4 deletions neural_tangents/_src/monte_carlo.py
Expand Up @@ -37,7 +37,7 @@
from jax.tree_util import tree_map
from jax.tree_util import tree_multimap
from .utils import utils
from .utils.typing import ApplyFn, Axes, EmpiricalGetKernelFn, Get, InitFn, MonteCarloKernelFn, NTTree, PRNGKey, PyTree, VMapAxes
from .utils.typing import ApplyFn, Axes, EmpiricalGetKernelFn, Get, InitFn, MonteCarloKernelFn, NTTree, PyTree, VMapAxes


def _sample_once_kernel_fn(kernel_fn: EmpiricalGetKernelFn,
Expand All @@ -52,7 +52,7 @@ def _sample_once_kernel_fn(kernel_fn: EmpiricalGetKernelFn,
def kernel_fn_sample_once(
x1: NTTree[np.ndarray],
x2: Optional[NTTree[np.ndarray]],
key: PRNGKey,
key: random.KeyArray,
get: Get,
**apply_fn_kwargs):
init_key, dropout_key = random.split(key, 2)
Expand All @@ -64,7 +64,7 @@ def kernel_fn_sample_once(

def _sample_many_kernel_fn(
kernel_fn_sample_once,
key: PRNGKey,
key: random.KeyArray,
n_samples: Set[int],
get_generator: bool):
def normalize(sample: PyTree, n: int) -> PyTree:
Expand Down Expand Up @@ -115,7 +115,7 @@ def get_sampled_kernel(
def monte_carlo_kernel_fn(
init_fn: InitFn,
apply_fn: ApplyFn,
key: PRNGKey,
key: random.KeyArray,
n_samples: Union[int, Iterable[int]],
batch_size: int = 0,
device_count: int = -1,
Expand Down
11 changes: 2 additions & 9 deletions neural_tangents/_src/utils/typing.py
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union

import jax.numpy as np
from jax import random
from .kernel import Kernel
from typing_extensions import Protocol

Expand All @@ -29,14 +30,6 @@
PyTree = Any


"""A type alias for PRNGKeys.
See https://jax.readthedocs.io/en/latest/jax.random.html#jax.random.PRNGKey
for details.
"""
PRNGKey = np.ndarray


"""A type alias for axes specification.
Axes can be specified as integers (`axis=-1`) or sequences (`axis=(1, 3)`).
Expand Down Expand Up @@ -81,7 +74,7 @@ class InitFn(Protocol):

def __call__(
self,
rng: PRNGKey,
rng: random.KeyArray,
input_shape: Shapes,
**kwargs
) -> Tuple[Shapes, PyTree]:
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Expand Up @@ -80,7 +80,10 @@ def _get_version() -> str:
'Bug Tracker': 'https://github.com/google/neural-tangents/issues',
'Release Notes': 'https://github.com/google/neural-tangents/releases',
'PyPi': 'https://pypi.org/project/neural-tangents/',
'Tests': 'https://travis-ci.org/github/google/neural-tangents',
'Linux Tests': 'https://github.com/google/neural-tangents/actions/workflows/linux.yml',
'macOS Tests': 'https://github.com/google/neural-tangents/actions/workflows/macos.yml',
'Pytype': 'https://github.com/google/neural-tangents/actions/workflows/pytype.yml',
'Coverage': 'https://app.codecov.io/gh/google/neural-tangents'
},
packages=setuptools.find_packages(exclude=('presentation',)),
long_description=long_description,
Expand Down

0 comments on commit 8863a8b

Please sign in to comment.