Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_nve_2d_neighbor_list_multi_atom_species in rigid_body_test.py is broken #277

Open
MarcBerneman opened this issue Jul 23, 2023 · 1 comment

Comments

@MarcBerneman
Copy link
Contributor

Because of a strange dependency on another test. The test only fails if it is run after test_nve_2d_neighbor_list. When run alone, no such problem occurs. Also, the problem is not present if the step_fn is not JITted (both in the test and in simulate.py).
My guess is that this bug is not a problem on the side of this project and that it is a problem on the side of JAX itself.

@MarcBerneman
Copy link
Contributor Author

MarcBerneman commented Feb 16, 2024

This error is again coming to the forefront, as seen in #299 and #207. This error is very strange because it only occurs when two independent functions are run successively. For example, here, the error occurs in the second call to nve_2d_neighbor_list.

Minimal working example:

import jax
from jax import random
from jax.config import config as jax_config
import jax.numpy as jnp
from jax_md import quantity
from jax_md import simulate
from jax_md import space
from jax_md import energy
from jax_md import util
from jax_md import rigid_body

jax_config.parse_flags_with_absl()

f32 = util.f32

PARTICLE_COUNT = 40


def nve_2d_neighbor_list(dtype):
    N = PARTICLE_COUNT
    box_size = quantity.box_size_at_number_density(N, 0.1, 2)

    displacement, shift = space.periodic(box_size)

    key = random.PRNGKey(0)

    key, pos_key, angle_key = random.split(key, 3)

    R = box_size * random.uniform(pos_key, (N, 2), dtype=dtype)
    angle = random.uniform(angle_key, (N,), dtype=dtype) * jnp.pi * 2

    body = rigid_body.RigidBody(R, angle)
    shape = rigid_body.square

    neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list(displacement,
                                                              box_size)
    neighbor_fn, energy_fn = rigid_body.point_energy_neighbor_list(energy_fn,
                                                                   neighbor_fn,
                                                                   shape)
    init_fn, step_fn = simulate.nve(energy_fn, shift)

    nbrs = neighbor_fn.allocate(body)
    init_fn(key, body, 1e-3, mass=shape.mass(), neighbor=nbrs)

nve_2d_neighbor_list(f32)
nve_2d_neighbor_list(f32)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant