diff --git a/pysindy/differentiation/sindy_derivative.py b/pysindy/differentiation/sindy_derivative.py index 7ae11133..17ace83e 100644 --- a/pysindy/differentiation/sindy_derivative.py +++ b/pysindy/differentiation/sindy_derivative.py @@ -52,7 +52,7 @@ def set_params(self, **params): # Simple optimization to gain speed (inspect is slow) return self else: - self.kwargs.update(params) + self.kwargs.update(params["kwargs"]) return self diff --git a/test/differentiation/test_differentiation_methods.py b/test/differentiation/test_differentiation_methods.py index e24481aa..6bae3da4 100644 --- a/test/differentiation/test_differentiation_methods.py +++ b/test/differentiation/test_differentiation_methods.py @@ -262,6 +262,13 @@ def test_wrapper_equivalence_with_dxdt(data, derivative_kws): ) +def test_sindy_derivative_kwarg_update(): + method = SINDyDerivative(kind="spectral", foo=2) + method.set_params(kwargs={"kind": "spline", "foo": 1}) + assert method.kwargs["kind"] == "spline" + assert method.kwargs["foo"] == 1 + + @pytest.mark.parametrize( "data, derivative_kws", [