diff --git a/jax_md/_energy/electrostatics.py b/jax_md/_energy/electrostatics.py index 619ba040..2049f4c4 100644 --- a/jax_md/_energy/electrostatics.py +++ b/jax_md/_energy/electrostatics.py @@ -129,7 +129,7 @@ def energy_fn(position, **kwargs): S2 = jnp.abs(structure_factor(g, position, charge))**2 fn = lambda g2: jnp.exp(-g2 / (4*alpha**2)) / g2 * S2 - return Z * util.high_precision_sum(safe_mask(mask, fn, g2, 1)) + return Z * util.high_precision_sum(util.safe_mask(mask, fn, g2, 1)) return energy_fn