From 9ffc92fa6ad61d000a19f29ab438cfc3f127fe0a Mon Sep 17 00:00:00 2001 From: Alan Kaptanoglu Date: Thu, 27 Jan 2022 09:50:13 -0800 Subject: [PATCH] Made a fix for error occuring when generalized library is used with tensor arrays and ensembling. --- pysindy/feature_library/generalized_library.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pysindy/feature_library/generalized_library.py b/pysindy/feature_library/generalized_library.py index 04fc9e67..79758a0b 100644 --- a/pysindy/feature_library/generalized_library.py +++ b/pysindy/feature_library/generalized_library.py @@ -163,6 +163,7 @@ def __init__( ) self.tensor_array_ = tensor_array self.inputs_per_library_ = inputs_per_library + self.libraries_full_ = self.libraries_ def fit(self, x, y=None): """ @@ -227,7 +228,7 @@ def fit(self, x, y=None): self.n_output_features_ = sum([lib.n_output_features_ for lib in fitted_libs]) # Save fitted libs - self.libraries_ = fitted_libs + self.libraries_full_ = fitted_libs return self @@ -246,7 +247,7 @@ def transform(self, x): generated from applying the custom functions to the inputs. """ - for lib in self.libraries_: + for lib in self.libraries_full_: check_is_fitted(lib) n_samples, n_features = x.shape @@ -262,7 +263,7 @@ def transform(self, x): xp = np.zeros((n_samples, self.n_output_features_)) current_feat = 0 - for i, lib in enumerate(self.libraries_): + for i, lib in enumerate(self.libraries_full_): # retrieve num output features from lib lib_n_output_features = lib.n_output_features_ @@ -296,7 +297,7 @@ def get_feature_names(self, input_features=None): output_feature_names : list of string, length n_output_features """ feature_names = list() - for i, lib in enumerate(self.libraries_): + for i, lib in enumerate(self.libraries_full_): if i < self.inputs_per_library_.shape[0]: if input_features is None: input_features_i = [