Skip to content

Commit

Permalink
random.keyArray has been removed from JAX
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcBerneman committed Feb 16, 2024
1 parent a4851b8 commit 5d82965
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions jax_md/rigid_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
PyTree = Any
f64 = util.f64
f32 = util.f32
KeyArray = random.KeyArray
NeighborListFns = partition.NeighborListFns
ShiftFn = space.ShiftFn

Expand Down Expand Up @@ -152,7 +151,7 @@ def _quaternion_rotate_bwd(res, g: Array) -> Tuple[Array, Array]:
_quaternion_rotate.defvjp(_quaternion_rotate_fwd, _quaternion_rotate_bwd)


def _random_quaternion(key: KeyArray, dtype: DType) -> Array:
def _random_quaternion(key: Array, dtype: DType) -> Array:
"""Generate a random quaternion of a given dtype."""
rnd = random.uniform(key, (3,), minval=0.0, maxval=1.0, dtype=dtype)

Expand Down Expand Up @@ -214,7 +213,7 @@ def quaternion_rotate(q: Quaternion, v: Array) -> Array:
return jnp.vectorize(_quaternion_rotate, signature='(q),(d)->(d)')(q.vec, v)


def random_quaternion(key: KeyArray, dtype: DType) -> Quaternion:
def random_quaternion(key: Array, dtype: DType) -> Quaternion:
"""Generate a random quaternion of a given dtype."""
rand_quat = partial(_random_quaternion, dtype=dtype)
rand_quat = jnp.vectorize(rand_quat, signature='(k)->(q)')
Expand Down

0 comments on commit 5d82965

Please sign in to comment.