This repository contains a PyTorch implementation of the curve-finding methods and WA-ensembling procedure from the paper
by Ivan Anokhin and Dmitry Yarotsky (ICML 2020).
Please cite our work if you find it useful in your research:
@article{anokhin2020low,
title={Low-loss connection of weight vectors: distribution-based approaches},
author={Anokhin, Ivan and Yarotsky, Dmitry},
journal={arXiv preprint arXiv:2008.00741},
year={2020}
}
Before usage go to the project directory: cd distribution_connector
, install requirements: pip install -r requirements.txt
and export PYTHONPATH: export PYTHONPATH=$(pwd)
.
The code in this repository implements the curve-finding procedure for the various methods for Dense ReLU nets and VGG16, and the Ensembling procedure with Weight Adjusment as discribed in the paper.
To run the curve-finding procedure or the ensembling procedure, you first need to train two or more networks that will serve as the end-points of the curve or as input to the WA ensembling procedure. You can train the endpoints using the following command
python3 train.py --dir=<DIR> \
--dataset=<DATASET> \
--data_path=<PATH> \
--model=<MODEL> \
--epochs=<EPOCHS> \
--lr_init=<LR_INIT> \
--wd=<WD> \
--seed=<SEED>
Parameters:
DIR
— path to training directory where checkpoints will be storedDATASET
— dataset name [MNIST/CIFAR10]DATA_PATH
— path to the data directoryMODEL
— DNN model name:- for MNIST dataset:
- LinearOneLayer
- for CIFAR10:
- LinearOneLayer100, LinearOneLayer500, LinearOneLayer1000, LinearOneLayer2000
- Linear3NoBias, Linear5NoBias, Linear7NoBias
- VGG16/
- PreResNet110
- for MNIST dataset:
EPOCHS
— number of training epochsLR_INIT
— initial learning rateWD
— weight decaySEED
— use different seeds to get different end-points
For example, use the following commands to train LinearOneLayer on MNIST and LinearOneLayer100, Linear3NoBias, VGG16 on CIFAR10:
#LinearOneLayer
python3 train.py --dir=checkpoints/LinearOneLayer/chp1 --dataset=MNIST --data_path=data --model=LinearOneLayer --epochs=30 --seed=1 --cuda
#LinearOneLayer100
python3 train.py --dir=checkpoints/LinearOneLayer100/chp1 --dataset=CIFAR10 --data_path=data --model=LinearOneLayer100 --epochs=400 --seed=1 --cuda
#Linear3NoBias
python3 train.py --dir=checkpoints/Linear3NoBias/chp1 --dataset=CIFAR10 --data_path=data --model=Linear3NoBias --epochs=400 --seed=1 --cuda
#VGG16
python3 train.py --dir=checkpoints/VGG16/chp1 --dataset=CIFAR10 --data_path=data --model=VGG16 --epochs=200 --seed=1 --cuda
To evaluate the methods to connect the endpoints, you can use the following command
python3 eval_curve.py --dir=<DIR> \
--point_finder=<POINTFINDER> \
--method=<METHOD>\
--end_time=<ENDTIME>\
--dataset=<DATASET> \
--data_path=<PATH> \
--model=<MODEL> \
--start=<START> \
--end=<END> \
--num_points=<NUM_POINTS>
Parameters
POINTFINDER
— algorithm that proposes samples of distribution to connect and may do some additional routine to preserve output of the network [PointFinderWithBias/PointFinderInverseWithBias/PointFinderTransportation/PointFinderInverseWithBiasOT/PointFinderSimultaneous/PointFinderStepWiseButterfly/PointFinderStepWiseInverse/PointFinderStepWiseTransportation/PointFinderStepWiseInverseOT]METHOD
— method that connects proposed by POINTFINDER samples [lin_connect/arc_connect]; lin_connect and arc_connect refer to Eq. 5 and Eq. 6 in the paper respectively.
POINTFINDER
andMETHOD
together determine the curve-finding procedures we examine in the paper. For example, in Table 1 in the paper PointFinderWithBias lin_connect refers to theLinear
, PointFinderWithBias arc_connect refers toArc
, PointFinderInverseWithBias lin_connect refers toLinear + Weight Adjustment
, PointFinderInverseWithBias arc_connect refers toArc + Weight Adjustment
, PointFinderTransportation lin_connect refers toOT
, PointFinderInverseWithBiasOT lin_connect refers toOT + Weight Adjustment
. Also, in Table 2 in the paper PointFinderSimultaneous lin_connect refers toLinear
, PointFinderSimultaneous arc_connect refers toArc
, PointFinderStepWiseButterfly lin_connect refers toLinear + B-fly
, PointFinderStepWiseButterfly arc_connect refers toArc + B-fly
, PointFinderStepWiseInverse lin_connect refers toLinear + WA
, PointFinderStepWiseInverse arc_connect refers toArc + WA
, PointFinderStepWiseTransportation lin_connect refers toOT + B-fly
, PointFinderStepWiseInverseOT lin_connect toOT + WA
,
START
— path to the first checkpoint saved bytrain.py
END
— path to the second checkpoint saved bytrain.py
NUM_POINTS
— number of points along the curve to use for evaluationENDTIME
—POINTFINDER
andMODEL
dependent time (parametrization of the curve) when the curve reaches the endpoint
eval_curve.py
outputs the statistics on train and test loss and error along the curve. It also saves a .npz
file containing more detailed statistics at <DIR>
.
For example, use the following commands to evaluate the paths on CIFAR10:
#PointFinderWithBias lin_connect for LinearOneLayer100 model
python3 eval_curve.py --dir=experiments/eval/LinearOneLayer100/PointFinderWithBias --point_finder=PointFinderWithBias --method=lin_connect --model=LinearOneLayer100 --end_time=1 --data_path=data --num_points=21 --start=checkpoints/LinearOneLayer100/chp1/checkpoint-400.pt --end=checkpoints/LinearOneLayer100/chp2/checkpoint-400.pt --cuda
#PointFinderInverseWithBias arc_connect for LinearOneLayer100 model
python3 eval_curve.py --dir=experiments/eval/LinearOneLayer100/PointFinderInverseWithBias --point_finder=PointFinderInverseWithBias --method=arc_connect --model=LinearOneLayer100 --end_time=2 --data_path=data --num_points=21 --start=checkpoints/LinearOneLayer100/chp1/checkpoint-400.pt --end=checkpoints/LinearOneLayer100/chp2/checkpoint-400.pt --cuda
#PointFinderTransportation lin_connect for LinearOneLayer100 model
python3 eval_curve.py --dir=experiments/eval/LinearOneLayer100/PointFinderTransportation --point_finder=PointFinderTransportation --method=lin_connect --model=LinearOneLayer100 --end_time=1 --data_path=data --num_points=21 --start=checkpoints/LinearOneLayer100/chp1/checkpoint-400.pt --end=checkpoints/LinearOneLayer100/chp2/checkpoint-400.pt --cuda
#PointFinderInverseWithBiasOT lin_connect for LinearOneLayer100 model
python3 eval_curve.py --dir=experiments/eval/LinearOneLayer100/PointFinderInverseWithBiasOT --point_finder=PointFinderInverseWithBiasOT --method=lin_connect --model=LinearOneLayer100 --end_time=2 --data_path=data --num_points=21 --start=checkpoints/LinearOneLayer100/chp1/checkpoint-400.pt --end=checkpoints/LinearOneLayer100/chp2/checkpoint-400.pt --cuda
#PointFinderSimultaneous lin_connect for Linear3NoBias model
python3 eval_curve.py --dir=experiments/eval/Linear3NoBias/PointFinderSimultaneous --point_finder=PointFinderSimultaneous --method=lin_connect --model=Linear3NoBias --end_time=1 --data_path=data --num_points=21 --start=checkpoints/Linear3NoBias/chp1/checkpoint-400.pt --end=checkpoints/Linear3NoBias/chp2/checkpoint-400.pt --cuda
#PointFinderStepWiseButterfly arc_connect for Linear3NoBias model
python3 eval_curve.py --dir=experiments/eval/Linear3NoBias/PointFinderStepWiseButterfly --point_finder=PointFinderStepWiseButterfly --method=arc_connect --model=Linear3NoBias --end_time=2 --data_path=data --num_points=21 --start=checkpoints/Linear3NoBias/chp1/checkpoint-400.pt --end=checkpoints/Linear3NoBias/chp2/checkpoint-400.pt --cuda
#PointFinderStepWiseInverse lin_connect for Linear3NoBias model
python3 eval_curve.py --dir=experiments/eval/Linear3NoBias/PointFinderStepWiseInverse --point_finder=PointFinderStepWiseInverse --method=lin_connect --model=Linear3NoBias --end_time=3 --data_path=data --num_points=31 --start=checkpoints/Linear3NoBias/chp1/checkpoint-400.pt --end=checkpoints/Linear3NoBias/chp2/checkpoint-400.pt --cuda
#PointFinderStepWiseTransportation lin_connect for Linear3NoBias model
python3 eval_curve.py --dir=experiments/eval/Linear3NoBias/PointFinderStepWiseTransportation --point_finder=PointFinderStepWiseTransportation --method=lin_connect --model=Linear3NoBias --end_time=2 --data_path=data --num_points=21 --start=checkpoints/Linear3NoBias/chp1/checkpoint-400.pt --end=checkpoints/Linear3NoBias/chp2/checkpoint-400.pt --cuda
#PointFinderStepWiseInverseOT lin_connect for Linear3NoBias model
python3 eval_curve.py --dir=experiments/eval/Linear3NoBias/PointFinderStepWiseInverseOT --point_finder=PointFinderStepWiseInverseOT --method=lin_connect --model=Linear3NoBias --end_time=3 --data_path=data --num_points=31 --start=checkpoints/Linear3NoBias/chp1/checkpoint-400.pt --end=checkpoints/Linear3NoBias/chp2/checkpoint-400.pt --cuda
#PointFinderStepWiseButterflyConvWBiasOT lin_connect for VGG16
python3 eval_curve.py --dir=experiments/eval/VGG16lin/PointFinderStepWiseButterflyConvWBiasOT/12 --point_finder=PointFinderStepWiseButterflyConvWBiasOT --method=lin_connect --model=VGG16 --end_time=15 --data_path=data --num_points=61 --start=checkpoints/VGG16/chp1/checkpoint-400.pt --end=checkpoints/VGG16/chp2/checkpoint-400.pt
To evaluate results of Ensembling with Weight Adjustment you can use the following command
python3 eval_ensemble.py --dir=<DIR> \
--data_path=<PATH> \
--model=<MODEL> \
--name=<NAME> \
--layer=<LAYER>\
--layer_ind=<LAYERIND>\
--model_paths=<MPATHS>
Parameters
NAME
— substring that is in all checkpoint's names you want to ensemble. For example, specify NAME=400 if you want to ensemble checkpoints trained 400 epochs.LAYER
— index of the layer in pytorch network implementation after which Weight Adjusment procedure is performedLAYERIND
— index of the layer in parameter space on which Weight Adjusment procedure is performedMPATHS
— path to the directory where checkpoints for ensembling are stored
For example, use the following commands to evaluate the WA(n) Ensembling (please see Section 6 in the paper for WA(n)):
#Linear3NoBias WA(1)
python3 eval_ensemble.py --dir=experiments/eval_ensemble/Linear3NoBias/ --data_path=data --model=Linear3NoBias --name=400 --layer=1 --layer_ind=2 --model_paths=checkpoints/Linear3NoBias/
#Linear3NoBias WA(2)
python3 eval_ensemble.py --dir=experiments/eval_ensemble/Linear3NoBias/ --data_path=data --model=Linear3NoBias --name=400 --layer=0 --layer_ind=1 --model_paths=checkpoints/Linear3NoBias/
#Linear5NoBias WA(1)
python3 eval_ensemble.py --dir=experiments/eval_ensemble/Linear5NoBias/ --data_path=data --model=Linear5NoBias --name=400 --layer=3 --layer_ind=4 --model_paths=checkpoints/curves/Linear5NoBias/
#Linear7NoBias WA(1)
python3 eval_ensemble.py --dir=experiments/eval_ensemble/Linear7NoBias/ --data_path=data --model=Linear7NoBias --name=400 --layer=5 --layer_ind=6 --model_paths=checkpoints/Linear7NoBias/
#Linear7NoBias WA(3)
python3 eval_ensemble.py --dir=experiments/eval_ensemble/Linear7NoBias/ --data_path=data --model=Linear7NoBias --name=400 --layer=3 --layer_ind=4 --model_paths=checkpoints/Linear7NoBias/
#VGG16 WA(9)
#python3 eval_ensemble.py --dir=experiments/eval_ensemble/VGG16cifar100w9/ --data_path=data --model=VGG16 --name=200 --layer=9 --layer_ind=-14 --model_paths=checkpoints/cifar100/VGG16 --dataset=CIFAR100
#VGG16 WA(10)
#python3 eval_ensemble.py --dir=experiments/eval_ensemble/VGG16cifar100w10/ --data_path=data --model=VGG16 --name=200 --layer=10 --layer_ind=-12 --model_paths=checkpoints/cifar100/VGG16 --dataset=CIFAR100
#VGG16 WA(3)
#python3 eval_ensemble.py --dir=experiments/eval_ensemble/VGG16cifar100w3/ --data_path=data --model=VGG16 --name=200 --layer=3 --layer_ind=-26 --model_paths=checkpoints/cifar100/VGG16 --dataset=CIFAR100
eval_ensemble.py
outputs the statistics on ensembling. It also saves a .npz
file and a .png
plot containing more details at <DIR>
.
- Surfaces, Mode Connectivity, and Fast Ensembling of DNNs by Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry Vetrov and Andrew Gordon Wilson
- Essentially No Barriers in Neural Network Energy Landscape by Felix Draxler, Kambis Veschgini, Manfred Salmhofer, Fred A. Hamprecht
- Topology and Geometry of Half-Rectified Network Optimization by C. Daniel Freeman, Joan Bruna
- Averaging Weights Leads to Wider Optima and Better Generalization by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson