Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code snippets in README.md give NaNs #254

Open
OmerRochman opened this issue Jan 19, 2023 · 0 comments
Open

Code snippets in README.md give NaNs #254

OmerRochman opened this issue Jan 19, 2023 · 0 comments

Comments

@OmerRochman
Copy link

When putting the snippets together and running the code after 1 or two timesteps the positions become NaNs. Possibly because under the LJ potential the energy and the forces are very large.

It's not a bug per se, but the this code is likely the first one newcomers will run and it would be nice if it worked. Changing the LJ potential to a soft sphere works, for example.

Cheers.

import jax.numpy as np

from jax_md import simulate
from jax_md import space
from jax import random
from jax_md import energy, quantity

box_size = 25.0
displacement_fn, shift_fn = space.periodic(box_size)
N = 1000

spatial_dimension = 2
key = random.PRNGKey(0)
R = random.uniform(key, (N, spatial_dimension), minval=0.0, maxval=0.1)

energy_fn = energy.lennard_jones_pair(displacement_fn)
print('E = {}'.format(energy_fn(R)))
force_fn = quantity.force(energy_fn)
print('Total Squared Force = {}'.format(np.sum(force_fn(R) ** 2)))

temperature = 1.0
dt = 1e-3
init, update = simulate.nvt_nose_hoover(energy_fn, shift_fn, dt, temperature)
state = init(key, R)
for _ in range(100):
  state = update(state)
R = state.position

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant