Skip to content

Commit

Permalink
chore: cleanup mypy config
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Feb 29, 2024
1 parent 94aa498 commit f62d36a
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 15 deletions.
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,12 @@ repos:
- id: rst-backticks
- id: rst-directive-colons
- id: rst-inline-touching-normal

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
files: pybamm
additional_dependencies:
- numpy
- packaging
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def run_unit(session):
"scikits.odes",
external=True,
)
session.install("-e", ".[all,dev,jax,odes]", silent=False)
session.install("-e", ".[all,dev,jax]", silent=False)
else:
if sys.version_info < (3, 9):
session.install("-e", ".[all,dev]", silent=False)
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _preprocess_binary(
right = pybamm.Vector(right)

# Check both left and right are pybamm Symbols
if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)):
if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)): # type: ignore[redundant-expr]
raise NotImplementedError(
"""BinaryOperator not implemented for symbols of type {} and {}""".format(
type(left), type(right)
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def __init__(
self._slices = copy.copy(copy_this._slices)
self._size = copy.copy(copy_this._size)
self._children_slices = copy.copy(copy_this._children_slices)
self.secondary_dimensions_npts = copy_this.secondary_dimensions_npts
self.secondary_dimensions_npts: int = copy_this.secondary_dimensions_npts # type: ignore[no-redef]

@classmethod
def _from_json(cls, snippet: dict):
Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
from scipy import special
from typing import Sequence, Callable
from typing import Sequence, Callable, Any
from typing_extensions import TypeVar

import pybamm
Expand Down Expand Up @@ -33,7 +33,7 @@ class Function(pybamm.Symbol):

def __init__(
self,
function: Callable,
function: Callable[[Any, Any], Any],
*children: pybamm.Symbol,
name: str | None = None,
derivative: str | None = "autograd",
Expand Down
2 changes: 1 addition & 1 deletion pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
# Set colors, linestyles, figsize, axis limits
# call LoopList to make sure list index never runs out
if colors is None:
self.colors = LoopList(colors or ["r", "b", "k", "g", "m", "c"])
self.colors = LoopList(["r", "b", "k", "g", "m", "c"])
else:
self.colors = LoopList(colors)
self.linestyles = LoopList(linestyles or ["-", ":", "--", "-."])
Expand Down
17 changes: 9 additions & 8 deletions pybamm/solvers/idaklu_jax.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import pybamm
import numpy as np
import logging
import warnings
import numbers

from typing import Union
from typing import List

from functools import lru_cache

Expand Down Expand Up @@ -273,9 +274,9 @@ def f_isolated(*args, **kwargs):

def jax_value(
self,
t: np.ndarray = None,
inputs: Union[dict, None] = None,
output_variables: Union[List[str], None] = None,
t: np.ndarray | None = None,
inputs: dict | None = None,
output_variables: list[str] | None = None,
):
"""Helper function to compute the gradient of a jaxified expression
Expand Down Expand Up @@ -306,9 +307,9 @@ def jax_value(

def jax_grad(
self,
t: np.ndarray = None,
inputs: Union[dict, None] = None,
output_variables: Union[List[str], None] = None,
t: np.ndarray | None = None,
inputs: dict | None = None,
output_variables: list[str] | None = None,
):
"""Helper function to compute the gradient of a jaxified expression
Expand Down Expand Up @@ -506,7 +507,7 @@ def _jax_vjp_impl(
for ix, y_outvar in enumerate(y_bar.T):
y_dot += jnp.dot(y_outvar, js[:, ix])
logger.debug(f"_jax_vjp_impl [exit]: {type(y_dot)}, {y_dot}, {y_dot.shape}")
y_dot = np.array(y_dot)
y_dot = jnp.array(y_dot)
return y_dot

def _jax_vjp_impl_array_inputs(
Expand Down
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,21 @@ concurrency = ["multiprocessing"]

[tool.repo-review]
ignore = [
"PP003" # list wheel as a build-dep
"PP003", # list wheel as a build-dep
"PC160", # codespell
"PC180", # prettier

]

[tool.mypy]
ignore_missing_imports = true
allow_redefinition = true
disable_error_code = ["call-overload", "operator"]
enable_error_code = [
"ignore-without-code",
"truthy-bool",
"redundant-expr",
]

[[tool.mypy.overrides]]
module = [
Expand Down

0 comments on commit f62d36a

Please sign in to comment.