Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaNs for Lennard Jones potential gradients. #258

Open
timduignan opened this issue Mar 4, 2023 · 1 comment
Open

NaNs for Lennard Jones potential gradients. #258

timduignan opened this issue Mar 4, 2023 · 1 comment

Comments

@timduignan
Copy link

Firstly jax-md is awesome thanks for building!

So I'm trying to take the derivative of forces with respect to the Lennard Jones parameters and I'm getting nans. I found a fix (see below) but it doesn't seem ideal

e.g.

from jax import grad
import jax.numpy as jnp
from jax_md import energy,space,quantity

def total_force(R,sigma):
    displacement_fn, shift=space.free()
    energy_fn = energy.lennard_jones_pair(displacement_fn,sigma=sigma)
    force_fn = quantity.force(energy_fn)
    F=force_fn(R)
    return jnp.sum(F**2.0)

def main():
    R=jnp.array([[1.0,2.0,3.0],[2.0,3.0,5.0]])
    print(total_force(R,jnp.float32(2.0)))
    total_force_grad=grad(total_force, argnums=1)
    print(total_force_grad(R,2.0))

main()

Prints:

2.7977674
nan

It can be fixed by replacing:

idr = (sigma / dr)

with

eps=f32(1e-32)
idr = (sigma / (dr+eps))

in jax_md/energy.py

But I don’t know if there is a better fix or if I'm doing something wrong. I also don't understand why I’m getting nans as the derivative is with respect to sigma which is in the numerator, whereas it can give the derivatives of the forces ok.

Thanks very much
Tim

@abhijeetgangan
Copy link
Contributor

Hi,

I had the same issue as you. Looks like the issue is with reverse mode differentiation. The forward mode works fine.

from jax import grad, jacfwd
import jax.numpy as jnp
from jax_md import energy,space,quantity

def total_force(R,sigma):
    displacement_fn, shift=space.free()
    energy_fn = energy.lennard_jones_pair(displacement_fn,sigma=sigma)
    force_fn = quantity.force(energy_fn)
    F=force_fn(R)
    return jnp.sum(F**2.0)

def main():
    R=jnp.array([[1.0,2.0,3.0],[2.0,3.0,5.0]])
    print(total_force(R,jnp.float32(2.0)))
    total_force_rev = grad(total_force, argnums=1)
    total_force_fwd = jacfwd(total_force, argnums=1)
    print(total_force_rev(R,2.0))
    print(total_force_fwd(R,2.0))

main()

The above code prints:

2.7977674
nan
-7.6302614

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants