Skip to content

Bayesian Gradient Descent Algorithm Model for TensorFlow

License

Notifications You must be signed in to change notification settings

taldatech/tf-bgd

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

title subtitle

tf-bgd

Video:

Vimeo - https://vimeo.com/297651842

YouTube - https://youtu.be/fa-xLXTzZ8I

Bayesian Gradient Descent Algorithm Model for TensorFlow

regress

Python and Tensorflow implementation of the Bayesian Gradient Descent algorithm and model

Based on the paper "Bayesian Gradient Descent: Online Variational Bayes Learning with Increased Robustness to Catastrophic Forgetting and Weight Pruning" by Chen Zeno, Itay Golan, Elad Hoffer, Daniel Soudry

Paper PDF: https://arxiv.org/abs/1803.10123

Theoretical Background

The basic assumption is that in each step, the previous posterior distribution is used as the new prior distribution and that the parametric distribution is approximately a Diagonal Gaussian, that is, all the parameters of the weight vector $\theta$ are independent.

We define the following:

  • equation - a Random Variable (RV) sampled from equation
  • equation - the weights which we wish to find their posterior distribution
  • equation - the parameters which serve as a condition for the distribution of equation
  • equation - the mean of the weights' distribution, initially sampled from equation
  • equation - the STD (Variance's root) of the weights' distribution, initially set to a small constant.
  • equation - the number of sub-networks
  • equation - hyper-parameter to compenstate for the accumulated error (tunable).
  • equation - Loss function

Algorithm Sketch:

  • Initialize: equation

  • For each sub-network k: sample equation and set equation

  • Repeat:

    1. For each sub-network k: sample equation, compute gradients: equation
    2. Set equation
    3. Set equation
    4. Set equation for each k (sub-network)
  • Until convergence criterion is met

  • Note: i is the equation component of the vector, that is, if we have n paramaters (weights, bias) for each sub-network, then for each parameter we have equation and equation

The expectactions are estimated using Monte Carlo method:

equation

equation

Loss Function Derivation for Regression Problems

equation

Recall that from our Gaussian noise assumption, we dervied that the target (label) equation is also Gaussian distributed, such that: equation where equation is the percision (the inverse variance). Assuming that the dataset is IID, we get the following: equation Taking the negative logarithm, we get: equation Maximizing the log-likelihood is equivalent to minimizing the sum: equation with respect to equation (looks similar to MSE, without the normalization), which is why we use reduce_sum in the code and not reduce_mean.

Note: we denote D as a general expression for the data, and in our case is the probability of the target conditiond on the input and the weights. Pay attention that equation is the log of the probability which is log of an expression between [0,1], thus, the loss itself is not bounded. The probability is a Gaussian (which is of course, bounded).

Regression using BGD

We wish to test the algorithm by learning equation with samples from equation such that equation~equation. We'll take 20 training examples and perform 40 epochs.

Network Prameters:

  • Sub-Networks (K) = 10
  • Hidden Layers (per Sub-Network): 1
  • Neurons per Layer: 100
  • Loss: SSE (Sum of Square Error)
  • Optimizer: BGD (weights are updated using BGD, unbiased Monte-Carlo gradients)

Prerequisites

Library Version
Python 3.6.6 (Anaconda)
tensorflow 1.10.0
sklearn 0.20.0
numpy 1.14.5
matplotlib 3.0.0

Basic Usage

Using the model is simple, there are multiple examples in the repository. Basic methods:

  • from bgd_model import BgdModel
  • model = BgdModel(config, 'train')
  • batch_acc_train = model.train(sess, X_batch, Y_batch)
  • batch_acc_test = model.calc_accuracy(sess, X_test, y_test)
  • model.save(sess, checkpoint_path, global_step=model.global_step)
  • model.restore(session, FLAGS.model_path)
  • results['predictions'] = model.predict(sess, inputs)
  • upper_confidence, lower_confidence = model.calc_confidence(sess, inputs)

Files in the repository

File name Purpsoe
bgd_model.py Includes the class for the BGD model from which you import
bgd_regression_example.py Usage example: simple regression as mentioned above
bgd_train.ipynb Jupyter Notebook with detailed explanation, derivations and graphs

Main Example App Usage:

This little example will train a regression model as described in the background.

The testing (predicting) is performed on 2000 points in [-6,6], which has samples outside the training region ([-4,4], 20 points). It will also output the maximum uncertainty (maximum standard deviation for the output), where we want more uncertainty in uncharted regions to show the flexibility of the network (the reddish zones in the graph).

You should use the bgd_regression_example.py file with the following arguments:

Argument Description
-h, --help shows arguments description
-w, --write_log save log for tensorboard (error graphs, and the NN)
-u, --reset start training from scratch, deletes previous checkpoints
-k, --num_sub_nets number of sub networks (K parameter), default: 10
-e, --epochs number of epochs to run, default: 40
-b, --batch_size batch size for training, default: 1
-n, --neurons number of hidden units, default: 100
-l, --layers number of layers in the network , default: 1
-t, --eta eta parameter ('learning rate'), deafult: 50.0
-g, --sigma sigma_0 parameter, default: 0.002
-f, --save_freq frequency to save checkpoints of the model, default: 200
-r, --decay_rate decay rate of eta (exponential scheduling), default: 1/10
-y, --decay_steps decay steps fof eta (exponential scheduling), default: 10000

Training and Testing

Examples to run bgd_regression_example.py:

  • Note: if there are checkpoints in the /model/ dir, and the model parameters are the same, training will automatically resume from the latest checkpoint (you can choose the exact checkpoint number by editing the checkpoint file in the /model/ dir with your favorite text editor).

python bgd_regression_example.py -k 10 -e 40 -b 1 -n 150 -l 1 -t 300.0 -g 0.005

python bgd_regression_example.py -u -w -k 15 -e 80 -b 5 -n 200 -l 2 -t 50.0 -g 0.003

Model's checkpoints are saved in /model/ dir.

GPU

If you have tensoflow-gpu you can run the example (the session uses tf.GPUOptions(allow_growth=True) ), but make sure to choose the correct device:

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" (so the IDs match nvidia-smi)

os.environ["CUDA_VISIBLE_DEVICES"] = "2" ("0, 1" for multiple)

Tensorboard

You can easily use tensorboard when running bgd_regression_example.py. You should run it with -w flag (to save a log file). This creates a tf_logs directory. To run tensorboard:

cd /path/to/dir/with/bgd_regression_example.py

tensorboard --logdir=./tf_logs

tensorboard