Skip to content

Commit

Permalink
Fix failing build (test_coulomb_direct_octions) (#287)
Browse files Browse the repository at this point in the history
* remove species=96 and set fractional_coordinates=True

* Increase swap size

* Use a predefined key when adding randomness to R

---------
  • Loading branch information
MarcBerneman committed Aug 29, 2023
1 parent b65c777 commit e107c26
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/build.yml
Expand Up @@ -41,6 +41,10 @@ jobs:
pip install pytest-xdist
pip install pytest-cov
pip install netCDF4
- name: Set Swap Space
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 10
- name: Test with pytest and generate coverage report
run: |
JAX_ENABLE_X64=1 pytest --cov=jax_md --cov-report=xml --cov-report=term --cov-config=setup.cfg
Expand Down
2 changes: 1 addition & 1 deletion tests/energy_test.py
Expand Up @@ -1200,7 +1200,7 @@ def test_coulomb_direct_octions(self):
R_frac = jnp.mod(R_frac, 1.0)

neighbor_fn, energy_fn = energy.coulomb_direct_neighbor_list(
displacement, box, jnp.array(Q), 96, alpha=0.3488,
displacement, box, jnp.array(Q), alpha=0.3488,
fractional_coordinates=True)
energy_fn = jit(energy_fn)

Expand Down
19 changes: 10 additions & 9 deletions tests/tpu_test.py
Expand Up @@ -43,7 +43,7 @@
FLAGS = jax_config.FLAGS


def get_test_grid(topology=None, num_dims=2, add_aux=False, rng_key=None):
def get_test_grid(rng_key, topology=None, num_dims=2, add_aux=False, ):
# magic numbers to make the gird fold evenly, after splitting
# across devices and padding see propose_tpu_box_size.
cell_size = 1./4.
Expand Down Expand Up @@ -90,12 +90,11 @@ def get_test_grid(topology=None, num_dims=2, add_aux=False, rng_key=None):
elif num_dims == 1:
R = points[0].reshape((-1, 1)) + 0.1

R += onp.random.randn(*R.shape) * 0.1
R += random.normal(rng_key, R.shape) * 0.1
R = onp.array(R, onp.float64)
if add_aux:
if rng_key is None:
rng_key = random.PRNGKey(1)
# these are used as velocities
rng_key = random.split(rng_key)[0]
V = random.normal(rng_key, R.shape)
R_grid, V_grid = tpu.to_grid(R, box_size_in_cells, cell_size, interaction_distance, topology, aux=V, strategy='linear')
print(f"R.shape {R.shape}, aux.shape {V.shape}, grid shape {V_grid.shape}, occupancy {R.shape[0]/float(onp.prod(V_grid.shape[:-1]))}")
Expand Down Expand Up @@ -124,7 +123,8 @@ def test_position_recovery(self, num_dims, topology):
self.skipTest('Skipping non-trivial topology; only one device detected.')
topology = topology + (1,) * (num_dims - 1)

sim_tpu, sim_cpu = get_test_grid(topology, num_dims)
key = random.PRNGKey(0)
sim_tpu, sim_cpu = get_test_grid(key, topology, num_dims)

(R_grid, tpu_energy_fn, tpu_force_fn) = sim_tpu
(R, energy_fn, shift_fn) = sim_cpu
Expand All @@ -147,7 +147,8 @@ def test_position_and_aux_recovery(self, num_dims, topology):
self.skipTest('Skipping non-trivial topology; only one device detected.')
topology = topology + (1,) * (num_dims - 1)

sim_tpu, sim_cpu = get_test_grid(topology, num_dims, True)
key = random.PRNGKey(0)
sim_tpu, sim_cpu = get_test_grid(key, topology, num_dims, True)

((R_grid, V_grid), tpu_energy_fn, tpu_force_fn) = sim_tpu
((R, V), energy_fn, shift_fn) = sim_cpu
Expand All @@ -171,9 +172,9 @@ def test_forces(self, num_dims, topology):
if jax.device_count() == 1:
self.skipTest('Skipping non-trivial topology; only one device detected.')
topology = topology + (1,) * (num_dims - 1)
key = random.PRNGKey(0)

sim_tpu, sim_cpu = get_test_grid(topology, num_dims)
key = random.PRNGKey(0)
sim_tpu, sim_cpu = get_test_grid(key, topology, num_dims)

(R_grid, tpu_energy_fn, tpu_force_fn) = sim_tpu
(R, energy_fn, shift_fn) = sim_cpu
Expand All @@ -199,8 +200,8 @@ def test_nve(self, num_dims, topology):
topology = topology + (1,) * (num_dims - 1)

key = random.PRNGKey(0)
sim_tpu, sim_cpu = get_test_grid(key, topology, num_dims, True)

sim_tpu, sim_cpu = get_test_grid(topology, num_dims, True, rng_key=key)
((R_grid, V_grid), tpu_energy_fn, tpu_force_fn) = sim_tpu
((R, V), energy_fn, shift_fn) = sim_cpu

Expand Down

0 comments on commit e107c26

Please sign in to comment.