Skip to content

Commit

Permalink
Better hint for y_dot
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Feb 29, 2024
1 parent b5feac5 commit da42761
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def __init__(
self._slices = copy.copy(copy_this._slices)
self._size = copy.copy(copy_this._size)
self._children_slices = copy.copy(copy_this._children_slices)
self.secondary_dimensions_npts: int = copy_this.secondary_dimensions_npts # type: ignore[no-redef]
self.secondary_dimensions_npts = copy_this.secondary_dimensions_npts

@classmethod
def _from_json(cls, snippet: dict):
Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
from scipy import special
from typing import Sequence, Callable, Any
from typing import Sequence, Callable
from typing_extensions import TypeVar

import pybamm
Expand Down Expand Up @@ -33,7 +33,7 @@ class Function(pybamm.Symbol):

def __init__(
self,
function: Callable[[Any, Any], Any],
function: Callable,
*children: pybamm.Symbol,
name: str | None = None,
derivative: str | None = "autograd",
Expand Down
4 changes: 2 additions & 2 deletions pybamm/solvers/idaklu_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def _jax_vjp_impl(
if t.ndim == 0 or (t.ndim == 1 and t.shape[0] == 1):
# scalar time input
logger.debug("scalar time")
y_dot = jnp.zeros_like(t)
y_dot: jax.Array | np.ndarray = jnp.zeros_like(t)
js = self._jaxify_solve(t, invar, *inputs)
if js.ndim == 0:
js = jnp.array([js])
Expand All @@ -507,7 +507,7 @@ def _jax_vjp_impl(
for ix, y_outvar in enumerate(y_bar.T):
y_dot += jnp.dot(y_outvar, js[:, ix])
logger.debug(f"_jax_vjp_impl [exit]: {type(y_dot)}, {y_dot}, {y_dot.shape}")
y_dot = jnp.array(y_dot)
y_dot = np.array(y_dot)
return y_dot

def _jax_vjp_impl_array_inputs(
Expand Down

0 comments on commit da42761

Please sign in to comment.