Skip to content

Commit

Permalink
feat: add option to schedule reclassification of all points (#177)
Browse files Browse the repository at this point in the history
* feat: add option to schedule reclassification of all points

* add docs and one more test case
  • Loading branch information
kjappelbaum committed Jun 15, 2021
1 parent 0ae381c commit 01d97a4
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 1 deletion.
29 changes: 29 additions & 0 deletions docs/getting_started.rst
Expand Up @@ -108,6 +108,35 @@ Hyperparameter optimization
Usually, the hyperparameters of a machine learning model, in particular the kernel hyperparameters of a Gaussian process regression model, should be optimized as new training data is added.
However, since this is usually a computationally expensive process, it may not be desirable to perform this at every iteration of the active learning process. The iteration frequency of the hyperparameter optimization is internally set by the :code:`_should_optimize_hyperparameters` function, which by default uses a schedule that optimizes the hyperparameter every 10th iteration. This behavior can be changed by override this function.


Reclassification schedule
.............................

A problem of the e-PAL algorithm can be the case where the initial GPR predictions are wrong. This might lead to points that are actually Pareto-efficient being confidently discarded (in case the GPR predicts with low variance a performance that is dominated by some other point).
Under the assumption that the GPR predictions improve over the course of the use of the algorithm, this behavior can be mitigated by "resetting" all the classification, i.e. reconsidering discarded points.
In PyePAL, we automatically do this in case there is only one point left. In general, you might want to customize this behavior, which you can do, for example, by `monkey patching <https://stackoverflow.com/questions/5626193/what-is-monkey-patching>`_

.. code-block:: python
from pyepal.pal.schedules import linear
def my_schedule(self):
return linear(self.iteration, 1)
PALGPyReclassify._should_reclassify = my_schedule
or by subclassing

.. code-block:: python
from pyepal.pal.schedules import linear
class PALGPyReclassify(PALGPy):
def _should_reclassify(self):
return linear(self.iteration, 1)
In the examples above the full design space would be re-classified each iteration.

Logging
........
Basic information such as the current iteration and the classification status are logged and can be viewed by printing the :code:`PAL` object
Expand Down
17 changes: 16 additions & 1 deletion pyepal/pal/pal_base.py
Expand Up @@ -45,7 +45,7 @@
)

PAL_LOGGER = logging.getLogger("PALLogger")
PAL_LOGGER.setLevel(logging.INFO)
PAL_LOGGER.setLevel(level="INFO")
CONSOLE_HANDLER = logging.StreamHandler()
CONSOLE_FORMAT = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
CONSOLE_HANDLER.setFormatter(CONSOLE_FORMAT)
Expand Down Expand Up @@ -175,6 +175,13 @@ def _reset(self):
self.measurement_uncertainty = np.zeros((self.number_design_points, self.ndim))
self._has_train_set = False

def _reset_classification(self):
"""Resetting the mask arrays that keep track of the classifications.
But do *not* reset the sampling status"""
self.pareto_optimal = np.array([False] * self.number_design_points)
self.discarded = np.array([False] * self.number_design_points)
self.unclassified = np.array([True] * self.number_design_points)

@property
def sampled_mask(self):
"""Create a mask for the sampled points
Expand Down Expand Up @@ -274,6 +281,11 @@ def _log(self):
def _should_optimize_hyperparameters(self) -> bool: # pylint:disable=no-self-use
return True

def _should_reclassify(self) -> bool: # pylint:disable=no-self-use
if self.number_unclassified_points <= 1:
return True
return False

def _predict(self):
raise NotImplementedError("The predict function is not implemented")

Expand Down Expand Up @@ -392,6 +404,9 @@ def _update_coef_var_mask(self):

def _classify(self):
self._update_coef_var_mask()
if self._should_reclassify: # pylint:disable=using-constant-test
PAL_LOGGER.info("Resetting the classifications.")
self._reset_classification()
if self.uses_fixed_epsilon:
pareto_optimal, discarded, unclassified = _pareto_classify(
self.pareto_optimal[self.coef_var_mask],
Expand Down
28 changes: 28 additions & 0 deletions tests/test_pal_base.py
Expand Up @@ -58,6 +58,34 @@ def test_pal_base(make_random_dataset):
assert palinstance.uses_fixed_epsilon


def test_reset_classification(make_random_dataset):
"""Make sure the reset of the reclassification status works as expected."""
X, _ = make_random_dataset # pylint: disable=invalid-name

palinstance = PALBase(X, ["model"], 3)
lows = np.zeros((100, 3))
highs = np.zeros((100, 3))

means = np.full((100, 3), 1)
palinstance.means = means
palinstance.std = np.full((100, 3), 0.1)
pareto_optimal = np.array([False] * 98 + [True, True])
sampled = np.array([[False] * 3, [False] * 3, [False] * 3, [False] * 3])
unclassified = np.array([True] * 98 + [False, False])

palinstance.rectangle_lows = lows
palinstance.rectangle_ups = highs
palinstance.sampled = sampled
palinstance.pareto_optimal = pareto_optimal
palinstance.unclassified = unclassified

palinstance._reset_classification()

assert palinstance.number_unclassified_points == len(X)
assert palinstance.number_pareto_optimal_points == 0
assert palinstance.number_discarded_points == 0


def test_update_train_set(make_random_dataset):
"""Check if the update of the training set works"""
X, y = make_random_dataset # pylint:disable=invalid-name
Expand Down
39 changes: 39 additions & 0 deletions tests/test_pal_gpy.py
Expand Up @@ -14,11 +14,14 @@
# limitations under the License.

"""Testing the PALGPy class"""
import logging

import numpy as np
import pytest

from pyepal.models.gpr import build_model
from pyepal.pal.pal_gpy import PALGPy
from pyepal.pal.schedules import linear


def test_pal_gpy(make_random_dataset):
Expand Down Expand Up @@ -101,6 +104,42 @@ def test_orchestration_run_one_step(make_random_dataset, binh_korn_points):
assert sum(palinstance.discarded) == 0


def test_reclassification_schedule(make_random_dataset, caplog):
"""Ensure that we can patch in a re-classification schedule
as described in the docs"""
X, y = make_random_dataset # pylint:disable=invalid-name
sample_idx = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
model_0 = build_model(X[sample_idx], y[sample_idx], 0)
model_1 = build_model(X[sample_idx], y[sample_idx], 1)
model_2 = build_model(X[sample_idx], y[sample_idx], 2)

class PALGPyReclassify(PALGPy): # pylint:disable=missing-class-docstring
def _should_reclassify(self):
return linear(self.iteration, 1)

palinstance = PALGPyReclassify(
X,
[model_0, model_1, model_2],
3,
beta_scale=1,
epsilon=0.01,
delta=0.01,
restarts=3,
)
palinstance.cross_val_points = 0

palinstance.update_train_set(sample_idx, y[sample_idx])
idx = palinstance.run_one_step()
assert "Resetting the classifications." in caplog.text

palinstance.update_train_set(idx, y[idx])
old_length = len(caplog.records)
with caplog.at_level(logging.INFO):
_ = palinstance.run_one_step()
assert "Resetting the classifications." in caplog.text
assert len(caplog.records) == 2 * old_length


def test_orchestration_run_one_step_parallel(binh_korn_points):
"""Test if the parallelization works"""
np.random.seed(10)
Expand Down

0 comments on commit 01d97a4

Please sign in to comment.