diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9f6ccc58..15cbe14b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -18,7 +18,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.8', '3.9', '3.10'] + python-version: ['3.9', '3.10', '3.11'] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@0.11.0 diff --git a/.readthedocs.yml b/.readthedocs.yml index c6a4de20..92fcaf3c 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -15,6 +15,6 @@ formats: # Optionally set the version of Python and requirements required to build your docs python: - version: 3.8 + version: 3.10 install: - requirements: docs/requirements.txt diff --git a/setup.py b/setup.py index b25d28f2..2983aad4 100644 --- a/setup.py +++ b/setup.py @@ -54,10 +54,11 @@ long_description=long_description, long_description_content_type='text/markdown', description='Differentiable, Hardware Accelerated, Molecular Dynamics', - python_requires='>=2.7', + python_requires='>=3.9', classifiers=[ - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', 'License :: OSI Approved :: Apache Software License', 'Operating System :: MacOS', 'Operating System :: POSIX :: Linux', diff --git a/tests/elasticity_test.py b/tests/elasticity_test.py index c882e71c..14105fad 100644 --- a/tests/elasticity_test.py +++ b/tests/elasticity_test.py @@ -35,14 +35,13 @@ from jax_md.util import * jax_config.parse_flags_with_absl() -FLAGS = jax_config.FLAGS PARTICLE_COUNT = 64 NUM_SAMPLES = 2 SPATIAL_DIMENSION = [2, 3] LOWPRESSURE = [True, False] -if FLAGS.jax_enable_x64: +if jax_config.jax_enable_x64: DTYPE = [f32, f64] else: DTYPE = [f32] diff --git a/tests/energy_test.py b/tests/energy_test.py index ccae6b4e..97e09a15 100644 --- a/tests/energy_test.py +++ b/tests/energy_test.py @@ -37,7 +37,6 @@ from jax_md.interpolate import spline config.parse_flags_with_absl() -FLAGS = config.FLAGS # TODO: Replace np by jnp everywhere. jnp = np diff --git a/tests/minimize_test.py b/tests/minimize_test.py index 8f12cfac..dea9ec95 100644 --- a/tests/minimize_test.py +++ b/tests/minimize_test.py @@ -32,14 +32,13 @@ from jax_md import test_util jax_config.parse_flags_with_absl() -FLAGS = jax_config.FLAGS PARTICLE_COUNT = 10 OPTIMIZATION_STEPS = 10 STOCHASTIC_SAMPLES = 10 SPATIAL_DIMENSION = [2, 3] -if FLAGS.jax_enable_x64: +if jax_config.jax_enable_x64: DTYPE = [f32, f64] else: DTYPE = [f32] diff --git a/tests/nn_test.py b/tests/nn_test.py index c18c1e19..8dae2580 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -31,9 +31,8 @@ from jax_md import test_util jax_config.parse_flags_with_absl() -FLAGS = jax_config.FLAGS -if FLAGS.jax_enable_x64: +if jax_config.jax_enable_x64: DTYPES = [f32, f64] else: DTYPES = [f32] @@ -106,7 +105,7 @@ def test_radial_symmetry_functions_neighbor_list(self, gr_exact = gr(R) gr_nbrs = gr_neigh(R, neighbor=nbrs) - tol = 1e-13 if FLAGS.jax_enable_x64 else 1e-6 + tol = 1e-13 if jax_config.jax_enable_x64 else 1e-6 self.assertAllClose(gr_exact, gr_nbrs, atol=tol, rtol=tol) @parameterized.named_parameters(test_util.cases_from_list( @@ -160,7 +159,7 @@ def test_angular_symmetry_functions_neighbor_list(self, gr_exact = gr(R) gr_nbrs = gr_neigh(R, neighbor=nbrs) - tol = 1e-13 if FLAGS.jax_enable_x64 else 1e-6 + tol = 1e-13 if jax_config.jax_enable_x64 else 1e-6 self.assertAllClose(gr_exact, gr_nbrs, atol=tol, rtol=tol) @parameterized.named_parameters(test_util.cases_from_list( diff --git a/tests/partition_test.py b/tests/partition_test.py index 3a42c18f..a4536326 100644 --- a/tests/partition_test.py +++ b/tests/partition_test.py @@ -36,14 +36,12 @@ jax_config.parse_flags_with_absl() -FLAGS = jax_config.FLAGS - PARTICLE_COUNT = 1000 STOCHASTIC_SAMPLES = 10 SPATIAL_DIMENSION = [2, 3] -if FLAGS.jax_enable_x64: +if jax_config.jax_enable_x64: POSITION_DTYPE = [f32, f64] else: POSITION_DTYPE = [f32] diff --git a/tests/quantity_test.py b/tests/quantity_test.py index 9267a3ca..d858b6d3 100644 --- a/tests/quantity_test.py +++ b/tests/quantity_test.py @@ -32,7 +32,6 @@ jax_config.parse_flags_with_absl() -FLAGS = jax_config.FLAGS PARTICLE_COUNT = 10 STOCHASTIC_SAMPLES = 10 @@ -42,7 +41,7 @@ partition.Sparse, partition.OrderedSparse] -DTYPES = [f32, f64] if FLAGS.jax_enable_x64 else [f32] +DTYPES = [f32, f64] if jax_config.jax_enable_x64 else [f32] COORDS = ['fractional', 'real'] @@ -603,7 +602,7 @@ def test_phop(self, spatial_dim, dtype, window): def test_maybe_downcast(self): - if not FLAGS.jax_enable_x64: + if not jax_config.jax_enable_x64: self.skipTest('Maybe downcast only works for float32 mode.') x = np.array([1, 2, 3], np.float64) diff --git a/tests/rigid_body_test.py b/tests/rigid_body_test.py index fd1577b2..758ccd76 100644 --- a/tests/rigid_body_test.py +++ b/tests/rigid_body_test.py @@ -45,7 +45,6 @@ jax_config.parse_flags_with_absl() -FLAGS = jax_config.FLAGS f32 = util.f32 @@ -64,7 +63,7 @@ BROWNIAN_DYNAMICS_STEPS = 8000 DTYPE = [f32] -if FLAGS.jax_enable_x64: +if jax_config.jax_enable_x64: DTYPE += [f64] diff --git a/tests/simulate_test.py b/tests/simulate_test.py index e947a24a..380f8d1a 100644 --- a/tests/simulate_test.py +++ b/tests/simulate_test.py @@ -41,7 +41,6 @@ from functools import partial jax_config.parse_flags_with_absl() -FLAGS = jax_config.FLAGS PARTICLE_COUNT = 1000 @@ -58,7 +57,7 @@ BROWNIAN_DYNAMICS_STEPS = 8000 DTYPE = [f32] -if FLAGS.jax_enable_x64: +if jax_config.jax_enable_x64: DTYPE += [f64] diff --git a/tests/smap_test.py b/tests/smap_test.py index bf009cd7..a8e8fc38 100644 --- a/tests/smap_test.py +++ b/tests/smap_test.py @@ -31,7 +31,6 @@ from jax_md import test_util jax_config.parse_flags_with_absl() -FLAGS = jax_config.FLAGS test_util.update_test_tolerance(f32_tol=5e-6, f64_tol=1e-14) @@ -44,7 +43,7 @@ NEIGHBOR_LIST_PARTICLE_COUNT = 100 -if FLAGS.jax_enable_x64: +if jax_config.jax_enable_x64: POSITION_DTYPE = [f32, f64] else: POSITION_DTYPE = [f32] diff --git a/tests/space_test.py b/tests/space_test.py index 09accf0e..116888fd 100644 --- a/tests/space_test.py +++ b/tests/space_test.py @@ -32,7 +32,6 @@ test_util.update_test_tolerance(5e-5, 5e-13) jax_config.parse_flags_with_absl() -FLAGS = jax_config.FLAGS PARTICLE_COUNT = 10 STOCHASTIC_SAMPLES = 10 @@ -40,7 +39,7 @@ SPATIAL_DIMENSION = [2, 3] BOX_FORMATS = ['scalar', 'vector', 'matrix'] -if FLAGS.jax_enable_x64: +if jax_config.jax_enable_x64: POSITION_DTYPE = [f32, f64] else: POSITION_DTYPE = [f32] diff --git a/tests/tpu_test.py b/tests/tpu_test.py index 9bd891f4..a737dff8 100644 --- a/tests/tpu_test.py +++ b/tests/tpu_test.py @@ -40,8 +40,6 @@ update_test_tolerance(5e-5, 1e-7) -FLAGS = jax_config.FLAGS - def get_test_grid(rng_key, topology=None, num_dims=2, add_aux=False, ): # magic numbers to make the gird fold evenly, after splitting