diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 401992e..964fa87 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -108,6 +108,35 @@ Hyperparameter optimization Usually, the hyperparameters of a machine learning model, in particular the kernel hyperparameters of a Gaussian process regression model, should be optimized as new training data is added. However, since this is usually a computationally expensive process, it may not be desirable to perform this at every iteration of the active learning process. The iteration frequency of the hyperparameter optimization is internally set by the :code:`_should_optimize_hyperparameters` function, which by default uses a schedule that optimizes the hyperparameter every 10th iteration. This behavior can be changed by override this function. + +Reclassification schedule +............................. + +A problem of the e-PAL algorithm can be the case where the initial GPR predictions are wrong. This might lead to points that are actually Pareto-efficient being confidently discarded (in case the GPR predicts with low variance a performance that is dominated by some other point). +Under the assumption that the GPR predictions improve over the course of the use of the algorithm, this behavior can be mitigated by "resetting" all the classification, i.e. reconsidering discarded points. +In PyePAL, we automatically do this in case there is only one point left. In general, you might want to customize this behavior, which you can do, for example, by `monkey patching `_ + +.. code-block:: python + + from pyepal.pal.schedules import linear + + def my_schedule(self): + return linear(self.iteration, 1) + + PALGPyReclassify._should_reclassify = my_schedule + + +or by subclassing + +.. code-block:: python + + from pyepal.pal.schedules import linear + class PALGPyReclassify(PALGPy): + def _should_reclassify(self): + return linear(self.iteration, 1) + +In the examples above the full design space would be re-classified each iteration. + Logging ........ Basic information such as the current iteration and the classification status are logged and can be viewed by printing the :code:`PAL` object diff --git a/pyepal/pal/pal_base.py b/pyepal/pal/pal_base.py index 1627d30..9f39d40 100644 --- a/pyepal/pal/pal_base.py +++ b/pyepal/pal/pal_base.py @@ -45,7 +45,7 @@ ) PAL_LOGGER = logging.getLogger("PALLogger") -PAL_LOGGER.setLevel(logging.INFO) +PAL_LOGGER.setLevel(level="INFO") CONSOLE_HANDLER = logging.StreamHandler() CONSOLE_FORMAT = logging.Formatter("%(name)s - %(levelname)s - %(message)s") CONSOLE_HANDLER.setFormatter(CONSOLE_FORMAT) @@ -175,6 +175,13 @@ def _reset(self): self.measurement_uncertainty = np.zeros((self.number_design_points, self.ndim)) self._has_train_set = False + def _reset_classification(self): + """Resetting the mask arrays that keep track of the classifications. + But do *not* reset the sampling status""" + self.pareto_optimal = np.array([False] * self.number_design_points) + self.discarded = np.array([False] * self.number_design_points) + self.unclassified = np.array([True] * self.number_design_points) + @property def sampled_mask(self): """Create a mask for the sampled points @@ -274,6 +281,11 @@ def _log(self): def _should_optimize_hyperparameters(self) -> bool: # pylint:disable=no-self-use return True + def _should_reclassify(self) -> bool: # pylint:disable=no-self-use + if self.number_unclassified_points <= 1: + return True + return False + def _predict(self): raise NotImplementedError("The predict function is not implemented") @@ -392,6 +404,9 @@ def _update_coef_var_mask(self): def _classify(self): self._update_coef_var_mask() + if self._should_reclassify: # pylint:disable=using-constant-test + PAL_LOGGER.info("Resetting the classifications.") + self._reset_classification() if self.uses_fixed_epsilon: pareto_optimal, discarded, unclassified = _pareto_classify( self.pareto_optimal[self.coef_var_mask], diff --git a/tests/test_pal_base.py b/tests/test_pal_base.py index a9f07d5..cdb7c47 100644 --- a/tests/test_pal_base.py +++ b/tests/test_pal_base.py @@ -58,6 +58,34 @@ def test_pal_base(make_random_dataset): assert palinstance.uses_fixed_epsilon +def test_reset_classification(make_random_dataset): + """Make sure the reset of the reclassification status works as expected.""" + X, _ = make_random_dataset # pylint: disable=invalid-name + + palinstance = PALBase(X, ["model"], 3) + lows = np.zeros((100, 3)) + highs = np.zeros((100, 3)) + + means = np.full((100, 3), 1) + palinstance.means = means + palinstance.std = np.full((100, 3), 0.1) + pareto_optimal = np.array([False] * 98 + [True, True]) + sampled = np.array([[False] * 3, [False] * 3, [False] * 3, [False] * 3]) + unclassified = np.array([True] * 98 + [False, False]) + + palinstance.rectangle_lows = lows + palinstance.rectangle_ups = highs + palinstance.sampled = sampled + palinstance.pareto_optimal = pareto_optimal + palinstance.unclassified = unclassified + + palinstance._reset_classification() + + assert palinstance.number_unclassified_points == len(X) + assert palinstance.number_pareto_optimal_points == 0 + assert palinstance.number_discarded_points == 0 + + def test_update_train_set(make_random_dataset): """Check if the update of the training set works""" X, y = make_random_dataset # pylint:disable=invalid-name diff --git a/tests/test_pal_gpy.py b/tests/test_pal_gpy.py index c28f4a2..82003c8 100644 --- a/tests/test_pal_gpy.py +++ b/tests/test_pal_gpy.py @@ -14,11 +14,14 @@ # limitations under the License. """Testing the PALGPy class""" +import logging + import numpy as np import pytest from pyepal.models.gpr import build_model from pyepal.pal.pal_gpy import PALGPy +from pyepal.pal.schedules import linear def test_pal_gpy(make_random_dataset): @@ -101,6 +104,42 @@ def test_orchestration_run_one_step(make_random_dataset, binh_korn_points): assert sum(palinstance.discarded) == 0 +def test_reclassification_schedule(make_random_dataset, caplog): + """Ensure that we can patch in a re-classification schedule + as described in the docs""" + X, y = make_random_dataset # pylint:disable=invalid-name + sample_idx = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + model_0 = build_model(X[sample_idx], y[sample_idx], 0) + model_1 = build_model(X[sample_idx], y[sample_idx], 1) + model_2 = build_model(X[sample_idx], y[sample_idx], 2) + + class PALGPyReclassify(PALGPy): # pylint:disable=missing-class-docstring + def _should_reclassify(self): + return linear(self.iteration, 1) + + palinstance = PALGPyReclassify( + X, + [model_0, model_1, model_2], + 3, + beta_scale=1, + epsilon=0.01, + delta=0.01, + restarts=3, + ) + palinstance.cross_val_points = 0 + + palinstance.update_train_set(sample_idx, y[sample_idx]) + idx = palinstance.run_one_step() + assert "Resetting the classifications." in caplog.text + + palinstance.update_train_set(idx, y[idx]) + old_length = len(caplog.records) + with caplog.at_level(logging.INFO): + _ = palinstance.run_one_step() + assert "Resetting the classifications." in caplog.text + assert len(caplog.records) == 2 * old_length + + def test_orchestration_run_one_step_parallel(binh_korn_points): """Test if the parallelization works""" np.random.seed(10)