Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues with NPT + Lennard Jones using jax-md-0.2.5 #255

Open
maggiezimon opened this issue Feb 2, 2023 · 2 comments
Open

Issues with NPT + Lennard Jones using jax-md-0.2.5 #255

maggiezimon opened this issue Feb 2, 2023 · 2 comments

Comments

@maggiezimon
Copy link

maggiezimon commented Feb 2, 2023

Hi,

Firstly, thank you for all your efforts!

I am bumping into some problems when running an NPT simulation with Lennard-Jones with jax-md-0.2.5. In the following code (which run with jax-md-0.2.0), when executing fori_loop state, nbrs, log = lax.fori_loop(0, steps, step_fn, (state, nbrs, log)), I receive the following error: TypeError: Argument 'DIFFERENT ShapedArray(uint8[2,2]) vs. ShapedArray(uint8[])' of type <class 'str'> is not a valid JAX type. Could that have something to do with how JAX MD catches errors? Please take a look at the code below. Thank you for any suggestions!

from jax_md import (
    space, energy,
    quantity, simulate,
    minimize)

import jax.numpy as np
from jax import random
from jax import lax
import jax_md.partition as partition
from jax.config import config

config.update('jax_enable_x64', True)


def run_npt_simulation(
    P, N_rep, dt,
    kT, r_cutoff_lj,
    sigma_lj, epsilon_lj,
        steps, write_every):
    
    key, split = random.split(key)

    # Populate the simulation box
    dimension = 2

    lattice_constant = 1.37820
    box_size = N_rep * lattice_constant

    displacement, shift = space.periodic_general(box_size)

    # Lattice box
    R = []
    for i in range(N_rep):
        for j in range(N_rep):
            R += [[i/box_size, j/box_size]]

    R = np.array(R) * lattice_constant
    N = R.shape[0]

    phi = N / (lattice_constant * N_rep) ** dimension
    print(
        f'Created a system of {N} LJ particles with number density {phi:.3f}')
    # Construct the neighbourhood object
    format = partition.OrderedSparse

    neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(
                displacement,
                box_size,
                r_cutoff=r_cutoff_lj,
                sigma=sigma_lj,
                epsilon=epsilon_lj,
                dr_threshold=1.0,
                format=format,
                fractional_coordinates=True)

    init, apply = simulate.npt_nose_hoover(
            energy_fn, shift, dt, P(0.), kT)
    nbrs = neighbor_fn.allocate(R, box=box_size, extra_capacity=8)
    state = init(key, R, box_size, neighbor=nbrs)

    def step_fn(i, state_nbrs_log):
        state, nbrs, log = state_nbrs_log

        t = i * dt

        # Log information about the simulation.
        T = quantity.temperature(momentum=state.momentum)  # , mass=state.mass)
        log['kT'] = log['kT'].at[i].set(T)

        box = simulate.npt_box(state)
        KE = quantity.kinetic_energy(momentum=state.momentum)  # , mass=Mass)
        P_measured = quantity.pressure(
                energy_fn, state.position,
                box=box, kinetic_energy=KE, neighbor=nbrs)
        log['P'] = log['P'].at[i].set(P_measured)

        H = simulate.npt_nose_hoover_invariant(
                energy_fn, state, P(t), kT, neighbor=nbrs)
        log['H'] = log['H'].at[i].set(H)

        # Record positions every `write_every` steps.
        pos = space.transform(box, state.position)
        log['position'] = lax.cond(i % write_every == 0,
                                   lambda p: p.at[i // write_every].set(pos),
                                   lambda p: p,
                                   log['position'])

        # Take a simulation step.
        state = apply(state, neighbor=nbrs, pressure=P(t))
        box = simulate.npt_box(state)
        nbrs = nbrs.update(state.position, box=box)

        return state, nbrs, log

    # Run simulation and store the results in log
    log = {
        'P': np.zeros((steps,)),
        'kT': np.zeros((steps,)),
        'H': np.zeros((steps,)),
        'position': np.zeros((steps // write_every,) + R.shape)
    }

    print('Simulation is running...')
    state, nbrs, log = lax.fori_loop(0, steps, step_fn, (state, nbrs, log))

    print(nbrs.did_buffer_overflow)

    return state, nbrs, log
@maggiezimon
Copy link
Author

maggiezimon commented Feb 2, 2023

To run the code above, you can use the following settings:

    units = {
      'mass': 1,
      'distance': 1,
      'time': 98.22694788,
      'energy': 1,
      'velocity': 0.01018051,
      'force': 1.0,
      'torque ': 1,
      'temperature': 8.617330337217213e-05,
      'pressure': 6.241509125883258e-07
    }

    fs = 1e-5 * units['time']
    ps = units['time']

    T_init = 300 * units['temperature']
    P_start = 0.0 * units['pressure']
    P_end = 0.1 * units['pressure']

    P = lambda t: np.where(t < 100.0, P_start, P_end)
    kT = T_init
    # lj
    sigma_lj = 1.0
    epsilon_lj = 1.0
    r_cutoff_lj = 3.5

    N_rep = 40
    # Run the simulation
    dt = fs
    steps = 10000
    write_every = 10
    state, nbrs, log = run_npt_simulation(
            P=P,
            N_rep=N_rep, dt=dt, kT=kT, r_cutoff_lj=r_cutoff_lj,
            sigma_lj=sigma_lj, epsilon_lj=epsilon_lj,
            steps=steps, write_every=write_every)

@maggiezimon maggiezimon changed the title Issues with NPT + Lennard Jones Issues with NPT + Lennard Jones using jax-md-0.2.5 Feb 2, 2023
@abhijeetgangan
Copy link
Contributor

Hi,

Defining box_size as a (dimension, dimension) matrix will fix the error.

lattice_constant = 1.37820
box_size = N_rep * lattice_constant

box_size = np.array([[box_size, 0.0], [0.0, box_size]])

displacement, shift = space.periodic_general(box_size)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants