Skip to content

Commit

Permalink
Made a fix for error occuring when generalized library is used with t…
Browse files Browse the repository at this point in the history
…ensor arrays and ensembling.
  • Loading branch information
Alan Kaptanoglu authored and Alan Kaptanoglu committed Jan 27, 2022
1 parent c92f03e commit 9ffc92f
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions pysindy/feature_library/generalized_library.py
Expand Up @@ -163,6 +163,7 @@ def __init__(
)
self.tensor_array_ = tensor_array
self.inputs_per_library_ = inputs_per_library
self.libraries_full_ = self.libraries_

def fit(self, x, y=None):
"""
Expand Down Expand Up @@ -227,7 +228,7 @@ def fit(self, x, y=None):
self.n_output_features_ = sum([lib.n_output_features_ for lib in fitted_libs])

# Save fitted libs
self.libraries_ = fitted_libs
self.libraries_full_ = fitted_libs

return self

Expand All @@ -246,7 +247,7 @@ def transform(self, x):
generated from applying the custom functions to the inputs.
"""
for lib in self.libraries_:
for lib in self.libraries_full_:
check_is_fitted(lib)

n_samples, n_features = x.shape
Expand All @@ -262,7 +263,7 @@ def transform(self, x):
xp = np.zeros((n_samples, self.n_output_features_))

current_feat = 0
for i, lib in enumerate(self.libraries_):
for i, lib in enumerate(self.libraries_full_):

# retrieve num output features from lib
lib_n_output_features = lib.n_output_features_
Expand Down Expand Up @@ -296,7 +297,7 @@ def get_feature_names(self, input_features=None):
output_feature_names : list of string, length n_output_features
"""
feature_names = list()
for i, lib in enumerate(self.libraries_):
for i, lib in enumerate(self.libraries_full_):
if i < self.inputs_per_library_.shape[0]:
if input_features is None:
input_features_i = [
Expand Down

0 comments on commit 9ffc92f

Please sign in to comment.