From 69a8becf01d2009de8c6138d559a6028c4264c75 Mon Sep 17 00:00:00 2001 From: Kevin Jablonka <32935233+kjappelbaum@users.noreply.github.com> Date: Tue, 15 Jun 2021 19:04:18 +0200 Subject: [PATCH] feat: add basic GPflow support (#178) * feat: add basic GPflow support * feat: add basic GPflow support --- .gitignore | 1 + dev/play_w_gpflow.ipynb | 477 ++++++++++++++++++++++++++++++++++++ docs/api.rst | 8 + pyepal/__init__.py | 2 + pyepal/pal/pal_gpflowgpr.py | 130 ++++++++++ setup.py | 7 +- tests/test_pal_gpflowgpr.py | 59 +++++ 7 files changed, 683 insertions(+), 1 deletion(-) create mode 100644 dev/play_w_gpflow.ipynb create mode 100644 pyepal/pal/pal_gpflowgpr.py create mode 100644 tests/test_pal_gpflowgpr.py diff --git a/.gitignore b/.gitignore index 30dfc02..419da7b 100644 --- a/.gitignore +++ b/.gitignore @@ -539,3 +539,4 @@ MigrationBackup/ *.pkl *.npy *.joblib +dev/ diff --git a/dev/play_w_gpflow.ipynb b/dev/play_w_gpflow.ipynb new file mode 100644 index 0000000..1bad416 --- /dev/null +++ b/dev/play_w_gpflow.ipynb @@ -0,0 +1,477 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 11, + "id": "7d0a37ee-7e5c-41fe-881f-34d92f7b584d", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c41cfbc3-2832-4741-a55a-c3206805bf16", + "metadata": {}, + "outputs": [], + "source": [ + "import gpflow\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import tensorflow as tf\n", + "from gpflow.utilities import print_summary\n", + "\n", + "# The lines below are specific to the notebook format\n", + "%matplotlib inline\n", + "\n", + "plt.style.use('ggplot')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2e4f3026-0cc2-4a4c-b8f1-d9d43b1054f0", + "metadata": {}, + "outputs": [], + "source": [ + "def binh_korn(x, y): # pylint:disable=invalid-name\n", + " \"\"\"https://en.wikipedia.org/wiki/Test_functions_for_optimization\"\"\"\n", + " obj1 = 4 * x ** 2 + 4 * y ** 2\n", + " obj2 = (x - 5) ** 2 + (y - 5) ** 2\n", + " return -obj1, -obj2\n", + "\n", + "\n", + "def binh_korn_points():\n", + " \"\"\"Create a dataset based on the Binh-Korn test function\"\"\"\n", + " x = np.linspace(0, 5, 100) # pylint:disable=invalid-name\n", + " y = np.linspace(0, 3, 100) # pylint:disable=invalid-name\n", + " array = np.array([binh_korn(xi, yi) for xi, yi in zip(x, y)])\n", + " return np.hstack([x.reshape(-1, 1), y.reshape(-1, 1)]), array" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "1c868223-b4fb-4e5d-9b87-ff661c087055", + "metadata": {}, + "outputs": [], + "source": [ + "x, points = binh_korn_points()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c95ebb69-50fb-49f1-9a78-a895d8456b1b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0, 0.5, 'objective 2')" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(points[:,0], points[:,1])\n", + "plt.xlabel('objective 1')\n", + "plt.ylabel('objective 2')" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "0dab5f4c-225d-42e4-8664-caeec9c82713", + "metadata": {}, + "outputs": [], + "source": [ + "x = (x - x.mean()) / x.std()\n", + "points = (points - points.mean()) / points.std()" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "ebc93fff-eb5c-4e62-a6fd-3db8d79aef62", + "metadata": {}, + "outputs": [], + "source": [ + "indices = np.random.choice(range(len(x)), 20)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "8a146c77-ce03-4402-9734-50a98bf11e12", + "metadata": {}, + "outputs": [], + "source": [ + "k = gpflow.kernels.RationalQuadratic()" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "df30b95c-2298-4f3b-bf11-d1d76e72915c", + "metadata": {}, + "outputs": [], + "source": [ + "m = gpflow.models.GPR(data=(x[indices], points[indices]), kernel=k, mean_function=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "90c39dbc-4dd2-4faf-8d2d-cac21546fe6e", + "metadata": {}, + "outputs": [], + "source": [ + "opt = gpflow.optimizers.Scipy()" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "695e9038-c319-4ace-9cb4-6800c929993a", + "metadata": {}, + "outputs": [], + "source": [ + "opt_logs = opt.minimize(m.training_loss, m.trainable_variables, options=dict(maxiter=10000))" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "9ea2a0fe-0202-43e0-a772-7e31f0bf7f6b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + " fun: -172.66386574809877\n", + " hess_inv: <4x4 LbfgsInvHessProduct with dtype=float64>\n", + " jac: array([-2.81214986e-06, -1.56898375e-04, 3.27001326e-05, 1.13801794e-02])\n", + " message: 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'\n", + " nfev: 92\n", + " nit: 9\n", + " njev: 92\n", + " status: 0\n", + " success: True\n", + " x: array([4571.19693555, 26.02637078, 5668.52441827, -21.0171565 ])" + ] + }, + "execution_count": 65, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "opt_logs" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "0d907d93-cf78-4320-a59d-93a78b3777c4", + "metadata": {}, + "outputs": [], + "source": [ + "mean, var = m.predict_f(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "fe77c5c5-59e9-47c4-bc10-6637dd8d866a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAntklEQVR4nO3de2BU9Z338fdccr/PTDK5QmAAAbkJQSEq12zt1qpZ21pv7XbxqU+ri5Rutd7QtpQ227WlrXW3tUvZWt0WrRVbb7UBAUukBjFcBCHhnguE3BNym5kzzx9UnlIMiUwmZ5L5vP7izPnl/L7zdfxk8psz51gCgUAAEREZ8axmFyAiIkNDgS8iEiEU+CIiEUKBLyISIRT4IiIRQoEvIhIh7GYXcCG1tbUhO7bL5aKhoSFkxx8p1KeBUZ/6px4NTLB9ys7O7nOf3uGLiEQIBb6ISIRQ4IuIRAgFvohIhFDgi4hECAW+iEiEUOCLiESIsD4P/2L09vTy6/V/JjXWTkpCDKlJcaSmJpHqTCEpLQWbzWZ2iSIiphhxgd/e3MqLvW78Pht0ACcBeoB6rIE6kn1dpBldOKx+Umw+0qKtpMbZcSTF4khNxOFIIc3tJCY21twnIiIyyEZc4Dsz0/ntbU46WttpaWyhtaWdlrZOWjp6aOny0hII0BKw0BKI5rAvgVbiz/xyaAdq4cw/2kn0dZHm78Rh6cVp9+OMseJIiMaVHI/TmYwzw0GyIxWrVatiIjI8jLjAB7BarSSnpZCcltLnmA++vuz3+2lvbqW5oZmmpnaa2jppOt1Lc7efRsNCc8DOcV8iLZYEDK8VWoBjPqCeaH8NTl8HTnpw2f04YyxkJEWTnppIenoqrqx04hMThuppi4hc0IgM/I/CZrOR6nKQ6nIw5gLjfF4fLQ1NNJ5qprGpjYa2LhpO+2gMBGjw29njT6TJm4jRbINm4HA3cJxEXyfp/tOkW7y4YwKkJ0aRmZpARkYqGdluEpITh+iZikiki/jAHyh7lB1XVgaurIw+x/h8PppPNnKqvpFTje2cauvmVJePesNCXSCGnb4ketqjz6waHffCO9UkeTtxGx24rV7ccRbcSTFkuZLIzE7HmZmO3a7/RCIyOJQmg8hut5Oe4yY9x/2h+w3DoK2phfraBuobWzjZ0snJ0z5O9to4ZMTzl94kfM32M38hVLZiNxpxe9twW7rJigmQnRRNtjOJrJx00nPc+mUgIh+JEmMIWa3Ws8tHEz5kv8/no/HEKU7WNVB7qo0TbT3U+QOcNKLZ60umuy0G2oDD7diNZtzeNrIs3eTEQk5qDDkZqeSMyiTFmaYPk0XkPAr8MGK323HnZuHOzWLa3+0zDIOWU03UVp+gtr6V2tYzvwxqjVh2+pLxNkVBE/D+KRJ9x8jxt5ET5SU3MYo8VxK5eRlk5GbqrwKRCKb/+4cJq9WKw+3C4XYx5e/2+Xw+GmpOUlNziuqGNmraeqnx29nhT2VjZyIcA451EO3fR46vhTxbL3mJVkalJzFqVCbuvCx9IU0kAijwRwC73U7m6BwyR+cw6+/2tTe3UX20luMnmqhu7uKY38JeI5Et3SlwHDjeQYz/PfJ8LYyK6mV0cjSj3Snke3JIS3eZ8XREJEQU+CNcUloyk9KSmfR3j59u66D6cDVHaxs51tzNUZ/1zF8EHYlnvqF8sIFU7xHG0MGoGIN8ZzxjRmWQOzaPqOgoM56KiARpUAK/oqKCtWvXYhgGixcvpri4+Jz9DQ0NPPHEE5w+fRrDMLj11luZOXPmYEwtFykhOZFLpk/kkunnPt7S0MTRQ9UcqWvhSKuXo/4YXvGl4D0VBad6iSp/n1HeFsZE9TImLRpPjosxE/KJTYgz54mIyIAFHfiGYbBmzRoefvhhnE4nDzzwAAUFBeTm5p4d8/zzzzN37lw+9rGPUV1dzXe/+10Ffpj64CyiD34PuFwuTtSdoOZwNYeOnuBwQyeHvRa2GWmUtsZDK1jfO0xObzMeezdjU6MYl+ti7CX5xCXEm/pcRORcQQd+VVUVmZmZuN1nzj0vLCykvLz8nMC3WCx0dnYC0NnZSVpaWrDTyhCyR9kZPSGf0RPyWfjXxwzDoKGunsMHa6g60cYhb4CdRjKb2pNgH1j3HiGvtwlPVDfj02IYP9rN6An5RMdEm/pcRCKZJRAIBII5wLZt26ioqOBLX/oSAFu2bKGyspI77rjj7Jjm5ma+/e1vc/r0aXp6elixYgVjx44971ilpaWUlpYCUFJSQm9vbzClXZDdbsfn84Xs+CPFR+3TyeoT7H2vin3HGni/1UclybRFnXmnH2V48fiamZQQYHJOGlOneMgZkzcivjOg11P/1KOBCbZP0dF9v6kakg9tt27dyoIFC7juuus4cOAAjz/+ON///vfP+x+9qKiIoqKis9sNDQ0hq+mDi6fJhX3UPtli7UydNZGpfz1dyDAM6mtOUFlZzYGT7VR67bzcncYLR6LhyHFSvfuYEGhjYqqVS0alM27yWGLjht/nAXo99U89Gphg+5Sdnd3nvqAD3+Fw0NjYeHa7sbERh8NxzpiNGzfy4IMPAjBhwgS8Xi/t7e2kpPR9NUsZGaxWK5l52WTmZXP1Xx/z9no5VnWU/Yfq2N/Uy35fPG93psH7YN97kDHeJibHeZmUlcLES8eQlu409TmIjBRBB77H46Guro76+nocDgdlZWXcc88954xxuVzs2bOHBQsWUF1djdfrJTk5OdipZZiKio7CM3kcnsnj+MRfH2s+1ciBfUfYV9vK+14br/gyeLEmCmpOkd1zgMlRnVzqTmDypFFk5GSOiGUgkaEWdODbbDaWLFnCqlWrMAyDhQsXkpeXx7p16/B4PBQUFPD5z3+en/3sZ7z88ssA3HXXXVgslqCLl5EjLd3JFelOrvjrdm93Lwf3HWTvkZPs7TV4y3BQeioOTrXh6q3mUls7U9JjmTJpNJmjsvULQGQAgv7QNpRqa2tDdmytJw5MuPTJ7/dzrPII71XV8l5DL++RQmvUmXsJuHrbmGJrZ2pGPNOmjCEjN3PI6wuXPoUz9WhgwnoNX2Qo2Gw2xkz0MGaih09y5sPg6kPH2f3+cfY09LDDSGVTYwJsbiGr5zBTozuZlp3EtGnjSXHpNGARUODLMGW1Whk1bjSjxo3mWj74C+Aou/ZXs6vXx5t+J6/XxWKprWNs7x6mJ3iZMcbFpGkTiY7VdwEkMinwZUQ48xfAWMZMHMsNnLklZeV7VVRUnWRnr5UXe938rtJGzPvvM9XfwGWuKGZOHUN2fm6/xxYZKRT4MiLZo+xMmjGRSTMmcgvQ2d7B7ooDvHusmXd9CWxvS4WtHWRu3MasmNPMGuNkysyJxMTGml26SMgo8CUixCclcsXVM8+eBVR7pJodew7zzikvf/Kn8/LhaKKrDjDN38CsjGhmzxjf560qRYYrBb5EpOz8XLLzc/kk0N3VxZ4d+3nnSCPv+BPY3pLKzzY1M7bnALOTfMyenItnkkenfsqwp8CXiBcbF0fBlTMouPKvZ/8cPMbbu4/ydg882+NmXYUf19vbuTy6nSvGpXPpZZN0TwAZlhT4In/DarUyanw+o8bn82nO3B9g+/b9/OVEF6X+dF6piibh/T3MtjQxNz+FGbMvHZbX/pHIpMAXuYBUl4Oij8+lCOg+3cW72/ey7Wgz2wNONh2PI/bIAWYZDSy+JIPJ0zy6B4CENQW+yADFJsQxd/4s5nLmAnB7duylrKqBvxgpbD0URUxlFQXGKa4cncysy6foLmASdhT4IhchKjqKy+ZM57I58H99Po5VVvPHHTW8ZaSwtTaR2OcPMDvQwLyxacy4fIpu/CJhQYEvEiS73c7lVxcwdlI+X/T52PvuPt7c30CZP403j8WTeOg95libWDDJzeTLJmGz2cwuWSKUAl9kENntdqbNnsq02fDFnl52lu9hy8Fm/hxwUbrfjmv3O8yL72DhzLGMGp9vdrkSYRT4IiESHRPN7KtmMvsq6Drdydtv7WbT8dOs97r53dvdjH3zTRa5YF7hFF3gTYaEAl9kCMQlxDO/6Armc+ZmL2++9R5v9Fj47/Z0/ufVGmb5d7B4bCqz5k7HHqX/LSU09MoSGWJp6U6uv34e1wOH3z/EG+/WsMmfwl+OJ5B26F0WxrVTdPl4csbkmV2qjDAKfBETfXCFz8/1enln207+dLD9zJJP2WkufWMT14yKY+5Vl+mSzjIoFPgiYSAqOoo58wqYMw8aT9Szcete/tQbyw9OppD0mz0sjGnmmismkDtW7/rl4inwRcKMMzODz3wqgxv9fnaXv8dr7zfzsi+T3791mmmb3uATnmRmF2qtXz66QXnFVFRUsHbtWgzDYPHixRQXF583pqysjOeeew6LxcLo0aNZtmzZYEwtMmLZbDZmzJnGjDnQdLKBP/15D6/3JlBSnYDz6R1ck9zJNfOnkepymF2qDBNBB75hGKxZs4aHH34Yp9PJAw88QEFBAbm5//9OQnV1daxfv56VK1eSmJhIa2trsNOKRBSH28VnP7WAT3l9bH+rgperTvO/3Vk892oNVwXe5ZOzRjPu0nFmlylhLujAr6qqIjMzE7f7zM0iCgsLKS8vPyfwN2zYwDXXXENiYiIAKSkpwU4rEpHsUfaza/3Hqo7y8l9q2OR38UaFj8l/2cT1niRmXzkdu13LPXK+oF8VTU1NOJ3Os9tOp5PKyspzxtTW1gKwYsUKDMPgM5/5DDNmzDjvWKWlpZSWlgJQUlKCy+UKtrw+2e32kB5/pFCfBsaMPrlcLmbOmcU9TS2sf+lNXjgZR0l1Apm/2s6ncqwUXzeP+KTEIa3pQvRaGphQ9mlI3gYYhkFdXR2PPvooTU1NPProozz22GMkJCScM66oqIiioqKz2w0NDSGryeVyhfT4I4X6NDBm9+naT8zlGq+PbX/ewYuHe3niVAa/fLKMT8S1cu2i6WGxzm92j4aLYPuUnZ3d576gA9/hcNDY2Hh2u7GxEYfDcd6Y8ePHY7fbycjIICsri7q6OsaN05qjyGCxR9m5auHlXLUQ9r67lxd2tvCsN5v1r9SwyPYu/zRvEpl5fYeBjHxB36TT4/FQV1dHfX09Pp+PsrIyCgoKzhlz+eWX89577wHQ1tZGXV3d2TV/ERl8ky+bzENfWMRProhnnqWeUsPNlzc3s/qpDRyrPGJ2eWKSoN/h22w2lixZwqpVqzAMg4ULF5KXl8e6devweDwUFBQwffp0du7cyfLly7Fardx+++0kJSUNRv0icgF540axdNwobq45yYub9/C6kc6mt7uZs3UDn70in7GTPGaXKEPIEggEAmYX0ZcPPuwNBa0nDoz6NDDDpU+tDc38YeO7vNzloNMey+XeWj47e9SQnNI5XHpktlCu4Qe9pCMiw0eKK43bb1rEkzeM5eaYOt6zpPFvFT5Wrd3IoX0HzS5PQkyBLxKBklKTueXTC3myeBy3xNSxx+pg+Q4vJf+zgWNVR80uT0JEgS8SwRJTkrj50wt58vqxfCaqlgqLk2XbTvOjpzZw4njollTFHAp8ESEpLZnbb1rEz64dzSdtdbxpyeDuTU08+b8baGloMrs8GSQKfBE5K8WVxh23LOa/FmWwkBO8amTypZeP8ezzb9B9usvs8iRICnwROU96jpt//VwRP5qTyNRAE890Z/HlZ/dQ+tpb+P1+s8uTi6TAF5E+jRo3moe+sIjvTDJwBrp4vDGNr/2yjN3b95hdmlwEBb6I9OvSmZP593++kuXuVtosMTy83853/0cf7A43uoaqiAyIzWZjQdEVzOnq4sWXt/G8P50dbzRyQ+z7fPraucQmxJldovRD7/BF5COJjYvjs59eyBOL3VwRqOc5bzZ3P7uLtza/g2EYZpcnF6DAF5GLkp7j5mv/vJhVE/0kBLyUVCew8pebqD1SbXZp0gcFvogEZcqsS/n+7VfwL4n17LU5WbalmWeff4Penl6zS5O/o8AXkaBFRUdRfMM8nviHLAqMep7pzuKrz5Sz9929Zpcmf0OBLyKDxpWVwde/sJiHRnXSZbHzwF4r//XMBk63dZhdmqDAF5EQuPzqmfzkpqlcZ6nhj4Es7nl+L5v/tNXssiKeAl9EQiIuIZ7/c+tiSqZYiQt4eXBvgB89tYH2ljazS4tYCnwRCamJMybyg1tnc3NcPZusmSx74X12lFWYXVZEUuCLSMhFx0az9M4b+fdpUcQFfHzzcCz/9cwGuk53ml1aRFHgi8iQmTBtAj+4dSY3WM+s7S9ft5P9u/abXVbEGJTAr6ioYNmyZSxdupT169f3OW7btm3cdNNNHDyoW6mJRKqY2FiW3LKYb13ix2uxcf8uH+ue34TP5zO7tBEv6MA3DIM1a9bw4IMPsnr1arZu3Up19fnftOvq6uLVV19l/PjxwU4pIiPAtNlT+eE/TeRK/wn+tzuTR5/6M6dqTppd1ogWdOBXVVWRmZmJ2+3GbrdTWFhIeXn5eePWrVvHDTfcQFRUVLBTisgIkZSazFc/t5Clrmaq7Gl8pbSGt9/cYXZZI1bQV8tsamrC6XSe3XY6nVRWVp4z5tChQzQ0NDBz5kx+//vf93ms0tJSSktLASgpKcHlcgVbXp/sdntIjz9SqE8Doz7170I9uvm265iz/xArfr+LVcdcfOq5LSxd8kmiYqKHuErzhfK1FPLLIxuGwVNPPcVdd93V79iioiKKiorObjc0NISsLpfLFdLjjxTq08CoT/3rr0eJzmRKbingF89v5fneHPb86A/ce81E0nPcQ1il+YJ9LWVnZ/e5L+glHYfDQWNj49ntxsZGHA7H2e3u7m6OHz/ON7/5Te6++24qKyv53ve+pw9uReQ8MbGxfPm2xXwtq42j9lS++qdqdr69y+yyRoygA9/j8VBXV0d9fT0+n4+ysjIKCgrO7o+Pj2fNmjU88cQTPPHEE4wfP5777rsPj8cT7NQiMkJdvehyHitMI9no4RsH7Pz2d5t1rf1BEHTg22w2lixZwqpVq1i+fDlz584lLy+PdevWsX379sGoUUQiUN64UfzHTdOZ66vjV11uHnvqDX1RK0iWQCAQMLuIvtTWhu5+mVpzHRj1aWDUp/5dbI8Mw+B369/k6c50RvU28eA/jCUzr+916uEurNfwRURCyWq18ukb57NibC+nbIncu6FO19m/SAp8ERkWZhXO4D+ucpBg9LJij8GGP75ldknDjgJfRIaNXM8ovvdPk5nkPcWPG9J45tk39GHuR6DAF5FhJTkthUdvv4pFRjXPerP40dO6f+5AKfBFZNiJio5i6W2LuDWmjk22HL71zFbdRnEAFPgiMixZrVY+++mFLEtvYW9UBg/99l2aTupMqQtR4IvIsLboY3N4yOOl1p7C/a9UUXvk/Kv1yhkKfBEZ9mYVzmDljBhOW6N5YNMJjuw/bHZJYUmBLyIjwiXTLuE7cx1YAwEe2tbC/p3vm11S2FHgi8iIMXpCPt9dlE2C0cOjFd28t0Nf0PpbCnwRGVEyR+fwnY+PJc3fxbf2eNlVvtvsksKGAl9ERhxXVgarrh1Puq+DlfsC7CrfY3ZJYUGBLyIjksPtYuV1E3H72vn2PoPd2xX6CnwRGbHS0p2svO4S0n0dfHuvP+IvuqbAF5ERLS3dxcprJ5Dm62Tlrl6q3qsyuyTTKPBFZMRzuF1862NjiDd6+GZ5K8eqjppdkikU+CISETJyM1m5MAcrAb7x5knqa06YXdKQU+CLSMTIzs/lG3PS6LJG840/HqK1qcXskoaUAl9EIsqYiR4emmKnPiqJb6/fRffpLrNLGjIKfBGJOFNmTeHf8nqojHbxg+fK8Pl8Zpc0JOyDcZCKigrWrl2LYRgsXryY4uLic/a/9NJLbNiwAZvNRnJyMl/+8pdJT08fjKlFRC7K3AUFLPn9Fta057D22c188dbFZpcUckG/wzcMgzVr1vDggw+yevVqtm7dSnX1uZcnzc/Pp6SkhMcee4w5c+bw9NNPBzutiEjQrr9+HtdZangpkMMrL71pdjkhF3TgV1VVkZmZidvtxm63U1hYSHl5+TljpkyZQkxMDADjx4+nqakp2GlFRAbFF26az6zeWn7e4uDdbTvNLiekgl7SaWpqwul0nt12Op1UVlb2OX7jxo3MmDHjQ/eVlpZSWloKQElJCS6XK9jy+mS320N6/JFCfRoY9al/4dyj79z5ce588g3+Y388T45vI/+SsabVEso+Dcoa/kBt2bKFQ4cO8Y1vfOND9xcVFVFUVHR2u6EhdLcrc7lcIT3+SKE+DYz61L9w79GDRWP42oYT3P/ibr73GSvxSYmm1BFsn7Kzs/vcF/SSjsPhoLGx8ex2Y2MjDofjvHG7du3ihRde4L777iMqKirYaUVEBlVmXjb3TrZTE53K6t++jWEYZpc06IIOfI/HQ11dHfX19fh8PsrKyigoKDhnzOHDh/n5z3/OfffdR0pKSrBTioiExPTLp/GFpEbejs7mhRe3mF3OoAt6Scdms7FkyRJWrVqFYRgsXLiQvLw81q1bh8fjoaCggKeffpru7m5+8IMfAGf+ZPn6178edPEiIoPtuuuuYv9Tb/D06SzGl+9h2uwpZpc0aCyBQCBgdhF9qa2tDdmxw309MVyoTwOjPvVvOPWos72Drz23i9PWaFb/Yz4O99B92BzWa/giIiNNfFIi912ZRac1hh++shu/3292SYNCgS8i8iHyLxnDHa5WdkZn8bv1I+NLWQp8EZE+fOzjhVzpreHXnelU7j5gdjlBU+CLiPTBarXy5Rtmk+Y7zQ/KG+k63Wl2SUFR4IuIXEBSWjLLpsRTF53C2hfKzC4nKAp8EZF+TJs9hettdfzRkjusr7ejwBcRGYBbiwvJ6Wni8X09dLS2m13ORVHgi4gMQGxcHMtmOWiOSuSXf3jb7HIuigJfRGSALpk+ketsdbxuyWH39j1ml/ORKfBFRD6CW68vxN3bwhO72unp7ja7nI9EgS8i8hHEJsRx16R46mLSeP4Pb5ldzkeiwBcR+YhmzJnG1b4anu9Op+bwcbPLGTAFvojIRfiXa6YRbfh48o2qYXPtfAW+iMhFcGamc3NqGxUxWZT/+V2zyxkQBb6IyEX6xD/OJbeniV9Ueent7jW7nH4p8EVELlJUdBR3TErgREwqf3g1/C+7oMAXEQnCzLnTmdlby2/bUmltajG7nAtS4IuIBOkLV46l2xbNutfeMbuUC1Lgi4gEafSEfBYHavmj382JozVml9OnoG9iDlBRUcHatWsxDIPFixdTXFx8zn6v18tPfvITDh06RFJSEl/5ylfIyMgYjKlFRMLCzUXT2FR6kt9sfp+vfD7H7HI+VNDv8A3DYM2aNTz44IOsXr2arVu3Ul1dfc6YjRs3kpCQwOOPP861117LM888E+y0IiJhxZWVwSfs9Wy2ZnGs8ojZ5XyooAO/qqqKzMxM3G43drudwsJCysvLzxmzfft2FixYAMCcOXPYs2cPgUAg2KlFRMLKp/5hJtGGl3VlB80u5UMFvaTT1NSE0+k8u+10OqmsrOxzjM1mIz4+nvb2dpKTk88ZV1paSmlpKQAlJSW4XK5gy+uT3W4P6fFHCvVpYNSn/kVCj1wuF9dv3s1zXVl88WQT4y6d8JGPEco+Dcoa/mApKiqiqKjo7HZDQ0PI5nK5XCE9/kihPg2M+tS/SOnRJ+dP5fevHOO/X3uXr7kdH/nng+1TdnZ2n/uCXtJxOBw0Njae3W5sbMThcPQ5xu/309nZSVJSUrBTi4iEnRRXGh+PbmCrLYvaI9X9/8AQCjrwPR4PdXV11NfX4/P5KCsro6Cg4Jwxs2bNYtOmTQBs27aNSy+9FIvFEuzUIiJh6foFU7EGDF74836zSzlH0IFvs9lYsmQJq1atYvny5cydO5e8vDzWrVvH9u3bAVi0aBEdHR0sXbqUl156idtuuy3owkVEwpUzM4OFnOCNQAbNp8JnGcsSCOPTZWpra0N27EhZTwyW+jQw6lP/Iq1H1QePcfe2Tj4bXcetn1k44J8L6zV8ERE5X65nFAW9tbx6OjlsboWowBcRCZHrL3XRFpXA5k07zC4FUOCLiITM1IIp5PU08mqtPyzuiqXAFxEJEavVyicyDA7FpHNg1wGzy1Hgi4iE0oIFM4nzdfPaLvOvoqnAFxEJofjEBK62NbKVdNpb2kytRYEvIhJiH5ueR68tmi1/3mVqHQp8EZEQ81w6jvyeBjbUm/u1JwW+iEiIWa1WFjkNDsakm3qtfAW+iMgQmFc4GWvAz8Z3DplWgwJfRGQIpKW7uMx7kje7EvD7/abUoMAXERkiV+XG0xCdYto5+Qp8EZEhcsUVU4j2e9myN3QXhrwQBb6IyBBJSE7kMqOet7zJpizrKPBFRIbQ3JwEmqOSOLC7sv/Bg0yBLyIyhApmT8Zu+Hhr39BfakGBLyIyhJJSk5niO8XbXfFDPrcCX0RkiM1Ot1MXk0b1oeNDOq8CX0RkiBVM9wCwfffhIZ3XHswPd3R0sHr1ak6dOkV6ejrLly8nMTHxnDFHjhzh5z//OV1dXVitVm688UYKCwuDKlpEZDjLzMsmr+cw7/T4KR7CeYMK/PXr1zN16lSKi4tZv34969ev5/bbbz9nTHR0NP/6r/9KVlYWTU1N3H///UyfPp2EhISgChcRGc5mxHbzmj+D7q4uYuPihmTOoJZ0ysvLmT9/PgDz58+nvLz8vDHZ2dlkZWUB4HA4SElJoa3N3GtCi4iY7bJ8J15rFHsrhu5bt0G9w29tbSUtLQ2A1NRUWltbLzi+qqoKn8+H2+3+0P2lpaWUlpYCUFJSgsvlCqa8C7Lb7SE9/kihPg2M+tQ/9ehc8xbNxf7zcvZUt/Kxv+lLKPvUb+CvXLmSlpaW8x6/+eabz9m2WCxYLJY+j9Pc3Mzjjz/O3XffjdX64X9YFBUVUVRUdHa7oaGhv/IumsvlCunxRwr1aWDUp/6pR+eb4G1gp9d6Tl+C7VN2dnaf+/oN/BUrVvS5LyUlhebmZtLS0mhubiY5OflDx3V2dlJSUsItt9zChAkTBlCyiMjINyUxwG97XZxu6yAhObH/HwhSUGv4BQUFbN68GYDNmzcze/bs88b4fD4ee+wx5s2bx5w5c4KZTkRkRJmS78SwWNm3u2pI5gsq8IuLi9m1axf33HMPu3fvpri4GICDBw/y05/+FICysjL27dvHpk2buPfee7n33ns5cuRIsHWLiAx7Ey4dhzXgZ19105DMF9SHtklJSTzyyCPnPe7xePB4znyxYN68ecybNy+YaURERqS4hHjG9Dbxfu/QfAdW37QVETHRxNheKu0OfF5fyOdS4IuImGh8egI9tmiOHzwW8rkU+CIiJhrvyQGg8siJkM+lwBcRMVHW6Gzifd1UNXSGfK6gPrQVEZHg2Gw2xvpbOOyPDvlceocvImKy/Fg/R+2p+Hyh/eBWgS8iYrIxjnh6bNGcOFob0nkU+CIiJhudd+ZiaceOnwzpPAp8ERGT5eafOVPneENHSOfRh7YiIiaLS4gno7eV4z4jpPPoHb6ISBjI5TTVIT5TR4EvIhIGsmIC1NmTMYzQvctX4IuIhIHspGi6bTGcqq0P2RwKfBGRMJDtTALg+NGakM2hwBcRCQMZbgcAdfXNIZtDgS8iEgZc2ekA1DWH7tRMBb6ISBiIjYsj1dtBXUfoLq+gwBcRCRNOo5PG3tAdX4EvIhImnFYfDYGokB0/qG/adnR0sHr1ak6dOkV6ejrLly8nMTHxQ8d2dnby1a9+ldmzZ3PHHXcEM62IyIjkiAqw1xcXsuMH9Q5//fr1TJ06lR//+MdMnTqV9evX9zl23bp1TJo0KZjpRERGtLQYKx32eHp7QrOuE1Tgl5eXM3/+fADmz59PeXn5h447dOgQra2tTJ8+PZjpRERGtJS4M8s57c0tITl+UEs6ra2tpKWlAZCamkpra+t5YwzD4KmnnmLp0qXs3r37gscrLS2ltLQUgJKSElwuVzDlXZDdbg/p8UcK9Wlg1Kf+qUf9y0pPgxbw+wIh6VW/gb9y5UpaWlrOe/zmm28+Z9tisWCxWM4b9/rrr3PZZZfhdDr7LaaoqIiioqKz2w0NDf3+zMVyuVwhPf5IoT4NjPrUP/Wof1H2MxlaU32CjFz3RR0jOzu7z339Bv6KFSv63JeSkkJzczNpaWk0NzeTnJx83pgDBw6wb98+Xn/9dbq7u/H5fMTGxnLbbbcNsHwRkciQmBgPdNPR2ROS4we1pFNQUMDmzZspLi5m8+bNzJ49+7wx99xzz9l/b9q0iYMHDyrsRUQ+RGLyXwO/Kww/tC0uLmbXrl3cc8897N69m+LiYgAOHjzIT3/608GoT0QkYiSknLmAWkdPaL5tG9Q7/KSkJB555JHzHvd4PHg8nvMeX7BgAQsWLAhmShGRESs6JgZrwE+XNzTXxNctDkVEwoTVaiXW76XbGgjN8UNyVBERuShxRi/dfgW+iMiIF4OfHuP8U9wHgwJfRCSMRGHQq8AXERn5ojHwhujYCnwRkTASRQBvIDTRrMAXEQkjVgKE5qRMBb6ISFixWsBAa/giIiOejUDIAl9fvBIRCSPTHFF0dIXhxdNERGRw3bnkhpBdRlpLOiIiEUKBLyISIRT4IiIRQoEvIhIhFPgiIhFCgS8iEiEU+CIiEUKBLyISISyBQCA0t1YREZGwErHv8O+//36zSxgW1KeBUZ/6px4NTCj7FLGBLyISaRT4IiIRImIDv6ioyOwShgX1aWDUp/6pRwMTyj7pQ1sRkQgRse/wRUQijQJfRCRCRPQNUH7zm9+wfft2LBYLKSkp3HXXXTgcDrPLCju/+tWveOedd7Db7bjdbu666y4SEhLMLiusvPXWWzz33HPU1NTwne98B4/HY3ZJYaWiooK1a9diGAaLFy+muLjY7JLCzn/+53+yY8cOUlJS+P73vx+SOSJ6Db+zs5P4+HgAXnnlFaqrq7nzzjtNrir87Ny5kylTpmCz2Xj66acBuP32202uKrxUV1djtVp58skn+dznPqfA/xuGYbBs2TIefvhhnE4nDzzwAMuWLSM3N9fs0sLK3r17iY2N5YknnghZ4Ef0ks4HYQ/Q09ODxRKaGwcPd9OnT8dmswEwYcIEmpqaTK4o/OTm5pKdnW12GWGpqqqKzMxM3G43drudwsJCysvLzS4r7EyePJnExMSQzhHRSzoAv/71r9myZQvx8fE8+uijZpcT9jZu3EhhYaHZZcgw0tTUhNPpPLvtdDqprKw0saLINeIDf+XKlbS0tJz3+M0338zs2bO55ZZbuOWWW3jhhRd47bXXuOmmm4a+yDDQX58Afve732Gz2bj66quHuLrwMJAeiYSzER/4K1asGNC4q6++mu9+97sRG/j99WnTpk288847PPLIIxG79DXQ15Kcy+Fw0NjYeHa7sbFRJ0eYJKLX8Ovq6s7+u7y8XGuwfaioqODFF1/k61//OjExMWaXI8OMx+Ohrq6O+vp6fD4fZWVlFBQUmF1WRIros3Qee+wx6urqsFgsuFwu7rzzTr3z+BBLly7F5/Od/UBp/PjxOpvp77z99tv84he/oK2tjYSEBPLz83nooYfMLits7Nixg1/+8pcYhsHChQu58cYbzS4p7Pzwhz9k7969tLe3k5KSwk033cSiRYsGdY6IDnwRkUgS0Us6IiKRRIEvIhIhFPgiIhFCgS8iEiEU+CIiEUKBLyISIRT4IiIR4v8BQp0hU2W2VREAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(mean[:,0], mean[:,1])\n", + "plt.errorbar(mean[:,0], mean[:,1], var[:,1], var[:,0])\n", + "#plt.plot(points[:,0], points[:,1])" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "aecc8906-69fd-41ac-a2dc-05c0dd8d26a6", + "metadata": {}, + "outputs": [], + "source": [ + "from pyepal.pal.pal_gpflowgpr import PALGPflowGPR\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "59a73e33-d833-4931-9af4-d13c63d7028e", + "metadata": {}, + "outputs": [], + "source": [ + "def build_model(x, y):\n", + " k = gpflow.kernels.RationalQuadratic()\n", + " m = gpflow.models.GPR(data=(x, y), kernel=k, mean_function=None)\n", + " return m" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0fcc4f84-cab3-4e6c-a9e1-fccc0117ff7f", + "metadata": {}, + "outputs": [], + "source": [ + "def binh_korn_points():\n", + " \"\"\"Create a dataset based on the Binh-Korn test function\"\"\"\n", + " x = np.linspace(0, 5, 100) # pylint:disable=invalid-name\n", + " y = np.linspace(0, 3, 100) # pylint:disable=invalid-name\n", + " array = np.array([binh_korn(xi, yi) for xi, yi in zip(x, y)])\n", + " return np.hstack([x.reshape(-1, 1), y.reshape(-1, 1)]), array" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "87550ee9-16f8-447c-8ef8-b9cd3e5c1e17", + "metadata": {}, + "outputs": [], + "source": [ + "X_binh_korn, y_binh_korn = binh_korn_points()\n", + "\n", + "X_binh_korn = (X_binh_korn - X_binh_korn.mean()) / X_binh_korn.std()\n", + "y_binh_korn = (y_binh_korn - y_binh_korn.mean()) / y_binh_korn.std() + 0.01 * np.random.rand()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9add20d3-2807-4276-a841-f9a9f65f6b33", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/kevinmaikjablonka/Dropbox (LSMO)/Documents/open_source/PythonPAL/pyepal/pal/validate_inputs.py:150: UserWarning: Only one epsilon value provided,\n", + "will automatically expand to use the same value in every dimension\n", + " UserWarning,\n", + "/Users/kevinmaikjablonka/Dropbox (LSMO)/Documents/open_source/PythonPAL/pyepal/pal/validate_inputs.py:178: UserWarning: No goals provided, will assume that every dimension should be maximized\n", + " UserWarning,\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "training 0\n" + ] + } + ], + "source": [ + "sample_idx = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 50, 60, 70])\n", + "model_0 = build_model(X_binh_korn[sample_idx], y_binh_korn[sample_idx])\n", + "model_1 = build_model(X_binh_korn[sample_idx], y_binh_korn[sample_idx])\n", + "\n", + "palinstance = PALGPflowGPR(\n", + " X_binh_korn,\n", + " [model_0, model_1],\n", + " 2,\n", + " beta_scale=1,\n", + " epsilon=0.01,\n", + " delta=0.01,\n", + " opt_kwargs={\"maxiter\": 50}\n", + ")\n", + "palinstance.cross_val_points = 0\n", + "palinstance.update_train_set(sample_idx, y_binh_korn[sample_idx])\n", + "idx = palinstance.run_one_step()\n", + "assert idx[0] not in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 50, 60, 70]\n", + "assert palinstance.number_sampled_points > 0\n", + "assert sum(palinstance.discarded) == 0\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "072ea41d-014b-4b85-887e-e1be583c6406", + "metadata": {}, + "outputs": [], + "source": [ + "model_0" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "id": "645681f1-22c9-439e-ac05-6a46ab581622", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 99, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "palinstance.opt" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "id": "7dde3e55-2760-437a-be9a-627ba8de3661", + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "id": "c4bbb186-07e5-442d-8471-d50fd3b30b39", + "metadata": {}, + "outputs": [], + "source": [ + "def _train_model_picklable(i, models, opt, opt_kwargs):\n", + " model = models[i]\n", + " _ = opt.minimize(model.training_loss, model.trainable_variables, options=opt_kwargs)\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "id": "cb789d19-4368-4527-8b11-a6225cdecbe7", + "metadata": {}, + "outputs": [], + "source": [ + "train_model_pickleable_partial = partial(\n", + " _train_model_picklable,\n", + " models=palinstance.models,\n", + " opt=palinstance.opt,\n", + " opt_kwargs=palinstance.opt_kwargs,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 109, + "id": "db26da04-235c-46d5-82ca-94a8697e759f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<gpflow.models.gpr.GPR object at 0x7febef6d2a10>\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
name class transform prior trainable shape dtype value
GPR.kernel.variance ParameterSoftplus True () float64 16.7525
GPR.kernel.lengthscalesParameterSoftplus True () float64 9.11525
GPR.kernel.alpha ParameterSoftplus True () float64395.715
GPR.likelihood.varianceParameterSoftplus + Shift True () float64 1.00001e-06
" + ], + "text/plain": [ + "\n", + "╒═════════════════════════╤═══════════╤══════════════════╤═════════╤═════════════╤═════════╤═════════╤═══════════════╕\n", + "│ name │ class │ transform │ prior │ trainable │ shape │ dtype │ value │\n", + "╞═════════════════════════╪═══════════╪══════════════════╪═════════╪═════════════╪═════════╪═════════╪═══════════════╡\n", + "│ GPR.kernel.variance │ Parameter │ Softplus │ │ True │ () │ float64 │ 16.7525 │\n", + "├─────────────────────────┼───────────┼──────────────────┼─────────┼─────────────┼─────────┼─────────┼───────────────┤\n", + "│ GPR.kernel.lengthscales │ Parameter │ Softplus │ │ True │ () │ float64 │ 9.11525 │\n", + "├─────────────────────────┼───────────┼──────────────────┼─────────┼─────────────┼─────────┼─────────┼───────────────┤\n", + "│ GPR.kernel.alpha │ Parameter │ Softplus │ │ True │ () │ float64 │ 395.715 │\n", + "├─────────────────────────┼───────────┼──────────────────┼─────────┼─────────────┼─────────┼─────────┼───────────────┤\n", + "│ GPR.likelihood.variance │ Parameter │ Softplus + Shift │ │ True │ () │ float64 │ 1.00001e-06 │\n", + "╘═════════════════════════╧═══════════╧══════════════════╧═════════╧═════════════╧═════════╧═════════╧═══════════════╛" + ] + }, + "execution_count": 109, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_model_pickleable_partial(1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8397db0b-334f-479c-afe3-bf9c06326240", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/api.rst b/docs/api.rst index 9feb355..336d3ff 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -59,6 +59,14 @@ For quantile regression with LightGBM :special-members: +For GPR with GPFlow +....................................... + +.. automodule:: pyepal.pal.pal_gpflowgpr + :members: + :show-inheritance: + :special-members: + Schedules for hyperparameter optimization ........................................... diff --git a/pyepal/__init__.py b/pyepal/__init__.py index d3e738e..98eeac7 100644 --- a/pyepal/__init__.py +++ b/pyepal/__init__.py @@ -21,6 +21,7 @@ from .pal.pal_coregionalized import PALCoregionalized from .pal.pal_finite_ensemble import PALJaxEnsemble from .pal.pal_gbdt import PALGBDT +from .pal.pal_gpflowgpr import PALGPflowGPR from .pal.pal_gpy import PALGPy from .pal.pal_neural_tangent import PALNT from .pal.pal_sklearn import PALSklearn @@ -39,6 +40,7 @@ "PALCoregionalized", "PALGBDT", "PALGPy", + "PALGPflowGPR", "PALSklearn", "PALJaxEnsemble", "PALNT", diff --git a/pyepal/pal/pal_gpflowgpr.py b/pyepal/pal/pal_gpflowgpr.py new file mode 100644 index 0000000..adeeeb5 --- /dev/null +++ b/pyepal/pal/pal_gpflowgpr.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 PyePAL authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PAL using GPy GPR models""" +import concurrent.futures +from functools import partial + +import numpy as np + +from .pal_base import PALBase +from .schedules import linear +from .validate_inputs import validate_njobs, validate_number_models + +__all__ = ["PALGPflowGPR"] + + +def _train_model_picklable(i, models, opt, opt_kwargs): + print(f"training {i}") + model = models[i] + _ = opt.minimize(model.training_loss, model.trainable_variables, options=opt_kwargs) + return model + + +class PALGPflowGPR(PALBase): + """PAL class for a list of GPFlow GPR models, with one model per objective. + Please consider that there are specific multioutput models + (https://gpflow.readthedocs.io/en/master/notebooks/advanced/multioutput.html) + for which the train and prediction function would need to be adjusted. + You might also consider using streaming GPRs + (https://github.com/thangbui/streaming_sparse_gp). + In future releases we might support this case automatically + (i.e., handle the case in which only one model is provided). + """ + + def __init__(self, *args, **kwargs): + """Contruct the PALGPflowGPR instance + + Args: + X_design (np.array): Design space (feature matrix) + models (list): Machine learning models + ndim (int): Number of objectives + epsilon (Union[list, float], optional): Epsilon hyperparameter. + Defaults to 0.01. + delta (float, optional): Delta hyperparameter. Defaults to 0.05. + beta_scale (float, optional): Scaling parameter for beta. + If not equal to 1, the theoretical guarantees do not necessarily hold. + Also note that the parametrization depends on the kernel type. + Defaults to 1/9. + goals (List[str], optional): If a list, provide "min" for every objective + that shall be minimized and "max" for every objective + that shall be maximized. Defaults to None, which means + that the code maximizes all objectives. + coef_var_threshold (float, optional): Use only points with + a coefficient of variation below this threshold + in the classification step. Defaults to 3. + opt (function, optional): Optimizer function for the GPR parameters. + If None (default), then we will use ` gpflow.optimizers.Scipy()` + opt_kwargs (dict, optional): Keyword arguments passed to the optimizer. + If None, PyePAL will pass `{"maxiter": 100}` + n_jobs (int): Number of parallel threads that are used to fit + the GPR models. Defaults to 1. + """ + import gpflow # pylint:disable=import-outside-toplevel + + self.n_jobs = validate_njobs(kwargs.pop("n_jobs", 1)) + self.opt = kwargs.pop("opt", gpflow.optimizers.Scipy()) + self.opt_kwargs = kwargs.pop("opt_kwargs", {"maxiter": 100}) + super().__init__(*args, **kwargs) + + validate_number_models(self.models, self.ndim) + # validate_gpy_model(self.models) + + def _set_data(self): + from gpflow.models.util import ( # pylint:disable=import-outside-toplevel + data_input_to_tensor, + ) + + for i, model in enumerate(self.models): + model.data = data_input_to_tensor( + ( + self.design_space[self.sampled[:, i]], + self.y[self.sampled[:, i], i].reshape(-1, 1), + ) + ) + + def _train(self): + models = [] + train_model_pickleable_partial = partial( + _train_model_picklable, + models=self.models, + opt=self.opt, + opt_kwargs=self.opt_kwargs, + ) + with concurrent.futures.ThreadPoolExecutor( + max_workers=self.n_jobs, + ) as executor: + for model in executor.map(train_model_pickleable_partial, range(self.ndim)): + models.append(model) + self.models = models + print("training done") + + def _predict(self): + means, stds = [], [] + for model in self.models: + mean, std = model.predict_f(self.design_space) + mean = mean.numpy() + std = std.numpy() + means.append(mean.reshape(-1, 1)) + stds.append(np.sqrt(std.reshape(-1, 1))) + + self.means = np.hstack(means) + self.std = np.hstack(stds) + + def _set_hyperparameters(self): + pass + + def _should_optimize_hyperparameters(self) -> bool: + return linear(self.iteration, 10) diff --git a/setup.py b/setup.py index 275936e..d1e05f4 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ ] gbdt_requirements = ["lightgbm==3.*"] neural_tangents_requirements = ["neural_tangents==0.*", "jaxlib==0.*"] +gpflow_requirements = ["gpflow"] setup( name="pyepal", version=versioneer.get_version(), @@ -62,7 +63,11 @@ "GPy": gpy_requirements, "GBDT": gbdt_requirements, "neural_tangents": neural_tangents_requirements, - "all": neural_tangents_requirements + gbdt_requirements + gpy_requirements, + "all": neural_tangents_requirements + + gbdt_requirements + + gpy_requirements + + gpflow_requirements, + "gpflow": gpflow_requirements, }, author="PyePAL authors", author_email="kevin.jablonka@epfl.ch, brian.yoo@basf.com", diff --git a/tests/test_pal_gpflowgpr.py b/tests/test_pal_gpflowgpr.py new file mode 100644 index 0000000..6f7d0b5 --- /dev/null +++ b/tests/test_pal_gpflowgpr.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 PyePAL authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Testing the PALGPflowGPR class""" +import numpy as np + +from pyepal.pal.pal_gpflowgpr import PALGPflowGPR + + +def test_pal_gpflow(binh_korn_points): + """Test basic functionality of the PALGpy class""" + import gpflow # pylint:disable=import-outside-toplevel + + X_binh_korn, y_binh_korn = binh_korn_points # pylint:disable=invalid-name + X_binh_korn = ( # pylint:disable=invalid-name + X_binh_korn - X_binh_korn.mean() + ) / X_binh_korn.std() # pylint:disable=invalid-name + y_binh_korn = ( + y_binh_korn - y_binh_korn.mean() + ) / y_binh_korn.std() + 0.01 * np.random.rand() + + def build_model(x, y): # pylint:disable=invalid-name + k = gpflow.kernels.RationalQuadratic() + m = gpflow.models.GPR( # pylint:disable=invalid-name + data=(x, y), kernel=k, mean_function=None + ) + return m + + sample_idx = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 50, 60, 70]) + model_0 = build_model(X_binh_korn[sample_idx], y_binh_korn[sample_idx]) + model_1 = build_model(X_binh_korn[sample_idx], y_binh_korn[sample_idx]) + + palinstance = PALGPflowGPR( + X_binh_korn, + [model_0, model_1], + 2, + beta_scale=1, + epsilon=0.01, + delta=0.01, + opt_kwargs={"maxiter": 50}, + ) + palinstance.cross_val_points = 0 + palinstance.update_train_set(sample_idx, y_binh_korn[sample_idx]) + idx = palinstance.run_one_step() + assert idx[0] not in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 50, 60, 70] + assert palinstance.number_sampled_points > 0 + assert sum(palinstance.discarded) == 0