Skip to content

Commit

Permalink
Merge pull request #245 from cagrikymk/aux_info
Browse files Browse the repository at this point in the history
Add support for auxiliary data returned by energy/force function
  • Loading branch information
ekindogus committed Feb 17, 2024
2 parents b5082ed + 434be0a commit 389d7c3
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 25 deletions.
36 changes: 29 additions & 7 deletions jax_md/quantity.py
Expand Up @@ -57,19 +57,34 @@
# Functions


def force(energy_fn: EnergyFn) -> ForceFn:
def force(energy_fn: EnergyFn, has_aux: bool = False) -> ForceFn:
"""Computes the force as the negative gradient of an energy."""
return grad(lambda R, *args, **kwargs: -energy_fn(R, *args, **kwargs))
def energy_fn_w_aux(R, *args, **kwargs):
'''
Returns negative energy and auxiliary data
'''
energy, aux = energy_fn(R, *args, **kwargs)
return -energy, aux

def energy_fn_wo_aux(R, *args, **kwargs):
'''
Returns negative energy and None as auxiliary data
'''
return -energy_fn(R, *args, **kwargs), None

updated_energy_fn = energy_fn_w_aux if has_aux else energy_fn_wo_aux
return grad(lambda R, *args, **kwargs: updated_energy_fn(R, *args, **kwargs),
has_aux=True)


def clipped_force(energy_fn: EnergyFn, max_force: float) -> ForceFn:
force_fn = force(energy_fn)
def wrapped_force_fn(R, *args, **kwargs):
force = force_fn(R, *args, **kwargs)
force, aux = force_fn(R, *args, **kwargs)
force_norm = jnp.linalg.norm(force, axis=-1, keepdims=True)
return jnp.where(force_norm > max_force,
force / force_norm * max_force,
force)
force), aux

return wrapped_force_fn

Expand All @@ -80,8 +95,12 @@ def force_fn(R, **kwargs):
nonlocal _force_fn
if _force_fn is None:
out_shaped = eval_shape(energy_or_force_fn, R, **kwargs)
has_aux = False
if isinstance(out_shaped, ShapeDtypeStruct) == False:
has_aux = True
out_shaped = out_shaped[0]
if isinstance(out_shaped, ShapeDtypeStruct) and out_shaped.shape == ():
_force_fn = force(energy_or_force_fn)
_force_fn = force(energy_or_force_fn, has_aux)
else:
# Check that the output has the right shape to be a force.
is_valid_force = tree_reduce(
Expand All @@ -93,8 +112,11 @@ def force_fn(R, **kwargs):
raise ValueError('Provided function should be compatible with '
'either an energy or a force. Found a function '
f'whose output has shape {out_shaped}.')

_force_fn = energy_or_force_fn
if has_aux == False:
_force_fn = lambda R, *args, **kwargs: \
energy_or_force_fn(R, *args, **kwargs), None
else:
_force_fn = energy_or_force_fn
return _force_fn(R, **kwargs)

return force_fn
Expand Down
49 changes: 32 additions & 17 deletions jax_md/simulate.py
Expand Up @@ -223,7 +223,8 @@ def velocity_verlet(force_fn: Callable[..., Array],

state = momentum_step(state, dt_2)
state = position_step(state, shift_fn, dt, **kwargs)
state = state.set(force=force_fn(state.position, **kwargs))
force, aux = force_fn(state.position, **kwargs)
state = state.set(force=force, aux=aux)
state = momentum_step(state, dt_2)

return state
Expand All @@ -249,11 +250,13 @@ class NVEState:
acting on particles from the previous step.
mass: A float or an ndarray of shape `[n]` containing the masses of the
particles.
aux: Auxiliary data (Ex. charge array).
"""
position: Array
momentum: Array
force: Array
mass: Array
aux: Any

@property
def velocity(self) -> Array:
Expand Down Expand Up @@ -283,8 +286,8 @@ def nve(energy_or_force_fn, shift_fn, dt=1e-3, **sim_kwargs):

@jit
def init_fn(key, R, kT, mass=f32(1.0), **kwargs):
force = force_fn(R, **kwargs)
state = NVEState(R, None, force, mass)
force, aux = force_fn(R, **kwargs)
state = NVEState(R, None, force, mass, aux)
state = canonicalize_mass(state)
return initialize_momenta(state, key, kT)

Expand Down Expand Up @@ -513,12 +516,14 @@ class NVTNoseHooverState:
mass: The mass of the particles. Can either be a float or an ndarray
of floats with shape `[n]`.
chain: The variables describing the Nose-Hoover chain.
aux: Auxiliary data (Ex. charge array).
"""
position: Array
momentum: Array
force: Array
mass: Array
chain: NoseHooverChain
aux: Any

@property
def velocity(self):
Expand Down Expand Up @@ -593,8 +598,8 @@ def init_fn(key, R, mass=f32(1.0), **kwargs):
_kT = kT if 'kT' not in kwargs else kwargs['kT']

dof = quantity.count_dof(R)

state = NVTNoseHooverState(R, None, force_fn(R, **kwargs), mass, None)
force, aux = force_fn(R, **kwargs)
state = NVTNoseHooverState(R, None, force, mass, None, aux)
state = canonicalize_mass(state)
state = initialize_momenta(state, key, _kT)
KE = kinetic_energy(state)
Expand Down Expand Up @@ -677,6 +682,7 @@ class NPTNoseHooverState:
barostat.
thermostsat: The variables describing the Nose-Hoover chain coupled to the
thermostat.
aux: Auxiliary data (Ex. charge array).
"""
position: Array
momentum: Array
Expand All @@ -692,6 +698,8 @@ class NPTNoseHooverState:
barostat: NoseHooverChain
thermostat: NoseHooverChain

aux: Any

@property
def velocity(self) -> Array:
return self.momentum / self.mass
Expand Down Expand Up @@ -794,12 +802,13 @@ def init_fn(key, R, box, mass=f32(1.0), **kwargs):
if jnp.isscalar(box) or box.ndim == 0:
# TODO(schsam): This is necessary because of JAX issue #5849.
box = jnp.eye(R.shape[-1]) * box

force, aux = force_fn(R, box=box, **kwargs)
state = NPTNoseHooverState(
R, None, force_fn(R, box=box, **kwargs),
R, None, force,
mass, box, box_position, box_momentum, box_mass,
barostat.initialize(1, KE_box, _kT),
None) # pytype: disable=wrong-arg-count
None,
aux) # pytype: disable=wrong-arg-count
state = canonicalize_mass(state)
state = initialize_momenta(state, key, _kT)
KE = kinetic_energy(state)
Expand Down Expand Up @@ -865,14 +874,15 @@ def inner_step(state, **kwargs):

box = box_fn(vol)
R = exp_iL1(box, R, P / M, P_b / M_b)
F = force_fn(R, box=box, **kwargs)
F, aux = force_fn(R, box=box, **kwargs)

P = exp_iL2(alpha, P, F, P_b / M_b)
G_e = box_force(alpha, vol, box_fn, R, P, M, F, _pressure, **kwargs)
P_b = P_b + dt_2 * G_e

return state.set(position=R, momentum=P, mass=M, force=F,
box_position=R_b, box_momentum=P_b, box_mass=M_b)
box_position=R_b, box_momentum=P_b, box_mass=M_b,
aux=aux)

def apply_fn(state, **kwargs):
S = state
Expand Down Expand Up @@ -979,12 +989,14 @@ class NVTLangevinState:
mass: The mass of particles. Will either be a float or an ndarray of floats
with shape `[n]`.
rng: The current state of the random number generator.
aux: Auxiliary data (Ex. charge array).
"""
position: Array
momentum: Array
force: Array
mass: Array
rng: Array
aux: Any

@property
def velocity(self) -> Array:
Expand Down Expand Up @@ -1050,8 +1062,8 @@ def nvt_langevin(energy_or_force_fn: Callable[..., Array],
def init_fn(key, R, mass=f32(1.0), **kwargs):
_kT = kwargs.pop('kT', kT)
key, split = random.split(key)
force = force_fn(R, **kwargs)
state = NVTLangevinState(R, None, force, mass, key)
force, aux = force_fn(R, **kwargs)
state = NVTLangevinState(R, None, force, mass, key, aux)
state = canonicalize_mass(state)
return initialize_momenta(state, split, _kT)

Expand All @@ -1065,7 +1077,8 @@ def step_fn(state, **kwargs):
state = position_step(state, shift_fn, dt_2, **kwargs)
state = stochastic_step(state, _dt, _kT, gamma)
state = position_step(state, shift_fn, dt_2, **kwargs)
state = state.set(force=force_fn(state.position, **kwargs))
force, aux = force_fn(state.position, **kwargs)
state = state.set(force=force, aux=aux)
state = momentum_step(state, dt_2)

return state
Expand All @@ -1083,10 +1096,12 @@ class BrownianState:
mass: The mass of particles. Will either be a float or an ndarray of floats
with shape `[n]`.
rng: The current state of the random number generator.
aux: Auxiliary data (Ex. charge array).
"""
position: Array
mass: Array
rng: Array
aux: Any


def brownian(energy_or_force: Callable[..., Array],
Expand Down Expand Up @@ -1125,25 +1140,25 @@ def brownian(energy_or_force: Callable[..., Array],
dt, gamma = static_cast(dt, gamma)

def init_fn(key, R, mass=f32(1)):
state = BrownianState(R, mass, key)
state = BrownianState(R, mass, key, None)
return canonicalize_mass(state)

def apply_fn(state, **kwargs):
_kT = kT if 'kT' not in kwargs else kwargs['kT']

R, mass, key = dataclasses.astuple(state)
R, mass, key, aux = dataclasses.astuple(state)

key, split = random.split(key)

F = force_fn(R, **kwargs)
F, aux = force_fn(R, **kwargs)
xi = random.normal(split, R.shape, R.dtype)

nu = f32(1) / (mass * gamma)

dR = F * dt * nu + jnp.sqrt(f32(2) * _kT * dt * nu) * xi
R = shift(R, dR, **kwargs)

return BrownianState(R, mass, key) # pytype: disable=wrong-arg-count
return BrownianState(R, mass, key, aux) # pytype: disable=wrong-arg-count

return init_fn, apply_fn

Expand Down
2 changes: 1 addition & 1 deletion tests/quantity_test.py
Expand Up @@ -616,7 +616,7 @@ def U(r):
return np.sum(1 / np.linalg.norm(r, axis=-1) ** 2)
force_fn = quantity.clipped_force(U, 1.5)
R = random.normal(random.PRNGKey(0), (N, dim))
self.assertTrue(np.all(np.linalg.norm(force_fn(R), axis=-1) <= 1.5))
self.assertTrue(np.all(np.linalg.norm(force_fn(R)[0], axis=-1) <= 1.5))

if __name__ == '__main__':
absltest.main()

0 comments on commit 389d7c3

Please sign in to comment.