Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalised symbol copying #4023

Open
wants to merge 16 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ def _process_symbol(self, symbol):

elif isinstance(symbol, pybamm.Function):
disc_children = [self.process_symbol(child) for child in symbol.children]
return symbol._function_new_copy(disc_children)
return symbol.new_copy(disc_children)
martinjrobins marked this conversation as resolved.
Show resolved Hide resolved

elif isinstance(symbol, pybamm.VariableDot):
# Add symbol's reference and multiply by the symbol's scale
Expand Down
6 changes: 5 additions & 1 deletion pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def _jac(self, variable) -> pybamm.Matrix:
jac = csr_matrix((self.size, variable.evaluation_array.count(True)))
return pybamm.Matrix(jac)

def create_copy(self):
def create_copy(
self,
new_children=None,
perform_simplifications: bool = True,
):
"""See :meth:`pybamm.Symbol.new_copy()`."""
return self.__class__(
self.entries,
Expand Down
100 changes: 83 additions & 17 deletions pybamm/expression_tree/averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ def __init__(
self,
child: pybamm.Symbol,
name: str,
integration_variable: list[pybamm.IndependentVariable]
| pybamm.IndependentVariable,
integration_variable: (
list[pybamm.IndependentVariable] | pybamm.IndependentVariable
),
) -> None:
super().__init__(child, integration_variable)
self.name = name
Expand All @@ -38,9 +39,22 @@ def __init__(self, child: pybamm.Symbol) -> None:
integration_variable = x
super().__init__(child, "x-average", integration_variable)

def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`UnaryOperator._unary_new_copy()`."""
return x_average(child)
def _unary_new_copy(
self, child: pybamm.Symbol, perform_simplifications: bool = True
):
"""
Creates a new copy of the operator with the child `child`.

Uses the convenience function :meth:`x_average` to perform checks before
creating an XAverage object.
"""
if perform_simplifications:
return x_average(child)
else:
raise NotImplementedError(
f"{self.__class__.__name__} should always be copied using "
martinjrobins marked this conversation as resolved.
Show resolved Hide resolved
"simplification checks"
)


class YZAverage(_BaseAverage):
Expand All @@ -50,9 +64,22 @@ def __init__(self, child: pybamm.Symbol) -> None:
integration_variable: list[pybamm.IndependentVariable] = [y, z]
super().__init__(child, "yz-average", integration_variable)

def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`UnaryOperator._unary_new_copy()`."""
return yz_average(child)
def _unary_new_copy(
self, child: pybamm.Symbol, perform_simplifications: bool = True
):
"""
Creates a new copy of the operator with the child `child`.

Uses the convenience function :meth:`yz_average` to perform checks before
creating an YZAverage object.
"""
if perform_simplifications:
return yz_average(child)
else:
raise NotImplementedError(
f"{self.__class__.__name__} should always be copied using "
"simplification checks"
)


class ZAverage(_BaseAverage):
Expand All @@ -62,9 +89,22 @@ def __init__(self, child: pybamm.Symbol) -> None:
]
super().__init__(child, "z-average", integration_variable)

def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`UnaryOperator._unary_new_copy()`."""
return z_average(child)
def _unary_new_copy(
self, child: pybamm.Symbol, perform_simplifications: bool = True
):
"""
Creates a new copy of the operator with the child `child`.

Uses the convenience function :meth:`z_average` to perform checks before
creating an ZAverage object.
"""
if perform_simplifications:
return z_average(child)
else:
raise NotImplementedError(
f"{self.__class__.__name__} should always be copied using "
"simplification checks"
)


class RAverage(_BaseAverage):
Expand All @@ -74,9 +114,22 @@ def __init__(self, child: pybamm.Symbol) -> None:
]
super().__init__(child, "r-average", integration_variable)

def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`UnaryOperator._unary_new_copy()`."""
return r_average(child)
def _unary_new_copy(
self, child: pybamm.Symbol, perform_simplifications: bool = True
):
"""
Creates a new copy of the operator with the child `child`.

Uses the convenience function :meth:`r_average` to perform checks before
creating an RAverage object.
"""
if perform_simplifications:
return r_average(child)
else:
raise NotImplementedError(
f"{self.__class__.__name__} should always be copied using "
"simplification checks"
)


class SizeAverage(_BaseAverage):
Expand All @@ -86,9 +139,22 @@ def __init__(self, child: pybamm.Symbol, f_a_dist) -> None:
super().__init__(child, "size-average", integration_variable)
self.f_a_dist = f_a_dist

def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`UnaryOperator._unary_new_copy()`."""
return size_average(child, f_a_dist=self.f_a_dist)
def _unary_new_copy(
self, child: pybamm.Symbol, perform_simplifications: bool = True
):
"""
Creates a new copy of the operator with the child `child`.

Uses the convenience function :meth:`size_average` to perform checks before
creating an SizeAverage object.
"""
if perform_simplifications:
return size_average(child, f_a_dist=self.f_a_dist)
else:
raise NotImplementedError(
f"{self.__class__.__name__} should always be copied using "
"simplification checks"
)


def x_average(symbol: pybamm.Symbol) -> pybamm.Symbol:
Expand Down
48 changes: 31 additions & 17 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,24 +113,39 @@ def __str__(self):
right_str = f"{self.right!s}"
return f"{left_str} {self.name} {right_str}"

def create_copy(self):
def create_copy(
self,
new_children: list[pybamm.Symbol] | None = None,
perform_simplifications: bool = True,
):
"""See :meth:`pybamm.Symbol.new_copy()`."""

# process children
new_left = self.left.new_copy()
new_right = self.right.new_copy()
if new_children and len(new_children) != 2:
raise ValueError(
f"Symbol of type {type(self)} must have exactly two children."
)
children = self._children_for_copying(new_children)

if not perform_simplifications:
out = self.__class__(children[0], children[1])
else:
# creates a new instance using the overloaded binary operator to perform
# additional simplifications, rather than just calling the constructor
out = self._binary_new_copy(children[0], children[1])

# make new symbol, ensure domain(s) remain the same
out = self._binary_new_copy(new_left, new_right)
out.copy_domains(self)

return out

def _binary_new_copy(self, left: ChildSymbol, right: ChildSymbol):
"""
Default behaviour for new_copy.
This copies the behaviour of `_binary_evaluate`, but since `left` and `right`
are symbols creates a new symbol instead of returning a value.
Performs the overloaded binary operation on the two symbols `left` and `right`,
to create a binary class instance after performing appropriate simplifying
checks.

Default behaviour for _binary_new_copy copies the behaviour of `_binary_evaluate`,
but since `left` and `right` are symbols this creates a new symbol instead of
returning a value.
"""
return self._binary_evaluate(left, right)

Expand Down Expand Up @@ -553,7 +568,10 @@ def _binary_new_copy(
left: ChildSymbol,
right: ChildSymbol,
):
"""See :meth:`pybamm.BinaryOperator._binary_new_copy()`."""
"""
Overwrites `pybamm.BinaryOperator._binary_new_copy()` to return a new instance of
`pybamm.Equality` rather than using `binary_evaluate` to return a value.
"""
return pybamm.Equality(left, right)


Expand Down Expand Up @@ -834,13 +852,11 @@ def _simplified_binary_broadcast_concatenation(
left, pybamm.ConcatenationVariable
):
if right.evaluates_to_constant_number():
return left._concatenation_new_copy(
[operator(child, right) for child in left.orphans]
)
return left.new_copy([operator(child, right) for child in left.orphans])
elif isinstance(right, pybamm.Concatenation) and not isinstance(
right, pybamm.ConcatenationVariable
):
return left._concatenation_new_copy(
return left.new_copy(
[
operator(left_child, right_child)
for left_child, right_child in zip(left.orphans, right.orphans)
Expand All @@ -850,9 +866,7 @@ def _simplified_binary_broadcast_concatenation(
right, pybamm.ConcatenationVariable
):
if left.evaluates_to_constant_number():
return right._concatenation_new_copy(
[operator(left, child) for child in right.orphans]
)
return right.new_copy([operator(left, child) for child in right.orphans])
return None


Expand Down
18 changes: 5 additions & 13 deletions pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def _from_json(cls, snippet):
"pybamm.Broadcast: Please use a discretised model when reading in from JSON"
)

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class PrimaryBroadcast(Broadcast):
"""
Expand Down Expand Up @@ -174,10 +178,6 @@ def check_and_set_domains(self, child: pybamm.Symbol, broadcast_domain: list[str

return domains

def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)

def _evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a Broadcast.
Expand Down Expand Up @@ -308,10 +308,6 @@ def check_and_set_domains(self, child: pybamm.Symbol, broadcast_domain: list[str

return domains

def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return SecondaryBroadcast(child, self.broadcast_domain)

def _evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a Broadcast.
Expand Down Expand Up @@ -429,10 +425,6 @@ def check_and_set_domains(

return domains

def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)

def _evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a Broadcast.
Expand Down Expand Up @@ -506,7 +498,7 @@ def check_and_set_domains(self, child: pybamm.Symbol, broadcast_domains: dict):

return broadcast_domains

def _unary_new_copy(self, child):
def _unary_new_copy(self, child, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, broadcast_domains=self.domains)

Expand Down
60 changes: 43 additions & 17 deletions pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,26 @@ def evaluate(
children_eval = [child.evaluate(t, y, y_dot, inputs) for child in self.children]
return self._concatenation_evaluate(children_eval)

def create_copy(self):
def create_copy(
self,
new_children: list[pybamm.Symbol] | None = None,
perform_simplifications: bool = True,
):
"""See :meth:`pybamm.Symbol.new_copy()`."""
new_children = [child.new_copy() for child in self.children]
return self._concatenation_new_copy(new_children)
children = self._children_for_copying(new_children)

def _concatenation_new_copy(self, children):
"""See :meth:`pybamm.Symbol.new_copy()`."""
return concatenation(*children)
return self._concatenation_new_copy(children, perform_simplifications)

def _concatenation_new_copy(self, children, perform_simplifications: bool = True):
"""
Creates a copy for the current concatenation class using the convenience
function :meth:`concatenation` to perform simplifications based on the new
children before creating the new copy.
"""
if perform_simplifications:
return concatenation(*children)
else:
return self.__class__(*children)

def _concatenation_jac(self, children_jacs):
"""Calculate the Jacobian of a concatenation."""
Expand Down Expand Up @@ -225,9 +237,19 @@ def _concatenation_jac(self, children_jacs):
else:
return SparseStack(*children_jacs)

def _concatenation_new_copy(self, children):
"""See :meth:`pybamm.Symbol.new_copy()`."""
return numpy_concatenation(*children)
def _concatenation_new_copy(
self,
children,
perform_simplifications: bool = True,
):
"""See :meth:`pybamm.Concatenation._concatenation_new_copy()`."""
if perform_simplifications:
return numpy_concatenation(*children)
else:
raise NotImplementedError(
f"{self.__class__.__name__} should always be copied using "
"simplification checks"
)


class DomainConcatenation(Concatenation):
Expand Down Expand Up @@ -373,12 +395,16 @@ def _concatenation_jac(self, children_jacs):
jacs.append(pybamm.Index(child_jac, child_slice[i]))
return SparseStack(*jacs)

def _concatenation_new_copy(self, children: list[pybamm.Symbol]):
"""See :meth:`pybamm.Symbol.new_copy()`."""
new_symbol = simplified_domain_concatenation(
children, self.full_mesh, copy_this=self
)
return new_symbol
def _concatenation_new_copy(
self, children: list[pybamm.Symbol], perform_simplifications: bool = True
):
"""See :meth:`pybamm.Concatenation._concatenation_new_copy()`."""
if perform_simplifications:
return simplified_domain_concatenation(
children, self.full_mesh, copy_this=self
)
else:
return DomainConcatenation(children, self.full_mesh, copy_this=self)

def to_json(self):
"""
Expand Down Expand Up @@ -434,8 +460,8 @@ def __init__(self, *children):
concat_fun=concatenation_function,
)

def _concatenation_new_copy(self, children):
"""See :meth:`pybamm.Symbol.new_copy()`."""
def _concatenation_new_copy(self, children, perform_simplifications=True):
"""See :meth:`pybamm.Concatenation._concatenation_new_copy()`."""
return SparseStack(*children)


Expand Down