You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
So I'm trying to take the derivative of forces with respect to the Lennard Jones parameters and I'm getting nans. I found a fix (see below) but it doesn't seem ideal
e.g.
from jax import grad
import jax.numpy as jnp
from jax_md import energy,space,quantity
def total_force(R,sigma):
displacement_fn, shift=space.free()
energy_fn = energy.lennard_jones_pair(displacement_fn,sigma=sigma)
force_fn = quantity.force(energy_fn)
F=force_fn(R)
return jnp.sum(F**2.0)
def main():
R=jnp.array([[1.0,2.0,3.0],[2.0,3.0,5.0]])
print(total_force(R,jnp.float32(2.0)))
total_force_grad=grad(total_force, argnums=1)
print(total_force_grad(R,2.0))
main()
Prints:
2.7977674
nan
It can be fixed by replacing:
idr = (sigma / dr)
with
eps=f32(1e-32)
idr = (sigma / (dr+eps))
in jax_md/energy.py
But I don’t know if there is a better fix or if I'm doing something wrong. I also don't understand why I’m getting nans as the derivative is with respect to sigma which is in the numerator, whereas it can give the derivatives of the forces ok.
Thanks very much
Tim
The text was updated successfully, but these errors were encountered:
Firstly jax-md is awesome thanks for building!
So I'm trying to take the derivative of forces with respect to the Lennard Jones parameters and I'm getting nans. I found a fix (see below) but it doesn't seem ideal
e.g.
Prints:
It can be fixed by replacing:
idr = (sigma / dr)
with
in jax_md/energy.py
But I don’t know if there is a better fix or if I'm doing something wrong. I also don't understand why I’m getting nans as the derivative is with respect to sigma which is in the numerator, whereas it can give the derivatives of the forces ok.
Thanks very much
Tim
The text was updated successfully, but these errors were encountered: