Skip to content

Commit

Permalink
Bugfixes (Python version and Python tests) (#295)
Browse files Browse the repository at this point in the history
* Haiku does not support Python 3.8 anymore

* FLAGS.jax_enable_x64 replaced by jax_config.jax_enable_x64

* FLAGS.jax_enable_x64 replaced by jax_config.jax_enable_x64
  • Loading branch information
MarcBerneman committed Dec 19, 2023
1 parent e107c26 commit a4851b8
Show file tree
Hide file tree
Showing 14 changed files with 18 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Expand Up @@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10']
python-version: ['3.9', '3.10', '3.11']
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.11.0
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yml
Expand Up @@ -15,6 +15,6 @@ formats:

# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.8
version: 3.10
install:
- requirements: docs/requirements.txt
7 changes: 4 additions & 3 deletions setup.py
Expand Up @@ -54,10 +54,11 @@
long_description=long_description,
long_description_content_type='text/markdown',
description='Differentiable, Hardware Accelerated, Molecular Dynamics',
python_requires='>=2.7',
python_requires='>=3.9',
classifiers=[
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'License :: OSI Approved :: Apache Software License',
'Operating System :: MacOS',
'Operating System :: POSIX :: Linux',
Expand Down
3 changes: 1 addition & 2 deletions tests/elasticity_test.py
Expand Up @@ -35,14 +35,13 @@
from jax_md.util import *

jax_config.parse_flags_with_absl()
FLAGS = jax_config.FLAGS

PARTICLE_COUNT = 64
NUM_SAMPLES = 2
SPATIAL_DIMENSION = [2, 3]
LOWPRESSURE = [True, False]

if FLAGS.jax_enable_x64:
if jax_config.jax_enable_x64:
DTYPE = [f32, f64]
else:
DTYPE = [f32]
Expand Down
1 change: 0 additions & 1 deletion tests/energy_test.py
Expand Up @@ -37,7 +37,6 @@
from jax_md.interpolate import spline

config.parse_flags_with_absl()
FLAGS = config.FLAGS

# TODO: Replace np by jnp everywhere.
jnp = np
Expand Down
3 changes: 1 addition & 2 deletions tests/minimize_test.py
Expand Up @@ -32,14 +32,13 @@
from jax_md import test_util

jax_config.parse_flags_with_absl()
FLAGS = jax_config.FLAGS

PARTICLE_COUNT = 10
OPTIMIZATION_STEPS = 10
STOCHASTIC_SAMPLES = 10
SPATIAL_DIMENSION = [2, 3]

if FLAGS.jax_enable_x64:
if jax_config.jax_enable_x64:
DTYPE = [f32, f64]
else:
DTYPE = [f32]
Expand Down
7 changes: 3 additions & 4 deletions tests/nn_test.py
Expand Up @@ -31,9 +31,8 @@
from jax_md import test_util

jax_config.parse_flags_with_absl()
FLAGS = jax_config.FLAGS

if FLAGS.jax_enable_x64:
if jax_config.jax_enable_x64:
DTYPES = [f32, f64]
else:
DTYPES = [f32]
Expand Down Expand Up @@ -106,7 +105,7 @@ def test_radial_symmetry_functions_neighbor_list(self,
gr_exact = gr(R)
gr_nbrs = gr_neigh(R, neighbor=nbrs)

tol = 1e-13 if FLAGS.jax_enable_x64 else 1e-6
tol = 1e-13 if jax_config.jax_enable_x64 else 1e-6
self.assertAllClose(gr_exact, gr_nbrs, atol=tol, rtol=tol)

@parameterized.named_parameters(test_util.cases_from_list(
Expand Down Expand Up @@ -160,7 +159,7 @@ def test_angular_symmetry_functions_neighbor_list(self,
gr_exact = gr(R)
gr_nbrs = gr_neigh(R, neighbor=nbrs)

tol = 1e-13 if FLAGS.jax_enable_x64 else 1e-6
tol = 1e-13 if jax_config.jax_enable_x64 else 1e-6
self.assertAllClose(gr_exact, gr_nbrs, atol=tol, rtol=tol)

@parameterized.named_parameters(test_util.cases_from_list(
Expand Down
4 changes: 1 addition & 3 deletions tests/partition_test.py
Expand Up @@ -36,14 +36,12 @@

jax_config.parse_flags_with_absl()

FLAGS = jax_config.FLAGS


PARTICLE_COUNT = 1000
STOCHASTIC_SAMPLES = 10
SPATIAL_DIMENSION = [2, 3]

if FLAGS.jax_enable_x64:
if jax_config.jax_enable_x64:
POSITION_DTYPE = [f32, f64]
else:
POSITION_DTYPE = [f32]
Expand Down
5 changes: 2 additions & 3 deletions tests/quantity_test.py
Expand Up @@ -32,7 +32,6 @@


jax_config.parse_flags_with_absl()
FLAGS = jax_config.FLAGS

PARTICLE_COUNT = 10
STOCHASTIC_SAMPLES = 10
Expand All @@ -42,7 +41,7 @@
partition.Sparse,
partition.OrderedSparse]

DTYPES = [f32, f64] if FLAGS.jax_enable_x64 else [f32]
DTYPES = [f32, f64] if jax_config.jax_enable_x64 else [f32]
COORDS = ['fractional', 'real']


Expand Down Expand Up @@ -603,7 +602,7 @@ def test_phop(self, spatial_dim, dtype, window):


def test_maybe_downcast(self):
if not FLAGS.jax_enable_x64:
if not jax_config.jax_enable_x64:
self.skipTest('Maybe downcast only works for float32 mode.')

x = np.array([1, 2, 3], np.float64)
Expand Down
3 changes: 1 addition & 2 deletions tests/rigid_body_test.py
Expand Up @@ -45,7 +45,6 @@

jax_config.parse_flags_with_absl()

FLAGS = jax_config.FLAGS


f32 = util.f32
Expand All @@ -64,7 +63,7 @@
BROWNIAN_DYNAMICS_STEPS = 8000

DTYPE = [f32]
if FLAGS.jax_enable_x64:
if jax_config.jax_enable_x64:
DTYPE += [f64]


Expand Down
3 changes: 1 addition & 2 deletions tests/simulate_test.py
Expand Up @@ -41,7 +41,6 @@
from functools import partial

jax_config.parse_flags_with_absl()
FLAGS = jax_config.FLAGS


PARTICLE_COUNT = 1000
Expand All @@ -58,7 +57,7 @@
BROWNIAN_DYNAMICS_STEPS = 8000

DTYPE = [f32]
if FLAGS.jax_enable_x64:
if jax_config.jax_enable_x64:
DTYPE += [f64]


Expand Down
3 changes: 1 addition & 2 deletions tests/smap_test.py
Expand Up @@ -31,7 +31,6 @@
from jax_md import test_util

jax_config.parse_flags_with_absl()
FLAGS = jax_config.FLAGS

test_util.update_test_tolerance(f32_tol=5e-6, f64_tol=1e-14)

Expand All @@ -44,7 +43,7 @@

NEIGHBOR_LIST_PARTICLE_COUNT = 100

if FLAGS.jax_enable_x64:
if jax_config.jax_enable_x64:
POSITION_DTYPE = [f32, f64]
else:
POSITION_DTYPE = [f32]
Expand Down
3 changes: 1 addition & 2 deletions tests/space_test.py
Expand Up @@ -32,15 +32,14 @@
test_util.update_test_tolerance(5e-5, 5e-13)

jax_config.parse_flags_with_absl()
FLAGS = jax_config.FLAGS

PARTICLE_COUNT = 10
STOCHASTIC_SAMPLES = 10
SHIFT_STEPS = 10
SPATIAL_DIMENSION = [2, 3]
BOX_FORMATS = ['scalar', 'vector', 'matrix']

if FLAGS.jax_enable_x64:
if jax_config.jax_enable_x64:
POSITION_DTYPE = [f32, f64]
else:
POSITION_DTYPE = [f32]
Expand Down
2 changes: 0 additions & 2 deletions tests/tpu_test.py
Expand Up @@ -40,8 +40,6 @@

update_test_tolerance(5e-5, 1e-7)

FLAGS = jax_config.FLAGS


def get_test_grid(rng_key, topology=None, num_dims=2, add_aux=False, ):
# magic numbers to make the gird fold evenly, after splitting
Expand Down

0 comments on commit a4851b8

Please sign in to comment.