Skip to content

Commit

Permalink
Merge branch 'develop' into feat/jaxWindows
Browse files Browse the repository at this point in the history
  • Loading branch information
kratman committed Mar 29, 2024
2 parents c6467f1 + 60ba076 commit 47feff2
Show file tree
Hide file tree
Showing 15 changed files with 87 additions and 73 deletions.
8 changes: 3 additions & 5 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ def run_coverage(session):
session.install("-e", ".[all,dev]", silent=False)
else:
session.install("-e", ".[all,dev,jax]", silent=False)
session.run("coverage", "run", "run-tests.py", "--nosub")
session.run("coverage", "combine")
session.run("coverage", "xml")
session.run("pytest", "--cov=pybamm", "--cov-report=xml", "tests/unit")


@nox.session(name="integration")
Expand All @@ -84,7 +82,7 @@ def run_integration(session):
@nox.session(name="doctests")
def run_doctests(session):
"""Run the doctests and generate the output(s) in the docs/build/ directory."""
session.install("-e", ".[all,docs]", silent=False)
session.install("-e", ".[all,dev,docs]", silent=False)
session.run("python", "run-tests.py", "--doctest")


Expand Down Expand Up @@ -116,7 +114,7 @@ def run_scripts(session):
# https://bitbucket.org/pybtex-devs/pybtex/issues/169/replace-pkg_resources-with
# is fixed
session.install("setuptools", silent=False)
session.install("-e", ".[all]", silent=False)
session.install("-e", ".[all,dev]", silent=False)
session.run("python", "run-tests.py", "--scripts")


Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/operations/latexify.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _get_geometry_displays(self, var):
for _, rng in self.model.default_geometry[var.domain[-1]].items():
rng_max = get_rng_min_max_name(rng, "max")

geo_latex = f"\quad {rng_min} < {name} < {rng_max}"
geo_latex = rf"\quad {rng_min} < {name} < {rng_max}"
geo.append(geo_latex)

return geo
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/printing/sympy_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _print_Derivative(self, expr):
eqn = super()._print_Derivative(expr)
if getattr(expr, "force_partial", False) and "partial" not in eqn:
var1, var2 = re.findall(r"^\\frac{(\w+)}{(\w+) .+", eqn)[0]
eqn = eqn.replace(var1, "\partial").replace(var2, "\partial")
eqn = eqn.replace(var1, r"\partial").replace(var2, r"\partial")

return eqn

Expand Down
2 changes: 1 addition & 1 deletion pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(
self.spatial_unit = "mm"
elif spatial_unit == "um": # micrometers
self.spatial_factor = 1e6
self.spatial_unit = "$\mu$m"
self.spatial_unit = r"$\mu$m"
else:
raise ValueError(f"spatial unit '{spatial_unit}' not recognized")

Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ dev = [
# For running testing sessions
"nox",
# For coverage
"coverage[toml]",
"pytest-cov",
# For test parameterization
"parameterized>=0.9",
# For testing Jupyter notebooks
Expand Down Expand Up @@ -244,6 +244,10 @@ filterwarnings = [
# ignore internal nbmake warnings
'ignore:unclosed \<socket.socket:ResourceWarning',
'ignore:unclosed event loop \<:ResourceWarning',
# ignore warnings generated while running tests
"ignore::DeprecationWarning",
"ignore::UserWarning",
"ignore::RuntimeWarning",
]

# Logging configuration
Expand Down
22 changes: 14 additions & 8 deletions run-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import pybamm
import sys
import argparse
import unittest
import subprocess
import pytest
import unittest


def run_code_tests(executable=False, folder: str = "unit", interpreter="python"):
Expand All @@ -36,12 +37,18 @@ def run_code_tests(executable=False, folder: str = "unit", interpreter="python")
# currently activated virtual environment
interpreter = sys.executable
if executable is False:
suite = unittest.defaultTestLoader.discover(tests, pattern="test*.py")
result = unittest.TextTestRunner(verbosity=2).run(suite)
ret = int(not result.wasSuccessful())
if tests == "tests/unit":
ret = pytest.main(["-v", tests])
else:
suite = unittest.defaultTestLoader.discover(tests, pattern="test*.py")
result = unittest.TextTestRunner(verbosity=2).run(suite)
ret = int(not result.wasSuccessful())
else:
print(f"Running {folder} tests with executable '{interpreter}'")
cmd = [interpreter, "-m", "unittest", "discover", "-v", tests]
print(f"Running {folder} tests with executable {interpreter}")
if tests == "tests/unit":
cmd = [interpreter, "-m", "pytest", "-v", tests]
else:
cmd = [interpreter, "-m", "unittest", "discover", "-v", tests]
p = subprocess.Popen(cmd)
try:
ret = p.wait()
Expand Down Expand Up @@ -148,7 +155,7 @@ def test_script(path, executable="python"):

# Tell matplotlib not to produce any figures
env = dict(os.environ)
env["MPLBACKEND"] = "Template"
env["MPLBACKEND"] = "Agg"

# Run in subprocess
cmd = [executable, path]
Expand Down Expand Up @@ -243,7 +250,6 @@ def test_script(path, executable="python"):
metavar="python",
help="Give the name of the Python interpreter if it is not 'python'",
)

# Parse!
args = parser.parse_args()

Expand Down
3 changes: 3 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
get_discretisation_for_testing,
get_p2d_discretisation_for_testing,
get_size_distribution_disc_for_testing,
function_test,
multi_var_function_test,
multi_var_function_cube_test,
get_1p1d_discretisation_for_testing,
get_2p1d_discretisation_for_testing,
get_unit_2p1D_mesh_for_testing,
Expand Down
12 changes: 12 additions & 0 deletions tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,18 @@ def get_size_distribution_disc_for_testing(xpts=None, rpts=10, Rpts=10, zpts=15)
)


def function_test(arg):
return arg + arg


def multi_var_function_test(arg1, arg2):
return arg1 + arg2


def multi_var_function_cube_test(arg1, arg2):
return arg1 + arg2**3


def get_1p1d_discretisation_for_testing(xpts=None, rpts=10, zpts=15):
return get_discretisation_for_testing(
mesh=get_1p1d_mesh_for_testing(xpts, rpts, zpts),
Expand Down
47 changes: 20 additions & 27 deletions tests/unit/test_expression_tree/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,11 @@

import pybamm
import sympy


def test_function(arg):
return arg + arg


def test_multi_var_function(arg1, arg2):
return arg1 + arg2


def test_multi_var_function_cube(arg1, arg2):
return arg1 + arg2**3
from tests import (
function_test,
multi_var_function_test,
multi_var_function_cube_test,
)


class TestFunction(TestCase):
Expand All @@ -31,16 +24,16 @@ def test_number_input(self):
self.assertIsInstance(log.children[0], pybamm.Scalar)
self.assertEqual(log.evaluate(), np.log(10))

summ = pybamm.Function(test_multi_var_function, 1, 2)
summ = pybamm.Function(multi_var_function_test, 1, 2)
self.assertIsInstance(summ.children[0], pybamm.Scalar)
self.assertIsInstance(summ.children[1], pybamm.Scalar)
self.assertEqual(summ.evaluate(), 3)

def test_function_of_one_variable(self):
a = pybamm.Symbol("a")
funca = pybamm.Function(test_function, a)
self.assertEqual(funca.name, "function (test_function)")
self.assertEqual(str(funca), "test_function(a)")
funca = pybamm.Function(function_test, a)
self.assertEqual(funca.name, "function (function_test)")
self.assertEqual(str(funca), "function_test(a)")
self.assertEqual(funca.children[0].name, a.name)

b = pybamm.Scalar(1)
Expand All @@ -61,7 +54,7 @@ def test_diff(self):
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))
y = np.array([5])
func = pybamm.Function(test_function, a)
func = pybamm.Function(function_test, a)
self.assertEqual(func.diff(a).evaluate(y=y), 2)
self.assertEqual(func.diff(func).evaluate(), 1)
func = pybamm.sin(a)
Expand All @@ -72,38 +65,38 @@ def test_diff(self):
self.assertEqual(func.diff(a).evaluate(y=y), np.exp(a.evaluate(y=y)))

# multiple variables
func = pybamm.Function(test_multi_var_function, 4 * a, 3 * a)
func = pybamm.Function(multi_var_function_test, 4 * a, 3 * a)
self.assertEqual(func.diff(a).evaluate(y=y), 7)
func = pybamm.Function(test_multi_var_function, 4 * a, 3 * b)
func = pybamm.Function(multi_var_function_test, 4 * a, 3 * b)
self.assertEqual(func.diff(a).evaluate(y=np.array([5, 6])), 4)
self.assertEqual(func.diff(b).evaluate(y=np.array([5, 6])), 3)
func = pybamm.Function(test_multi_var_function_cube, 4 * a, 3 * b)
func = pybamm.Function(multi_var_function_cube_test, 4 * a, 3 * b)
self.assertEqual(func.diff(a).evaluate(y=np.array([5, 6])), 4)
self.assertEqual(
func.diff(b).evaluate(y=np.array([5, 6])), 3 * 3 * (3 * 6) ** 2
)

# exceptions
func = pybamm.Function(
test_multi_var_function_cube, 4 * a, 3 * b, derivative="derivative"
multi_var_function_cube_test, 4 * a, 3 * b, derivative="derivative"
)
with self.assertRaises(ValueError):
func.diff(a)

def test_function_of_multiple_variables(self):
a = pybamm.Variable("a")
b = pybamm.Parameter("b")
func = pybamm.Function(test_multi_var_function, a, b)
self.assertEqual(func.name, "function (test_multi_var_function)")
self.assertEqual(str(func), "test_multi_var_function(a, b)")
func = pybamm.Function(multi_var_function_test, a, b)
self.assertEqual(func.name, "function (multi_var_function_test)")
self.assertEqual(str(func), "multi_var_function_test(a, b)")
self.assertEqual(func.children[0].name, a.name)
self.assertEqual(func.children[1].name, b.name)

# test eval and diff
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))
y = np.array([5, 2])
func = pybamm.Function(test_multi_var_function, a, b)
func = pybamm.Function(multi_var_function_test, a, b)

self.assertEqual(func.evaluate(y=y), 7)
self.assertEqual(func.diff(a).evaluate(y=y), 1)
Expand All @@ -114,7 +107,7 @@ def test_exceptions(self):
a = pybamm.Variable("a", domain="something")
b = pybamm.Variable("b", domain="something else")
with self.assertRaises(pybamm.DomainError):
pybamm.Function(test_multi_var_function, a, b)
pybamm.Function(multi_var_function_test, a, b)

def test_function_unnamed(self):
fun = pybamm.Function(np.cos, pybamm.t)
Expand Down Expand Up @@ -148,7 +141,7 @@ def test_to_equation(self):

def test_to_from_json_error(self):
a = pybamm.Symbol("a")
funca = pybamm.Function(test_function, a)
funca = pybamm.Function(function_test, a)

with self.assertRaises(NotImplementedError):
funca.to_json()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,10 @@

if pybamm.have_jax():
import jax


def test_function(arg):
return arg + arg


def test_function2(arg1, arg2):
return arg1 + arg2
from tests import (
function_test,
multi_var_function_test,
)


class TestEvaluate(TestCase):
Expand Down Expand Up @@ -93,10 +89,10 @@ def test_find_symbols(self):
# test function
constant_symbols = OrderedDict()
variable_symbols = OrderedDict()
expr = pybamm.Function(test_function, a)
expr = pybamm.Function(function_test, a)
pybamm.find_symbols(expr, constant_symbols, variable_symbols)
self.assertEqual(next(iter(constant_symbols.keys())), expr.id)
self.assertEqual(next(iter(constant_symbols.values())), test_function)
self.assertEqual(next(iter(constant_symbols.values())), function_test)
self.assertEqual(next(iter(variable_symbols.keys())), a.id)
self.assertEqual(list(variable_symbols.keys())[1], expr.id)
self.assertEqual(next(iter(variable_symbols.values())), "y[0:1]")
Expand Down Expand Up @@ -283,9 +279,9 @@ def test_to_python(self):
expr = a + b
constant_str, variable_str = pybamm.to_python(expr)
expected_str = (
"var_[0-9m]+ = y\[0:1\].*\\n"
"var_[0-9m]+ = y\[1:2\].*\\n"
"var_[0-9m]+ = var_[0-9m]+ \+ var_[0-9m]+"
r"var_[0-9m]+ = y\[0:1\].*\n"
r"var_[0-9m]+ = y\[1:2\].*\n"
r"var_[0-9m]+ = var_[0-9m]+ \+ var_[0-9m]+"
)

self.assertRegex(variable_str, expected_str)
Expand All @@ -306,12 +302,12 @@ def test_evaluator_python(self):
self.assertEqual(result, 3)

# test function(a*b)
expr = pybamm.Function(test_function, a * b)
expr = pybamm.Function(function_test, a * b)
evaluator = pybamm.EvaluatorPython(expr)
result = evaluator(t=None, y=np.array([[2], [3]]))
self.assertEqual(result, 12)

expr = pybamm.Function(test_function2, a, b)
expr = pybamm.Function(multi_var_function_test, a, b)
evaluator = pybamm.EvaluatorPython(expr)
result = evaluator(t=None, y=np.array([[2], [3]]))
self.assertEqual(result, 5)
Expand Down Expand Up @@ -486,7 +482,7 @@ def test_evaluator_jax(self):
self.assertEqual(result, 3)

# test function(a*b)
expr = pybamm.Function(test_function, a * b)
expr = pybamm.Function(function_test, a * b)
evaluator = pybamm.EvaluatorJax(expr)
result = evaluator(t=None, y=np.array([[2], [3]]))
self.assertEqual(result, 12)
Expand Down
7 changes: 2 additions & 5 deletions tests/unit/test_expression_tree/test_operations/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
import unittest
from scipy.sparse import eye
from tests import get_mesh_for_testing


def test_multi_var_function(arg1, arg2):
return arg1 + arg2
from tests import multi_var_function_test


class TestJacobian(TestCase):
Expand Down Expand Up @@ -217,7 +214,7 @@ def test_functions(self):
np.testing.assert_array_equal(0, dfunc_dy)

# several children
func = pybamm.Function(test_multi_var_function, 2 * y, 3 * y)
func = pybamm.Function(multi_var_function_test, 2 * y, 3 * y)
jacobian = np.diag(5 * np.ones(4))
dfunc_dy = func.jac(y).evaluate(y=y0)
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())
Expand Down
11 changes: 5 additions & 6 deletions tests/unit/test_expression_tree/test_operations/test_jac_2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import numpy as np
import unittest
from scipy.sparse import eye
from tests import get_1p1d_discretisation_for_testing


def test_multi_var_function(arg1, arg2):
return arg1 + arg2
from tests import (
get_1p1d_discretisation_for_testing,
multi_var_function_test,
)


class TestJacobian(TestCase):
Expand Down Expand Up @@ -202,7 +201,7 @@ def test_functions(self):
np.testing.assert_array_equal(0, dfunc_dy)

# several children
func = pybamm.Function(test_multi_var_function, 2 * y, 3 * y)
func = pybamm.Function(multi_var_function_test, 2 * y, 3 * y)
jacobian = np.diag(5 * np.ones(8))
dfunc_dy = func.jac(y).evaluate(y=y0)
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())
Expand Down

0 comments on commit 47feff2

Please sign in to comment.