Skip to content

Commit

Permalink
Merge pull request #225 from dynamicslab/cln_process_multiple
Browse files Browse the repository at this point in the history
Clean _process_multiple_trajectories
  • Loading branch information
akaptano committed Jul 8, 2022
2 parents 632585a + 7e02c31 commit 59d24ba
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 63 deletions.
55 changes: 7 additions & 48 deletions pysindy/pysindy.py
Expand Up @@ -620,11 +620,7 @@ def score(

def _process_multiple_trajectories(self, x, t, x_dot):
"""
Handle input data that contains multiple trajectories by doing the
necessary validation, reshaping, and computation of derivatives.
This method essentially just loops over elements of each list in parallel,
validates them, and (optionally) concatenates them together.
Calculate derivatives of input data, iterating through trajectories.
Parameters
----------
Expand All @@ -633,19 +629,16 @@ def _process_multiple_trajectories(self, x, t, x_dot):
trajectory.
t: list of np.ndarray or int
List of time points for different trajectories.
If a list of ints is passed, each entry is assumed to be the timestep
for the corresponding trajectory in x.
List of time points for different trajectories. If a list of ints
is passed, each entry is assumed to be the timestep for the
corresponding trajectory in x. If np.ndarray is passed, it is
used for each trajectory.
x_dot: list of np.ndarray
List of derivative measurements, with each entry corresponding to a
different trajectory. If None, the derivatives will be approximated
from x.
return_array: boolean, optional (default True)
Whether to return concatenated np.ndarrays.
If False, the outputs will be lists with an entry for each trajectory.
Returns
-------
x_out: np.ndarray or list
Expand All @@ -658,51 +651,17 @@ def _process_multiple_trajectories(self, x, t, x_dot):
will be an np.ndarray of concatenated trajectories.
If False, x_out will be a list.
"""
if not isinstance(x, Sequence):
raise TypeError("Input x must be a list")

if self.discrete_time:
x = [validate_input(xi) for xi in x]
if x_dot is None:
if x_dot is None:
if self.discrete_time:
x_dot = [xi[1:] for xi in x]
x = [xi[:-1] for xi in x]
else:
if not isinstance(x_dot, Sequence):
raise TypeError(
"x_dot must be a list if used with x of list type "
"(i.e. for multiple trajectories)"
)
x_dot = [validate_input(xd) for xd in x_dot]
else:
if x_dot is None:
x = [
self.feature_library.validate_input(xi, ti)
for xi, ti in _zip_like_sequence(x, t)
]
x_dot = [
self.feature_library.calc_trajectory(
self.differentiation_method, xi, ti
)
for xi, ti in _zip_like_sequence(x, t)
]
else:
if not isinstance(x_dot, Sequence):
raise TypeError(
"x_dot must be a list if used with x of list type "
"(i.e. for multiple trajectories)"
)
if isinstance(t, Sequence):
x = [
self.feature_library.validate_input(xi, ti)
for xi, ti in zip(x, t)
]
x_dot = [
self.feature_library.validate_input(xd, ti)
for xd, ti in zip(x_dot, t)
]
else:
x = [self.feature_library.validate_input(xi, t) for xi in x]
x_dot = [self.feature_library.validate_input(xd, t) for xd in x_dot]
return x, x_dot

def differentiate(self, x, t=None, multiple_trajectories=False):
Expand Down
15 changes: 0 additions & 15 deletions test/test_pysindy.py
Expand Up @@ -580,21 +580,6 @@ def test_complexity(data_lorenz):
assert model.complexity < 10


def test_multiple_trajectories_errors(data_multiple_trajctories, data_discrete_time):
x, t = data_multiple_trajctories

model = SINDy()
with pytest.raises(TypeError):
model._process_multiple_trajectories(np.array(x, dtype=object), t, x)
with pytest.raises(TypeError):
model._process_multiple_trajectories(x, t, np.array(x, dtype=object))

x = data_discrete_time
model = SINDy(discrete_time=True)
with pytest.raises(TypeError):
model._process_multiple_trajectories(x, t, np.array(x, dtype=object))


def test_simulate_errors(data_lorenz):
x, t = data_lorenz
model = SINDy()
Expand Down

0 comments on commit 59d24ba

Please sign in to comment.