Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add triangulated membrane potentials #292

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
182 changes: 180 additions & 2 deletions jax_md/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ def _ters_bij(R, D, c, d, h, lam3, beta, n, m,
# compute g_ijk - angle penalty value
costheta = quantity.cosine_angles(dRij)
gijk = 1.0 + (c**2 / d**2) - (c**2 / (d**2 + (h - costheta)**2))

# compute exponential term - distance penalty value
dr_diff = drij[:, None, :] - drik[:, :, None]
dr_diff = jnp.where(mask_ijk, dr_diff, 0)
Expand Down Expand Up @@ -1158,7 +1158,7 @@ def compute_fn(R, **kwargs):
dR = space.map_product(d)(R, R)
dr = space.distance(dR)
N = R.shape[0]
mask = jnp.where(1 - jnp.eye(N),
mask = jnp.where(1 - jnp.eye(N),
dr < params['R'] + params['D'], 0)
mask = mask.astype(R.dtype)
mask_ijk = mask[:, None, :] * mask[:, :, None]
Expand Down Expand Up @@ -2032,3 +2032,181 @@ def energy_fn(position, neighbor, **kwargs):

return neighbor_fn, energy_fn


# TRIANGULATED SURFACE POTENTIALS / MEMBRANE POTENTIALS

def triangle_area_potential(R_mem: Array,
triangles: Array,
displacement_fn: DisplacementOrMetricFn,
A_0: Array,
k: Array) -> Array:
""".. _triangle_area_potential

Local area conservation of the triangles in the vesicle discretization as
proposed by Vutukuri HR et al. (Gompper Group)
https://doi.org/10.1038/s41586-020-2730-x

Args:
R_mem (Array, shape: (N, spatial dim)): Positions of the membrane
vertices. (R_mem is allowed to contain non-membrane particles too,
only particles with indices in 'triangles' are considered in the
calculation)
triangles (Array, shape: (2(N-1), 3)): Array of triangles storing the indices
of the vertices that comprise the triangle
displacement_fn (DisplacementOrMetricFn): _description_
A_0 (Array): desired triangle area
k (Array): local-area conservation coefficient

Returns:
energy contribution due to local area conservation
"""
areas = _calc_triangle_areas(R_mem, triangles, displacement_fn)
energy = 0.5 * k * util.high_precision_sum((areas - A_0)**2 / A_0)
return energy


def _calc_triangle_areas(R_mem: Array,
triangles: Array,
displacement_fn: DisplacementOrMetricFn) -> Array:
"""
Calculate the areas of the triangle given an a point cloud and
its triangulation

Args:
R_mem (Array, shape: (N, spatial dim)): Positions of the membrane
vertices. (R_mem is allowed to contain non-membrane particles too,
only particles with indices in 'triangles' are considered in the
calculation)
triangles (Array, shape: (2(N-1), 3)): Array of triangles storing the indices
of the vertices that comprise the triangle
displacement_fn (DisplacementOrMetricFn): _description_

Returns:
Array of shape (N) where the i-th entry is the area of the i-th triangle
of the triangles array.
"""
R0 = R_mem[triangles[:,0]]
vec_displacement_fn = vmap(displacement_fn)
dR1 = vec_displacement_fn(R_mem[triangles[:,1]], R0)
dR2 = vec_displacement_fn(R_mem[triangles[:,2]], R0)
dr1 = space.distance(dR1)
dr2 = space.distance(dR2)
cos = vmap(quantity.cosine_angle_between_two_vectors)(dR1, dR2)
sin = jnp.sqrt(1 - cos**2)
return 0.5 * dr1 * dr2 * sin


def volume_potential(R_mem: Array,
triangles: Array,
V_0: float,
k: float):
""".. _volume_potential
Global volume conservation of the vesicle as used by Vutukuri HR et al.
(Gompper Group) https://doi.org/10.1038/s41586-020-2730-x

Args:
R_mem (Array, shape: (N, spatial dim)): Positions of the membrane
vertices. (R_mem is allowed to contain non-membrane particles too,
only particles with indices in 'triangles' are considered in the
calculation)
triangles (Array, shape: (2(N-1), 3)): Array of triangles storing the indices
of the vertices that comprise the triangle
V_0 (float): desired vesicle volume
k (float): volume stiffness

Returns:
energy contribution due to global volume conservation
"""
V = _calc_volume(R_mem, triangles)
return k * (V - V_0)**2 / (2 * V_0)


def _calc_volume(R_mem: Array,
triangles: Array) -> Array:
"""
Calculates the volume enclosed by a triangulated surface
(arbitrary non-convex polyhedron).
Assumes vertices are ordered in triangle list.
Implementation Follows https://doi.org/10.1109/MCG.1984.6429334

Args:
R_mem (Array, shape: (N, spatial dim)): Positions of the membrane
vertices. (R_mem is allowed to contain non-membrane particles too,
only particles with indices in 'triangles' are considered in the
calculation)
triangles (Array, shape: (2(N-1), 3)): Array of triangles storing the indices
of the vertices that comprise the triangle
Returns:
Volume enclosed by triangulated surface
"""
a = R_mem[triangles[:,0]]
b = R_mem[triangles[:,1]]
c = R_mem[triangles[:,2]]

det = a[:, 0] * (b[:,1] * c[:,2] - c[:,1] * b[:,2]) \
- a[:,1] * (b[:,0] * c[:,2] - c[:,0] * b[:,2]) \
+ a[:,2] * (b[:,0] * c[:,1] - c[:,0] * b[:,1])

return jnp.abs(util.high_precision_sum(det)) / 6


def bending_potential(R_mem: Array,
triangles: Array,
displacement_fn: DisplacementOrMetricFn,
kappa: Array) -> Array:
""".. _bending_potential
Calculates the bending potential of an triangulated surface,
based on discretization by Gompper: https://doi.org/10.1051/jp1:1996246

Args:
R_mem (Array, shape: (N, spatial dim)): Positions of the membrane
vertices
triangles (Array, shape: (N, 3)): Array of triangles storing the indices
of the vertices that comprise the triangle
displacement_fn (DisplacementOrMetricFn): _description_
kappa (Array, scalar): Bending rigidity

Returns:
Energy of the membrane conformation due to the bending potential
"""
displacement_fn = vmap(displacement_fn)
cos =vmap(quantity.cosine_angle_between_two_vectors)
N = R_mem.shape[0]

R0 = R_mem[triangles[:,0]]
R1 = R_mem[triangles[:,1]]
R2 = R_mem[triangles[:,2]]

dR01 = displacement_fn(R1, R0) # R1 - R0
dR12 = displacement_fn(R2, R1) # R2 - R1
dR20 = displacement_fn(R0, R2) # R0 - R2

dr01 = space.distance(dR01)
dr12 = space.distance(dR12)
dr20 = space.distance(dR20)

cos01 = cos(dR20, -dR12)
cos12 = cos(dR01, -dR20)
cos20 = cos(dR12, -dR01)

cot01 = cos01 / (jnp.sqrt(1 - cos01**2) + 1e-7)
cot12 = cos12 / (jnp.sqrt(1 - cos12**2) + 1e-7)
cot20 = cos20 / (jnp.sqrt(1 - cos20**2) + 1e-7)

sigma0 = dr01**2 * cot01 + dr20**2 * cot20
sigma1 = dr01**2 * cot01 + dr12**2 * cot12
sigma2 = dr12**2 * cot12 + dr20**2 * cot20

sigma = jnp.zeros(N) # sigma per vertex
sigma = sigma.at[triangles[:,0]].add(sigma0)
sigma = sigma.at[triangles[:,1]].add(sigma1)
sigma = sigma.at[triangles[:,2]].add(sigma2)
sigma = sigma / 8

rho = jnp.zeros((N, 3)) # per Vertex
rho = rho.at[triangles[:,0]].add(cot20[:,None] * dR20 - cot01[:,None] * dR01)
rho = rho.at[triangles[:,1]].add(cot01[:,None] * dR01 - cot12[:,None] * dR12)
rho = rho.at[triangles[:,2]].add(cot12[:,None] * dR12 - cot20[:,None] * dR20)

per_particle = jnp.sum(rho * rho, axis = 1) / (sigma + 1e-7)
return (kappa / 8) * util.high_precision_sum(per_particle)