Skip to content

Commit

Permalink
chore: cleanup mypy config
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Feb 29, 2024
1 parent 94aa498 commit fa53c17
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 21 deletions.
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,12 @@ repos:
- id: rst-backticks
- id: rst-directive-colons
- id: rst-inline-touching-normal

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
files: pybamm
additional_dependencies:
- numpy
- packaging
2 changes: 1 addition & 1 deletion pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _preprocess_binary(
right = pybamm.Vector(right)

# Check both left and right are pybamm Symbols
if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)):
if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)): # type: ignore[redundant-expr]
raise NotImplementedError(
"""BinaryOperator not implemented for symbols of type {} and {}""".format(
type(left), type(right)
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def __init__(
self._slices = copy.copy(copy_this._slices)
self._size = copy.copy(copy_this._size)
self._children_slices = copy.copy(copy_this._children_slices)
self.secondary_dimensions_npts = copy_this.secondary_dimensions_npts
self.secondary_dimensions_npts: int = copy_this.secondary_dimensions_npts # type: ignore[no-redef]

@classmethod
def _from_json(cls, snippet: dict):
Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
from scipy import special
from typing import Sequence, Callable
from typing import Sequence, Callable, Any
from typing_extensions import TypeVar

import pybamm
Expand Down Expand Up @@ -33,7 +33,7 @@ class Function(pybamm.Symbol):

def __init__(
self,
function: Callable,
function: Callable[[Any, Any], Any],
*children: pybamm.Symbol,
name: str | None = None,
derivative: str | None = "autograd",
Expand Down
2 changes: 1 addition & 1 deletion pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
# Set colors, linestyles, figsize, axis limits
# call LoopList to make sure list index never runs out
if colors is None:
self.colors = LoopList(colors or ["r", "b", "k", "g", "m", "c"])
self.colors = LoopList(["r", "b", "k", "g", "m", "c"])
else:
self.colors = LoopList(colors)
self.linestyles = LoopList(linestyles or ["-", ":", "--", "-."])
Expand Down
17 changes: 9 additions & 8 deletions pybamm/solvers/idaklu_jax.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import pybamm
import numpy as np
import logging
import warnings
import numbers

from typing import Union
from typing import List

from functools import lru_cache

Expand Down Expand Up @@ -273,9 +274,9 @@ def f_isolated(*args, **kwargs):

def jax_value(
self,
t: np.ndarray = None,
inputs: Union[dict, None] = None,
output_variables: Union[List[str], None] = None,
t: np.ndarray | None = None,
inputs: dict | None = None,
output_variables: list[str] | None = None,
):
"""Helper function to compute the gradient of a jaxified expression
Expand Down Expand Up @@ -306,9 +307,9 @@ def jax_value(

def jax_grad(
self,
t: np.ndarray = None,
inputs: Union[dict, None] = None,
output_variables: Union[List[str], None] = None,
t: np.ndarray | None = None,
inputs: dict | None = None,
output_variables: list[str] | None = None,
):
"""Helper function to compute the gradient of a jaxified expression
Expand Down Expand Up @@ -506,7 +507,7 @@ def _jax_vjp_impl(
for ix, y_outvar in enumerate(y_bar.T):
y_dot += jnp.dot(y_outvar, js[:, ix])
logger.debug(f"_jax_vjp_impl [exit]: {type(y_dot)}, {y_dot}, {y_dot.shape}")
y_dot = np.array(y_dot)
y_dot = jnp.array(y_dot)
return y_dot

def _jax_vjp_impl_array_inputs(
Expand Down
13 changes: 6 additions & 7 deletions pybamm/type_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
#
from __future__ import annotations

from typing import Union, List, Dict
from typing_extensions import TypeAlias
import numpy as np
import pybamm

# numbers.Number should not be used for type hints
Numeric: TypeAlias = Union[int, float, np.number]
Numeric: TypeAlias = int | float | np.number

# expression tree
ChildValue: TypeAlias = Union[float, np.ndarray]
ChildSymbol: TypeAlias = Union[float, np.ndarray, pybamm.Symbol]
ChildValue: TypeAlias = float | np.ndarray
ChildSymbol: TypeAlias = float | np.ndarray | pybamm.Symbol

DomainType: TypeAlias = Union[List[str], str, None]
AuxiliaryDomainType: TypeAlias = Union[Dict[str, str], None]
DomainsType: TypeAlias = Union[Dict[str, Union[List[str], str]], None]
DomainType: TypeAlias = list[str] | str | None
AuxiliaryDomainType: TypeAlias = dict[str, str] | None
DomainsType: TypeAlias = dict[str, list[str] | str] | None
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,21 @@ concurrency = ["multiprocessing"]

[tool.repo-review]
ignore = [
"PP003" # list wheel as a build-dep
"PP003", # list wheel as a build-dep
"PC160", # codespell
"PC180", # prettier

]

[tool.mypy]
ignore_missing_imports = true
allow_redefinition = true
disable_error_code = ["call-overload", "operator"]
enable_error_code = [
"ignore-without-code",
"truthy-bool",
"redundant-expr",
]

[[tool.mypy.overrides]]
module = [
Expand Down

0 comments on commit fa53c17

Please sign in to comment.