Skip to content

Mi-Peng/Systematic-Investigation-of-Sparse-Perturbed-Sharpness-Aware-Minimization-Optimizer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Systematic-Investigation-of-Sparse-Perturbed-Sharpness-Aware-Minimization-Optimizer

This is the official implementation of paper Systematic Investigation of Sparse Perturbed Sharpness-Aware Minimization Optimizer

Installation

Clone this repo
git clone git@github.com:Mi-Peng/Systematic-Investigation-of-Sparse-Perturbed-Sharpness-Aware-Minimization-Optimizer.git
Create a virtual environment (e.g. Anaconda3)
conda create -n ssam python=3.8 -y
conda activate ssam
Install the necessary packages
  1. Pytorch

Install Pytorch following the official installation instructions.

conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch -y
  1. cusparseLt

Details could be found in cusparseLt.md

  1. Install other packages
pip install einops
  1. Install wandb(optional)

Wandb makes it easy to track your experiments, manage & version your data. This is optional, codes run without wandb.

pip install wandb
  1. Dataset preparation We use CIFAR10, CIFAR100 and ImageNet in this repo.

For the CIFAR dataset, you don't need to do anything, pytorch will do the trivia about downloading.

For ImageNet dataset, we use standard ImageNet dataset, which could be found in http://image-net.org/. Your ImageNet file structure should look like:

$ tree data
imagenet
├── train
│   ├── class1
│   │   ├── img1.jpeg
│   │   ├── img2.jpeg
│   │   └── ...
│   ├── class2
│   │   ├── img3.jpeg
│   │   └── ...
│   └── ...
└── val
    ├── class1
    │   ├── img4.jpeg
    │   ├── img5.jpeg
    │   └── ...
    ├── class2
    │   ├── img6.jpeg
    │   └── ...
    └── ...

Configuration

Details are in configs/defaulf_cfg.py.

  • --dataset: Currently supported choice include: CIFAR10_base, CIFAR10_cutout, CIFAR100_base, CIFAR100_cutout and ImageNet_base,.

  • --model: Currently supported choice include: resnet18, wideresnet28x10, ...(See more in models/__init__.py)

  • --opt: How to update parameters. --sgd for SGD, --sam-sgd for SAM within SGD, --ssamf-sgd for Fisher-SparseSAM within SGD.

  • --pattern. pattern of masks. Currently supported choice include: unstructured, structured, nm.

  • --n_structured and --m_structured. Set n and m in nm pattern (Only works for nm pattern).

  • --implicit. Whether use mask to calculate sparse perturbation implicitly, and must add argument --samconv or --culinear to transform the backpropagation.

  • --samconv. Transform the convolution layer for implicit sparse perturbation.(For ResNet)

  • --culinear. Transform the linear layer for implicit sparse perturbation.(For vit_testspmm)

Training

Training model on CIFAR10 with SGD (Taking ResNet18 as an example)
python train.py \
  --model resnet18 \
  --dataset CIFAR10_cutout --datadir [Path2Data] \
  --opt sgd --lr 0.05 --weight_decay 5e-4 \
  --seed 1234 --wandb
Training model on CIFAR10 with SAM (Taking ResNet18 as an example)
python train.py \
  --model resnet18 \
  --dataset CIFAR100_cutout --datadir [Path2Data] \
  --opt sam-sgd --lr 0.05 --weight_decay 1e-3 --rho 0.2 \
  --seed 1234 --wandb
Training model on CIFAR10 with SSAM, unstructured mask, explicit sparse perturbation
python train.py \
  --model resnet18 \
  --dataset CIFAR100_cutout --datadir [Path2Data] \
  --opt ssamf-sgd --lr 0.05 --weight_decay 1e-3 --rho 0.2 \
  --pattern unstructured --sparsity 0.5 --num_samples 128 --update_freq 1 \
  --seed 1234 --wandb
Training model on CIFAR10 with SSAM, structured mask, explicit sparse perturbation
python train.py \
  --model resnet18 \
  --dataset CIFAR100_cutout --datadir [Path2Data] \
  --opt ssamf-sgd --lr 0.05 --weight_decay 1e-3 --rho 0.2 \
  --pattern structured --sparsity 0.5 --num_samples 128 --update_freq 1 \
  --seed 1234 --wandb
Training model on CIFAR10 with SSAM, N:M mask, explicit sparse perturbation
python train.py \
  --model resnet18 \
  --dataset CIFAR100_cutout --datadir [Path2Data] \
  --opt ssamf-sgd --lr 0.05 --weight_decay 1e-3 --rho 0.2 \
  --pattern nm --n_structured 2 --m_structured 4 --num_samples 128 --update_freq 1 \
  --seed 1234 --wandb
Training model on CIFAR10 with SSAM, structured mask, implicit sparse perturbation
python train.py \
  --model resnet18 \
  --dataset CIFAR100_cutout --datadir [Path2Data] \
  --opt sam-sgd --lr 0.05 --weight_decay 1e-3 --rho 0.2 \
  --pattern structured --sparsity 0.5 --num_samples 128 --update_freq 1 --implicit --samconv\
  --seed 1234 --wandb
Training model on CIFAR10 with SSAM, N:M mask, implicit sparse perturbation
python train.py \
  --model resnet18 \
  --dataset CIFAR100_cutout --datadir [Path2Data] \
  --opt sam-sgd --lr 0.05 --weight_decay 1e-3 --rho 0.2 \
  --pattern nm --n_structured 2 --m_structured 4 --num_samples 128 --update_freq 1 --implicit --samconv\
  --seed 1234 --wandb
Test cusparseLt for ViT on CIFAR10 with SSAM N:M mask implicit sparse perturbation
python train.py \
  --model vit_testspmm --patch_size 1 --log_freq 1 \
  --dataset CIFAR100_cutout --datadir [Path2Data] \
  --opt sam-sgd --lr 0.05 --weight_decay 1e-3 --rho 0.2 \
  --pattern nm --n_structured 2 --m_structured 4 --num_samples 128 --update_freq 1 --implicit --culinear \
  --seed 1234 --wandb

About

Systematic Investigation of Sparse Perturbed Sharpness-Aware Minimization Optimizer

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published