You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Because of a strange dependency on another test. The test only fails if it is run after test_nve_2d_neighbor_list. When run alone, no such problem occurs. Also, the problem is not present if the step_fn is not JITted (both in the test and in simulate.py).
My guess is that this bug is not a problem on the side of this project and that it is a problem on the side of JAX itself.
The text was updated successfully, but these errors were encountered:
This error is again coming to the forefront, as seen in #299 and #207. This error is very strange because it only occurs when two independent functions are run successively. For example, here, the error occurs in the second call to nve_2d_neighbor_list.
Minimal working example:
import jax
from jax import random
from jax.config import config as jax_config
import jax.numpy as jnp
from jax_md import quantity
from jax_md import simulate
from jax_md import space
from jax_md import energy
from jax_md import util
from jax_md import rigid_body
jax_config.parse_flags_with_absl()
f32 = util.f32
PARTICLE_COUNT = 40
def nve_2d_neighbor_list(dtype):
N = PARTICLE_COUNT
box_size = quantity.box_size_at_number_density(N, 0.1, 2)
displacement, shift = space.periodic(box_size)
key = random.PRNGKey(0)
key, pos_key, angle_key = random.split(key, 3)
R = box_size * random.uniform(pos_key, (N, 2), dtype=dtype)
angle = random.uniform(angle_key, (N,), dtype=dtype) * jnp.pi * 2
body = rigid_body.RigidBody(R, angle)
shape = rigid_body.square
neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list(displacement,
box_size)
neighbor_fn, energy_fn = rigid_body.point_energy_neighbor_list(energy_fn,
neighbor_fn,
shape)
init_fn, step_fn = simulate.nve(energy_fn, shift)
nbrs = neighbor_fn.allocate(body)
init_fn(key, body, 1e-3, mass=shape.mass(), neighbor=nbrs)
nve_2d_neighbor_list(f32)
nve_2d_neighbor_list(f32)
Because of a strange dependency on another test. The test only fails if it is run after test_nve_2d_neighbor_list. When run alone, no such problem occurs. Also, the problem is not present if the step_fn is not JITted (both in the test and in simulate.py).
My guess is that this bug is not a problem on the side of this project and that it is a problem on the side of JAX itself.
The text was updated successfully, but these errors were encountered: