Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Add support for auxiliary data returned by energy/force function" #301

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 7 additions & 29 deletions jax_md/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,34 +57,19 @@
# Functions


def force(energy_fn: EnergyFn, has_aux: bool = False) -> ForceFn:
def force(energy_fn: EnergyFn) -> ForceFn:
"""Computes the force as the negative gradient of an energy."""
def energy_fn_w_aux(R, *args, **kwargs):
'''
Returns negative energy and auxiliary data
'''
energy, aux = energy_fn(R, *args, **kwargs)
return -energy, aux

def energy_fn_wo_aux(R, *args, **kwargs):
'''
Returns negative energy and None as auxiliary data
'''
return -energy_fn(R, *args, **kwargs), None

updated_energy_fn = energy_fn_w_aux if has_aux else energy_fn_wo_aux
return grad(lambda R, *args, **kwargs: updated_energy_fn(R, *args, **kwargs),
has_aux=True)
return grad(lambda R, *args, **kwargs: -energy_fn(R, *args, **kwargs))


def clipped_force(energy_fn: EnergyFn, max_force: float) -> ForceFn:
force_fn = force(energy_fn)
def wrapped_force_fn(R, *args, **kwargs):
force, aux = force_fn(R, *args, **kwargs)
force = force_fn(R, *args, **kwargs)
force_norm = jnp.linalg.norm(force, axis=-1, keepdims=True)
return jnp.where(force_norm > max_force,
force / force_norm * max_force,
force), aux
force)

return wrapped_force_fn

Expand All @@ -95,12 +80,8 @@ def force_fn(R, **kwargs):
nonlocal _force_fn
if _force_fn is None:
out_shaped = eval_shape(energy_or_force_fn, R, **kwargs)
has_aux = False
if isinstance(out_shaped, ShapeDtypeStruct) == False:
has_aux = True
out_shaped = out_shaped[0]
if isinstance(out_shaped, ShapeDtypeStruct) and out_shaped.shape == ():
_force_fn = force(energy_or_force_fn, has_aux)
_force_fn = force(energy_or_force_fn)
else:
# Check that the output has the right shape to be a force.
is_valid_force = tree_reduce(
Expand All @@ -112,11 +93,8 @@ def force_fn(R, **kwargs):
raise ValueError('Provided function should be compatible with '
'either an energy or a force. Found a function '
f'whose output has shape {out_shaped}.')
if has_aux == False:
_force_fn = lambda R, *args, **kwargs: \
energy_or_force_fn(R, *args, **kwargs), None
else:
_force_fn = energy_or_force_fn

_force_fn = energy_or_force_fn
return _force_fn(R, **kwargs)

return force_fn
Expand Down
49 changes: 17 additions & 32 deletions jax_md/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,7 @@ def velocity_verlet(force_fn: Callable[..., Array],

state = momentum_step(state, dt_2)
state = position_step(state, shift_fn, dt, **kwargs)
force, aux = force_fn(state.position, **kwargs)
state = state.set(force=force, aux=aux)
state = state.set(force=force_fn(state.position, **kwargs))
state = momentum_step(state, dt_2)

return state
Expand All @@ -250,13 +249,11 @@ class NVEState:
acting on particles from the previous step.
mass: A float or an ndarray of shape `[n]` containing the masses of the
particles.
aux: Auxiliary data (Ex. charge array).
"""
position: Array
momentum: Array
force: Array
mass: Array
aux: Any

@property
def velocity(self) -> Array:
Expand Down Expand Up @@ -286,8 +283,8 @@ def nve(energy_or_force_fn, shift_fn, dt=1e-3, **sim_kwargs):

@jit
def init_fn(key, R, kT, mass=f32(1.0), **kwargs):
force, aux = force_fn(R, **kwargs)
state = NVEState(R, None, force, mass, aux)
force = force_fn(R, **kwargs)
state = NVEState(R, None, force, mass)
state = canonicalize_mass(state)
return initialize_momenta(state, key, kT)

Expand Down Expand Up @@ -516,14 +513,12 @@ class NVTNoseHooverState:
mass: The mass of the particles. Can either be a float or an ndarray
of floats with shape `[n]`.
chain: The variables describing the Nose-Hoover chain.
aux: Auxiliary data (Ex. charge array).
"""
position: Array
momentum: Array
force: Array
mass: Array
chain: NoseHooverChain
aux: Any

@property
def velocity(self):
Expand Down Expand Up @@ -598,8 +593,8 @@ def init_fn(key, R, mass=f32(1.0), **kwargs):
_kT = kT if 'kT' not in kwargs else kwargs['kT']

dof = quantity.count_dof(R)
force, aux = force_fn(R, **kwargs)
state = NVTNoseHooverState(R, None, force, mass, None, aux)

state = NVTNoseHooverState(R, None, force_fn(R, **kwargs), mass, None)
state = canonicalize_mass(state)
state = initialize_momenta(state, key, _kT)
KE = kinetic_energy(state)
Expand Down Expand Up @@ -682,7 +677,6 @@ class NPTNoseHooverState:
barostat.
thermostsat: The variables describing the Nose-Hoover chain coupled to the
thermostat.
aux: Auxiliary data (Ex. charge array).
"""
position: Array
momentum: Array
Expand All @@ -698,8 +692,6 @@ class NPTNoseHooverState:
barostat: NoseHooverChain
thermostat: NoseHooverChain

aux: Any

@property
def velocity(self) -> Array:
return self.momentum / self.mass
Expand Down Expand Up @@ -802,13 +794,12 @@ def init_fn(key, R, box, mass=f32(1.0), **kwargs):
if jnp.isscalar(box) or box.ndim == 0:
# TODO(schsam): This is necessary because of JAX issue #5849.
box = jnp.eye(R.shape[-1]) * box
force, aux = force_fn(R, box=box, **kwargs)

state = NPTNoseHooverState(
R, None, force,
R, None, force_fn(R, box=box, **kwargs),
mass, box, box_position, box_momentum, box_mass,
barostat.initialize(1, KE_box, _kT),
None,
aux) # pytype: disable=wrong-arg-count
None) # pytype: disable=wrong-arg-count
state = canonicalize_mass(state)
state = initialize_momenta(state, key, _kT)
KE = kinetic_energy(state)
Expand Down Expand Up @@ -874,15 +865,14 @@ def inner_step(state, **kwargs):

box = box_fn(vol)
R = exp_iL1(box, R, P / M, P_b / M_b)
F, aux = force_fn(R, box=box, **kwargs)
F = force_fn(R, box=box, **kwargs)

P = exp_iL2(alpha, P, F, P_b / M_b)
G_e = box_force(alpha, vol, box_fn, R, P, M, F, _pressure, **kwargs)
P_b = P_b + dt_2 * G_e

return state.set(position=R, momentum=P, mass=M, force=F,
box_position=R_b, box_momentum=P_b, box_mass=M_b,
aux=aux)
box_position=R_b, box_momentum=P_b, box_mass=M_b)

def apply_fn(state, **kwargs):
S = state
Expand Down Expand Up @@ -989,14 +979,12 @@ class NVTLangevinState:
mass: The mass of particles. Will either be a float or an ndarray of floats
with shape `[n]`.
rng: The current state of the random number generator.
aux: Auxiliary data (Ex. charge array).
"""
position: Array
momentum: Array
force: Array
mass: Array
rng: Array
aux: Any

@property
def velocity(self) -> Array:
Expand Down Expand Up @@ -1062,8 +1050,8 @@ def nvt_langevin(energy_or_force_fn: Callable[..., Array],
def init_fn(key, R, mass=f32(1.0), **kwargs):
_kT = kwargs.pop('kT', kT)
key, split = random.split(key)
force, aux = force_fn(R, **kwargs)
state = NVTLangevinState(R, None, force, mass, key, aux)
force = force_fn(R, **kwargs)
state = NVTLangevinState(R, None, force, mass, key)
state = canonicalize_mass(state)
return initialize_momenta(state, split, _kT)

Expand All @@ -1077,8 +1065,7 @@ def step_fn(state, **kwargs):
state = position_step(state, shift_fn, dt_2, **kwargs)
state = stochastic_step(state, _dt, _kT, gamma)
state = position_step(state, shift_fn, dt_2, **kwargs)
force, aux = force_fn(state.position, **kwargs)
state = state.set(force=force, aux=aux)
state = state.set(force=force_fn(state.position, **kwargs))
state = momentum_step(state, dt_2)

return state
Expand All @@ -1096,12 +1083,10 @@ class BrownianState:
mass: The mass of particles. Will either be a float or an ndarray of floats
with shape `[n]`.
rng: The current state of the random number generator.
aux: Auxiliary data (Ex. charge array).
"""
position: Array
mass: Array
rng: Array
aux: Any


def brownian(energy_or_force: Callable[..., Array],
Expand Down Expand Up @@ -1140,25 +1125,25 @@ def brownian(energy_or_force: Callable[..., Array],
dt, gamma = static_cast(dt, gamma)

def init_fn(key, R, mass=f32(1)):
state = BrownianState(R, mass, key, None)
state = BrownianState(R, mass, key)
return canonicalize_mass(state)

def apply_fn(state, **kwargs):
_kT = kT if 'kT' not in kwargs else kwargs['kT']

R, mass, key, aux = dataclasses.astuple(state)
R, mass, key = dataclasses.astuple(state)

key, split = random.split(key)

F, aux = force_fn(R, **kwargs)
F = force_fn(R, **kwargs)
xi = random.normal(split, R.shape, R.dtype)

nu = f32(1) / (mass * gamma)

dR = F * dt * nu + jnp.sqrt(f32(2) * _kT * dt * nu) * xi
R = shift(R, dR, **kwargs)

return BrownianState(R, mass, key, aux) # pytype: disable=wrong-arg-count
return BrownianState(R, mass, key) # pytype: disable=wrong-arg-count

return init_fn, apply_fn

Expand Down
2 changes: 1 addition & 1 deletion tests/quantity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def U(r):
return np.sum(1 / np.linalg.norm(r, axis=-1) ** 2)
force_fn = quantity.clipped_force(U, 1.5)
R = random.normal(random.PRNGKey(0), (N, dim))
self.assertTrue(np.all(np.linalg.norm(force_fn(R)[0], axis=-1) <= 1.5))
self.assertTrue(np.all(np.linalg.norm(force_fn(R), axis=-1) <= 1.5))

if __name__ == '__main__':
absltest.main()