You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am bumping into some problems when running an NPT simulation with Lennard-Jones with jax-md-0.2.5. In the following code (which run with jax-md-0.2.0), when executing fori_loop state, nbrs, log = lax.fori_loop(0, steps, step_fn, (state, nbrs, log)), I receive the following error: TypeError: Argument 'DIFFERENT ShapedArray(uint8[2,2]) vs. ShapedArray(uint8[])' of type <class 'str'> is not a valid JAX type. Could that have something to do with how JAX MD catches errors? Please take a look at the code below. Thank you for any suggestions!
from jax_md import (
space, energy,
quantity, simulate,
minimize)
import jax.numpy as np
from jax import random
from jax import lax
import jax_md.partition as partition
from jax.config import config
config.update('jax_enable_x64', True)
def run_npt_simulation(
P, N_rep, dt,
kT, r_cutoff_lj,
sigma_lj, epsilon_lj,
steps, write_every):
key, split = random.split(key)
# Populate the simulation box
dimension = 2
lattice_constant = 1.37820
box_size = N_rep * lattice_constant
displacement, shift = space.periodic_general(box_size)
# Lattice box
R = []
for i in range(N_rep):
for j in range(N_rep):
R += [[i/box_size, j/box_size]]
R = np.array(R) * lattice_constant
N = R.shape[0]
phi = N / (lattice_constant * N_rep) ** dimension
print(
f'Created a system of {N} LJ particles with number density {phi:.3f}')
# Construct the neighbourhood object
format = partition.OrderedSparse
neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(
displacement,
box_size,
r_cutoff=r_cutoff_lj,
sigma=sigma_lj,
epsilon=epsilon_lj,
dr_threshold=1.0,
format=format,
fractional_coordinates=True)
init, apply = simulate.npt_nose_hoover(
energy_fn, shift, dt, P(0.), kT)
nbrs = neighbor_fn.allocate(R, box=box_size, extra_capacity=8)
state = init(key, R, box_size, neighbor=nbrs)
def step_fn(i, state_nbrs_log):
state, nbrs, log = state_nbrs_log
t = i * dt
# Log information about the simulation.
T = quantity.temperature(momentum=state.momentum) # , mass=state.mass)
log['kT'] = log['kT'].at[i].set(T)
box = simulate.npt_box(state)
KE = quantity.kinetic_energy(momentum=state.momentum) # , mass=Mass)
P_measured = quantity.pressure(
energy_fn, state.position,
box=box, kinetic_energy=KE, neighbor=nbrs)
log['P'] = log['P'].at[i].set(P_measured)
H = simulate.npt_nose_hoover_invariant(
energy_fn, state, P(t), kT, neighbor=nbrs)
log['H'] = log['H'].at[i].set(H)
# Record positions every `write_every` steps.
pos = space.transform(box, state.position)
log['position'] = lax.cond(i % write_every == 0,
lambda p: p.at[i // write_every].set(pos),
lambda p: p,
log['position'])
# Take a simulation step.
state = apply(state, neighbor=nbrs, pressure=P(t))
box = simulate.npt_box(state)
nbrs = nbrs.update(state.position, box=box)
return state, nbrs, log
# Run simulation and store the results in log
log = {
'P': np.zeros((steps,)),
'kT': np.zeros((steps,)),
'H': np.zeros((steps,)),
'position': np.zeros((steps // write_every,) + R.shape)
}
print('Simulation is running...')
state, nbrs, log = lax.fori_loop(0, steps, step_fn, (state, nbrs, log))
print(nbrs.did_buffer_overflow)
return state, nbrs, log
The text was updated successfully, but these errors were encountered:
Hi,
Firstly, thank you for all your efforts!
I am bumping into some problems when running an NPT simulation with Lennard-Jones with jax-md-0.2.5. In the following code (which run with jax-md-0.2.0), when executing fori_loop
state, nbrs, log = lax.fori_loop(0, steps, step_fn, (state, nbrs, log))
, I receive the following error:TypeError: Argument 'DIFFERENT ShapedArray(uint8[2,2]) vs. ShapedArray(uint8[])' of type <class 'str'> is not a valid JAX type.
Could that have something to do with how JAX MD catches errors? Please take a look at the code below. Thank you for any suggestions!The text was updated successfully, but these errors were encountered: