Skip to content

EigenPro/EigenPro-tensorflow

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

28 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

EigenPro

Intro

EigenPro is a preconditioned (stochastic) gradient descent iteration proposed in the paper:

Ma, Siyuan, and Mikhail Belkin. Diving into the shallows:
a computational perspective on large-scale shallow learning.
In NIPS, 2017.

It accelerates the convergence of SGD iteration when minimizing linear and kernel least squares, defined as

where is the labeled training data. Let .

Consdier the linear setting where is a vector space and . The corresponding standard gradient descent iteration is hence,

where is the covariance matrix and . The step size is automatically set as to ensure fast convergence. Note that the top eigenvalue of the covariance is calculated approximately. We then construct EigenPro preconditioner P using the approximate top eigensystem of H, which can be efficiently calculated when H has fast eigendecay.

Here we select to counter the negative impact of eigensystem approximation error on convergence. The EigenPro iteration then runs as follows,

With larger , EigenPro iteration yields higher convergence acceleration over standard (stochastic) gradient descent. This is especially critical in the kernel setting where (widely used) smooth kernels have exponential eigendecay. Note that in such setting is typically an RKHS (reproducing kernel Hilbert space) of infinite dimension. Thus it is necessary to parametrize the (approximate) solution in a subspace of finite dimension (e.g. ). See the paper for more details on the kernel setting and some theoretical results.

Requirements: Tensorflow (>=1.2.1) and Keras (=2.0.8)

pip install tensorflow tensorflow-gpu keras

Follow the Tensorflow installation guide for Virtualenv setup.

Running experiments

The experiments will compare Pegasos, Kernel EigenPro, Random Fourier Feature with linear SGD, and Random Fourier Feature with EigenPro on MNIST.

python run_expr.py

Besides, users can pass the flag "--kernel" to choose different kernels such like Gaussian, Laplace, and Cauchy.

python run_expr.py --kernel=Laplace

Note that we have only implemented the random Fourier feature for the Gaussian kernel.

An example of building and training a kernel model

First, let's import the related Keras and kernel components,

from keras.layers import Dense, Input
from keras.models import Model
from keras import backend as K

from layers import KernelEmbedding
from optimizers import PSGD

Please read this short Keras tutorial to get familiar with its components. Then we can create the input layer,

import mnist
import utils
(x_train, y_train), (x_test, y_test) = mnist.load()
n, D = x_train.shape
ix = Input(shape=(D+1,), dtype='float32', name='indexed-feat')
x, index = utils.separate_index(ix) # features, sample_id

Note that the initialization of PSGD (SGD optimizer for primal kernel method) needs a tensor that records the sample ID. Therefore we preprocess each sample by appending its sample id after its feature vector (ix). The KernelEmbedding layer is a non-trainable layer that maps input feature vector (x) to kernel features (kfeat) with a given kernel function (in kernels.py) ,

kfeat = KernelEmbedding(kernel, x_train,
                        input_shape=(D,))(x)

Since the kernel least squares is essentially a linear least squares model using the kernel features, we create a trainable Dense (linear) layer for the kernel features to predict corresponding labels.

y = Dense(num_classes, input_shape=(n,),
          activation='linear',
          kernel_initializer='zeros',
          use_bias=False)(kfeat)

Thus the Keras model can be created using the input tensor (ix) and the prediction tensor (y). Also, calling the compile(...) method to specify the loss function and optimizer for training, as well as the metrics for evaluation.

model = Model(ix, y)
model.compile(loss='mse',
              optimizer=PSGD(pred_t=y, index_t=index, eta=5.),
              metrics=['accuracy'])

The training can be performed by calling the method fit(...),

model.fit(utils.add_index(x_train), y_train,
          batch_size=256, epochs=10, verbose=0,
          validation_data=(utils.add_index(x_test), y_test))

It will run for 10 epochs using mini-batches of size 256. Note that utils.add_index(...) will append the sample id to each sample feature vector.

To evaluate the training result, we can call the method evaluate(...),

scores = model.evaluate(utils.add_index(x_test), y_test, verbose=0)

where scores[0] is the L2 loss (mse) and scores[1] the accuracy on the testing set.

Using the EigenPro iteration

The EigenPro iteration can be called through a Keras Model. It is integrated in the two optimizers, SGD and PSGD. The former works with a finite dimension feature map like random Fourier feature; the latter works in an RKHS related to a kernel function. Note the latter requires appending a sample id (used during training) to each data sample.

By default, the optimizers use standard (stochastic) gradient descents. To enable EigenPro iteration, pass parameter eigenpro_f to the optimizer, such like

PSGD(... , eta=scale*eta, eigenpro_f=f)

where scale is the eigenvalue ratio used to increase the step size and f is the EigenPro preconditioner (or more specifically, I - P. See the intro section). Both can be calculated using utils.py,

f, scale = utils.asm_eigenpro_f(... , in_rkhs=True)

Here flag in_rkhs indicates if the calculation is for PSGD (infinite dimension RKHS) or SGD (finite dimension vector space). The function will use truncated randomized SVD (for small dataset) or Nystrom based SVD (for large dataset) to calcualte the approximate top eigensystem of the covariance.

Note that the optimizer should be connected to a Keras model,

model.compile(loss='mse', optimizer=PSGD(...), metrics=['accuracy'])

After the optimizer is appropriately initialized for a model, the EigenPro iteration will be used through model training,

model.fit(x_train, y_train)

Reference experimental results

Classification Error (MNIST)

In these experiments, EigenPro (Primal) achieves classification error 1.22% using only 10 epochs. For comparison, Pegasos reaches the same error after 80 epochs. Although the number of random features used by EigenPro (Random) and RF/DSGD is 6 * 10^4, same as the number of training points, methods using random features deliver generally worse performance. Specifically, RF/DSGD has error rate 1.75% after 20 epochs and Pegasos reaches error rate 1.63% after the same number of epochs.

#Epochs Primal Random Fourier Feature
EigenPro Pegasos EigenPro RF/DSGD
train test train test train test train test
1 0.43% 1.75% 4.01% 4.35% 0.39% 1.88% 4.00% 4.35%
5 0.02% 1.26% 1.58% 2.32% 0.05% 1.48% 1.70% 2.51%
10 0.0% 1.22% 0.89% 1.91% 0.01% 1.49% 0.98% 2.09%
20 0.0% 1.23% 0.40% 1.63% 0.0% 1.48% 0.48% 1.75%

Training Time per Epoch

Computing
Resource
Primal Random Fourier Feature
EigenPro Pegasos EigenPro RF/DSGD
One GTX Titan X (Maxwell) 5.0s 4.6s 2.4s 2.0s
One GTX Titan Xp (Pascal) 3.0s 2.7s 1.6s 1.4s

EigenPro Preprocessing Time

In our experiments we construct the EigenPro preconditioner by computing the top 160 approximate eigenvectors for a subsample matrix with 4800 points using Randomized SVD (RSVD).

Computing
Resource
RSVD Time
(k = 160, m = 4800)
One GTX Titan X (Maxwell) 18.1s
One GTX Titan Xp (Pascal) 17.4s

About

EigenPro iteration in Tensorflow (Keras)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages