Skip to content

Commit

Permalink
Merge pull request #4 from cagrikymk/reaxff_dev
Browse files Browse the repository at this point in the history
Reaxff dev
  • Loading branch information
cagrikymk committed Jun 26, 2023
2 parents 2040465 + e4c1a50 commit c36c715
Show file tree
Hide file tree
Showing 21 changed files with 6,006 additions and 525 deletions.
6 changes: 0 additions & 6 deletions jax_md/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,6 @@ def clz_from_iterable(meta, data):
kwargs = dict(meta_args + data_args)
return data_clz(**kwargs)

def replace(self, **updates):
""""Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)

data_clz.replace = replace

jax.tree_util.register_pytree_node(data_clz,
iterate_clz,
clz_from_iterable)
Expand Down
170 changes: 169 additions & 1 deletion jax_md/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ def load_lammps_tersoff_parameters(file: TextIO) -> Array:
else:
skip = False

words[3:] = f32(words[3:])
words[3:] = f64(words[3:])
params.append({
'element1': words[0],
'element2': words[1],
Expand Down Expand Up @@ -1268,6 +1268,174 @@ def tersoff_from_lammps_parameters_neighbor_list(
fractional_coordinates=fractional_coordinates,
**neighbor_kwargs)

# (EDIP) Environment-dependent interatomic potential

def _edip_cutoff_function(r: Array, cutoff: f64, c: f64, alpha: f64) -> Array:
x_term = (r - c) / (cutoff - c)
expo_term = jnp.exp(alpha / (1 - (x_term ** (-3))))
outer = jnp.where(r > cutoff,
0,
expo_term)
inner = jnp.where(r > c, outer, 1)
return inner


def _edip_radial_interaction(A: f64, B: f64, rho: f64, sigma: f64, c: f64,
alpha: f64, beta: f64, cutoff: f64, mask,
r: Array) -> Array:
within_cutoff = (r > 0) & (r < cutoff)
repul = (B / r) ** (rho)
r = jnp.where(within_cutoff, r, 0)
Z_i = util.high_precision_sum(_edip_cutoff_function(r, cutoff, c, alpha) * mask,
axis=1, keepdims=True)
p_Z = jnp.exp(-beta * (Z_i ** 2))
term1 = repul - p_Z
term2 = jnp.exp(sigma / (r - cutoff))
return jnp.where(within_cutoff, A * term1 * term2, 0.0)


def _edip_angle_interaction(lam: f64, gamma: f64, Q_0: f64, cutoff: f64,
u1: f64, u2: f64, u3: f64, u4: f64, c: f64, eta: f64,
alpha: f64, mu: f64, mask, dR12: Array, dR13: Array) -> Array:
dr12 = space.distance(dR12)
dr13 = space.distance(dR13)
dr12 = jnp.where(dr12 < cutoff, dr12, 0)
dr13 = jnp.where(dr13 < cutoff, dr13, 0)
term1 = jnp.exp(gamma / (dr12 - cutoff) + gamma / (dr13 - cutoff))
l_ijk = quantity.cosine_angle_between_two_vectors(dR12, dR13)
Z_i = util.high_precision_sum(_edip_cutoff_function(dr13, cutoff, c, alpha) * mask,
keepdims=False)
tau_Z = u1 + u2 * ((u3 * jnp.exp(-u4 * Z_i)) - jnp.exp(-2 * u4 * Z_i))
Q_Z = Q_0 * jnp.exp(-mu * Z_i)
l_tau = (l_ijk + tau_Z) ** 2
term2 = ((1 - jnp.exp(-Q_Z * l_tau)) + (eta * Q_Z * l_tau))

within_cutoff = (dr12 > 0) & (dr13 > 0) & (jnp.linalg.norm(dR12 - dR13) > 1e-5)
return jnp.where(within_cutoff, lam * term1 * term2, 0)

def edip(displacement: DisplacementFn,
u1: f64 = -0.165799,
u2: f64 = 32.557,
u3: f64 = 0.286198,
u4: f64 = 0.66,
rho: f64 = 1.2085196,
eta: f64 = 0.2523244,
Q_0: f64 = 312.1341346,
mu: f64 = 0.6966326,
beta: f64 = 0.0070975,
alpha: f64 = 3.1083847,
A: f64 = 7.9821730,
lam: f64 = 1.4533108,
B: f64 = 1.5075463,
gamma: f64 = 1.1247945,
sigma: f64 = 0.5774108,
c: f64 = 2.5609104,
cutoff: f64 = 3.1213820) -> Callable[[Array], Array]:
"""
Computes the the Environment-dependent interatomic potential (EDIP).
The parameter values are for bulk Silicon [1,2]. The EDIP potential is a bond
order potential which depends on the local coordination number of the atom.
:param displacement: displacement function for the space.
:param u1: parameter for the three-body bond order function tau(Z) (pure number)
:param u2: parameter for the three-body bond order function tau(Z) (pure number)
:param u3: parameter for the three-body bond order function tau(Z) (pure number)
:param u4: parameter for the three-body bond order function tau(Z) (pure number)
:param rho: exponent for the repulsive part of two-body potential (pure number)
:param eta: parameter for the three-body term (pure number)
:param Q_0: parameter for the three-body bond order function Q(Z) (pure number)
:param mu: parameter for the three-body bond order function Q(Z) (pure number)
:param beta: parameter for the two-body bond order function p(Z) (pure number)
:param alpha: parameter for the cutoff function (pure number)
:param A: parameter that determines the energy scale of two-body term (eV)
:param lam: parameter that determines the energy scale of three-body term (eV)
:param B: parameter for the repulsive part of two-body potential (Angstrom)
:param gamma: parameter for the radial part of three-body term (Angstrom)
:param sigma: parameter that determines the distance scale between neighbors (Angstrom)
:param c: inner cutoff for the cutoff function f(r) (Angstrom)
:param cutoff: outer cutoff (a) for the cutoff function f(r) (Angstrom)
:return: A function that computes the potential energy.
References:
[1] - Martin Z. Bazant, Efthimios Kaxiras, and J. F. Justo.
"Environment-dependent interatomic potential for bulk silicon".
Phys. Rev. B 56, 8542 (1997).
[2] - João F. Justo, Martin Z. Bazant, Efthimios Kaxiras, V. V. Bulatov,
and Sidney Yip. "Interatomic potential for silicon defects and
disordered phases". Phys. Rev. B 58, 2539 (1998).
"""
two_body_fn = partial(_edip_radial_interaction, A, B, rho, sigma, c, alpha, beta, cutoff)
three_body_fn = partial(_edip_angle_interaction, lam, gamma, Q_0, cutoff,
u1, u2, u3, u4, c, eta,
alpha, mu)

def compute_fn(R, **kwargs):
_three_body_fn = vmap(vmap(vmap(three_body_fn, (None, 0, None)), (None, None, 0)))
d = partial(displacement, **kwargs)
dR = space.map_product(d)(R, R)
dr = space.distance(dR)
N = R.shape[0]
mask = (1 - jnp.eye(N)) * (dr < cutoff)
first_term = util.high_precision_sum(two_body_fn(mask, dr))
second_term = util.high_precision_sum(_three_body_fn(mask, dR, dR)) / 2.0
return first_term + second_term

return compute_fn


def edip_neighbor_list(displacement: DisplacementFn,
box_size: f64,
u1: f64 = -0.165799,
u2: f64 = 32.557,
u3: f64 = 0.286198,
u4: f64 = 0.66,
rho: f64 = 1.2085196,
eta: f64 = 0.2523244,
Q_0: f64 = 312.1341346,
mu: f64 = 0.6966326,
beta: f64 = 0.0070975,
alpha: f64 = 3.1083847,
A: f64 = 7.9821730,
lam: f64 = 1.4533108,
B: f64 = 1.5075463,
gamma: f64 = 1.1247945,
sigma: f64 = 0.5774108,
c: f64 = 2.5609104,
cutoff: f64 = 3.1213820,
dr_threshold: f64 = 0.0,
fractional_coordinates: bool = True,
format: NeighborListFormat = partition.Dense,
**neighbor_kwargs) -> Tuple[NeighborFn, Callable[[Array, NeighborList], Array]]:
two_body_fn = partial(_edip_radial_interaction, A, B, rho, sigma, c, alpha, beta, cutoff)
three_body_fn = partial(_edip_angle_interaction, lam, gamma, Q_0, cutoff,
u1, u2, u3, u4, c, eta,
alpha, mu)

neighbor_fn = partition.neighbor_list(displacement,
box_size,
cutoff,
dr_threshold,
format=format,
fractional_coordinates=fractional_coordinates,
**neighbor_kwargs)

def compute_fn(R, neighbor, **kwargs):
d = partial(displacement, **kwargs)
mask = partition.neighbor_list_mask(neighbor, mask_self=True)
if format is partition.Dense:
_three_body_fn = vmap(vmap(vmap(three_body_fn, (None, 0, None)), (None, None, 0)))
dR = space.map_neighbor(d)(R, R[neighbor.idx])
dr = space.distance(dR)
first_term = util.high_precision_sum(two_body_fn(mask, dr) * mask)
mask_ijk = mask[:, None, :] * mask[:, :, None]
second_term = util.high_precision_sum(_three_body_fn(mask, dR, dR) * mask_ijk) / 2.0
else:
raise NotImplementedError('EDIP potential only implemented '
'with Dense neighbor lists.')

return first_term + second_term

return neighbor_fn, compute_fn

# Embedded Atom Method

Expand Down
56 changes: 28 additions & 28 deletions jax_md/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def volume(dimension: int, box: Box) -> float:
def kinetic_energy(*unused_args,
momentum: Array=None,
velocity: Array=None,
mass: Array=f32(1.0)
mass: Array=1.0,
) -> float:
"""Computes the kinetic energy of a system.
Expand Down Expand Up @@ -153,7 +153,7 @@ def kinetic_energy(*unused_args,
def temperature(*unused_args,
momentum: Array=None,
velocity: Array=None,
mass: Array=f32(1.0)
mass: Array=1.0,
) -> float:
"""Computes the temperature of a system.
Expand Down Expand Up @@ -298,18 +298,18 @@ def average_pair_correlation_results(gofr, species=None):
returned.
When species is specified, gofr is expected to be a list of nspecies arrays,
each of shape (N,nr), where nspecies is the number of unique species types.
each of shape (N,nr), where nspecies is the number of unique species types.
Here, the average is carried out separately for every pair of species, so the
returned array has shape (nspecies, nspecies, nr).
returned array has shape (nspecies, nspecies, nr).
Args:
gofr: array of shape (N,nr) or a list of arrays of shape (N,nr), where nr is
the number of radii for which :math:`g(r)` is calculated.
species: Optional. Array of shape (N,) specifying the species of each
species: Optional. Array of shape (N,) specifying the species of each
particle.
Returns:
An array of shape (nr,) for species=None, otherwise an array of shape
An array of shape (nr,) for species=None, otherwise an array of shape
(nspecies, nspecies, nr), where nspecies is the number of unique species.
"""
if species is None:
Expand Down Expand Up @@ -354,16 +354,16 @@ def pair_correlation(displacement_or_metric: Union[DisplacementFn, MetricFn],
collection of particles.
:math:`g(r)` is calculated separately for each particle. For species=None, the
output of `g_fn` is an array of shape (N, nr), where N is the number of
particles passed to `g_fn` and nr is the size of radii (the number of points
output of `g_fn` is an array of shape (N, nr), where N is the number of
particles passed to `g_fn` and nr is the size of radii (the number of points
at which we calculate :math:`g(r)`. When species is specified, the output is a
list of nspecies arrays, each of shape (N, nr), where nspecies is the number
of unique species. If `gofr` is the output of `g_fn`, then gofr[si][i] gives
the :math:`g(r)` for particle i considering only pair particles of species si.
Note: when species is specified, the returned list is in the order of the
sorted unique species indices, not the order in which they appear.
Note: when species is specified, the returned list is in the order of the
sorted unique species indices, not the order in which they appear.
"""
d = space.canonicalize_displacement_or_metric(displacement_or_metric)
d = space.map_product(d)
Expand Down Expand Up @@ -415,11 +415,11 @@ def pair_correlation_neighbor_list(
"""Computes the pair correlation function at a mesh of distances.
The pair correlation function measures the number of particles at a given
distance from a central particle. The pair correlation function is defined by
distance from a central particle. The pair correlation function is defined by
.. math::
g(r) = <\sum_{i \\neq j}\delta(r - |r_i - r_j|)>
We make the approximation,
.. math::
Expand Down Expand Up @@ -450,15 +450,15 @@ def pair_correlation_neighbor_list(
position and a neighbor list.
:math:`g(r)` is calculated separately for each particle. For species=None, the
output of `g_fn` is an array of shape (N, nr), where N is the number of
particles passed to `g_fn` and nr is the size of radii (the number of points
output of `g_fn` is an array of shape (N, nr), where N is the number of
particles passed to `g_fn` and nr is the size of radii (the number of points
at which we calculate :math:`g(r)`. When species is specified, the output is a
list of nspecies arrays, each of shape (N, nr), where nspecies is the number
of unique species. If `gofr` is the output of `g_fn`, then gofr[si][i] gives
the :math:`g(r)` for particle i considering only pair particles of species si.
Note: when species is specified, the returned list is in the order of the
sorted unique species indices, not the order in which they appear.
sorted unique species indices, not the order in which they appear.
"""
metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
inv_rad = 1 / (radii + eps)
Expand Down Expand Up @@ -537,21 +537,21 @@ def nball_unit_volume(spatial_dimension: int) -> float:
return jnp.power(jnp.pi, spatial_dimension / 2) / \
jnp.exp( gammaln(spatial_dimension / 2 + 1))

def particle_volume(radii: Array,
spatial_dimension: int,
particle_count: Array = 1,
def particle_volume(radii: Array,
spatial_dimension: int,
particle_count: Array = 1,
species: Array = None) -> float:
""" Calculate the volume of a collection of particles
Args:
radii: array of shape (n,) giving particle radii, where n can be 1, the
number of species, or the number of particles depending on the values of
particle_count and species.
particle_count and species.
spatial_dimension: int giving the spatial dimension
particle_count: number of particles with each radii. Broadcastable to radii.
species: list of particle species. If provided, this overrides
species: list of particle species. If provided, this overrides
particle_count and radii is expected to give per-species radii
Returns: the sum of the volume of all the particles
"""
V_unit = nball_unit_volume(spatial_dimension)
Expand All @@ -562,17 +562,17 @@ def particle_volume(radii: Array,

return jnp.sum(particle_count * V_particle)

def volume_fraction(box: Box,
radii: Array,
spatial_dimension: int,
particle_count: Array = 1,
def volume_fraction(box: Box,
radii: Array,
spatial_dimension: int,
particle_count: Array = 1,
species: Array = None) -> float:
""" Calculate the volume fraction
See documentation for particle_volume for explanation of parameters
"""
Vparticle = particle_volume(radii, spatial_dimension, particle_count, species)
return Vparticle / quantity.volume(spatial_dimension, box)
return Vparticle / volume(spatial_dimension, box)

def box_size_at_volume_fraction(volume_fraction: float,
radii: Array,
Expand Down

0 comments on commit c36c715

Please sign in to comment.