Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: Particularities of Autodifferentiation for Forces #273

Open
ESEberhard opened this issue Jul 6, 2023 · 0 comments
Open

Question: Particularities of Autodifferentiation for Forces #273

ESEberhard opened this issue Jul 6, 2023 · 0 comments

Comments

@ESEberhard
Copy link

ESEberhard commented Jul 6, 2023

I am wondering why the extra jnp.where is needed in the radial interaction term of the stillinger weber potential,
as I am trying to make sense why without it there is a discontinuity in the force.

I would really appreciate if somebody could elaborate on this, if they can find the time.

def _sw_radial_interaction(r, 
                           sigma=2.0951, 
                           B=0.6022245584, 
                           p=4, 
                           cutoff=1.8*2.0951):
    a = cutoff / sigma
    term1 = (B*(r/sigma)**(-p) - 1.0)
    within_cutoff = (r > 0) & (r < cutoff)
    r = np.where(within_cutoff, r, 0)   # Why is this line needed?
    term2 = np.exp(1/(r/sigma-a))
    return np.where(within_cutoff, term1 * term2, 0.0)

In my case I'm using this for the following bond potential, but this extra jnp.where to crop the dr term seems redundant.

def membrane_bond_gompper(dr: Array,
                        k: Array,
                        l_c0: Array,
                        l_c1: Array,
                        l_max: Array,
                        l_min: Array,
                        **unused_kwargs) -> Array:
  """.. _membrane-bond-gompper:

  Bond model as used by gompper group in
  https://doi.org/10.1038/s41586-020-2730-x

  Args:
    dr: An ndarray of shape `[n, m]` of pairwise distances between particles.
    k: bond stiffness
    l_cO: lower cutoff of attraction term
    l_max: maximum bond length (l_max !> l_c0)
    l_c1: upper cutoff of repulsion term
    l_min: minimum bond length (l_min !< l_c1)

  """
  assert l_max > l_c0
  assert l_min < l_c1

  # repulsive term of membrane bonds:
  within_rep_cutoff = (dr > l_min) & (dr < l_c1) # dr in (l_min, l_c1)
  dr_rep = jnp.where(within_rep_cutoff, dr, 0)  # This is needed to fix discontinuity in the force
  rep_pot =  jnp.exp(1.0 / (dr_rep - l_c1)) / (dr_rep - l_min)
  rep_term = jnp.where(within_rep_cutoff, rep_pot, 0.0)

  # attractive term of membrane bonds
  within_attr_cutoff = (dr > l_c0) & (dr < l_max) # dr in (l_c0, l_max)
  dr_attr = jnp.where(within_attr_cutoff, dr, 0)   # This is needed to fix discontinuity in the force
  attr_pot = jnp.exp(1 / (l_c0 - dr_attr)) / (l_max - dr_attr)
  attr_term = jnp.where(within_attr_cutoff, attr_pot, 0.0)

  return k * (attr_term + rep_term)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant