Skip to content

MediaBrain-SJTU/BCL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

51 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Contrastive Learning with Boosted Memorization

Keywords: Long-Tailed Recognition, Self-Supervised Learning, Memorization Effect

Contrastive Learning with Boosted Memorization

ICML 2022

@inproceedings{zhou2022contrastive,
  title={Contrastive Learning with Boosted Memorization},
  author={Zhou, Zhihan and Yao, Jiangchao and Wang, Yan-Feng and Han, Bo and Zhang, Ya},
  booktitle={International Conference on Machine Learning},
  pages={27367--27377},
  year={2022},
  organization={PMLR}
}

Abstract: Self-supervised learning has achieved a great success in the representation learning of visual and textual data. However, the current methods are mainly validated on the well-curated datasets, which do not exhibit the real-world long-tailed distribution. Recent attempts to consider self-supervised long-tailed learning are made by rebalancing in the loss perspective or the model perspective, resembling the paradigms in the supervised long-tailed learning. Nevertheless, without the aid of labels, these explorations have not shown the expected significant promise due to the limitation in tail sample discovery or the heuristic structure design. Different from previous works, we explore this direction from an alternative perspective, i.e., the data perspective, and propose a novel Boosted Contrastive Learning (BCL) method. Specifically, BCL leverages the memorization effect of deep neural networks to automatically drive the information discrepancy of the sample views in contrastive learning, which is more efficient to enhance the long-tailed learning in the label-unaware context. Extensive experiments on a range of benchmark datasets demonstrate the effectiveness of BCL over several state-of-the-art methods.

Get Started

Environment

  • Python (3.7.10)
  • Pytorch (1.7.1)
  • torchvision (0.8.2)
  • CUDA
  • Numpy

File Structure

After the preparation work, the whole project should have the following structure:

./Boosted-Contrastive-Learning
├── README.md
├── data                            # datasets and augmentations
│   ├── memoboosted_cifar100.py
│   ├── cifar100.py                   
│   ├── augmentations.py
│   └── randaug.py
├── models                          # models and backbones
│   ├── simclr.py
│   ├── sdclr.py
│   ├── resnet.py
│   ├── resnet_prune_multibn.py
│   └── utils.py
├── losses                          # losses
│   └── nt_xent.py   
├── split                           # data split
│   ├── cifar100                        
│   └── cifar100_imbSub_with_subsets
├── eval_cifar.py                   # linear probing evaluation code
├── test.py                         # testing code
├── train.py                        # training code
├── train_sdclr.py                  # training code for sdclr
└── utils.py                        # utils

Quick Preview

A code snippet of the BCL is shown below.

train_datasets = memoboosted_CIFAR100(train_idx_list, args, root=args.data_folder, train=True)

# initialize momentum loss
shadow = torch.zeros(dataset_total_num).cuda()
momentum_loss = torch.zeros(args.epochs,dataset_total_num).cuda()

shadow, momentum_loss = train(train_loader, model, optimizer, scheduler, epoch, log, shadow, momentum_loss, args=args)
train_datasets.update_momentum_weight(momentum_loss, epoch)

During the training phase, track the momentum loss.

if epoch>1:
    new_average = (1.0 - args.momentum_loss_beta) * loss[batch_idx].clone().detach() + args.momentum_loss_beta * shadow[index[batch_idx]]
else:
    new_average = loss[batch_idx].clone().detach()
    
shadow[index[batch_idx]] = new_average
momentum_loss[epoch-1,index[batch_idx]] = new_average

Training

To train model on CIFAR-100-LT, simply run:

  • SimCLR
python train.py SimCLR --lr 0.5 --epochs 2000 --temperature 0.2 --weight_decay 5e-4 --data_folder ${data_folder} --trainSplit cifar100_imbSub_with_subsets/cifar100_split1_D_i.npy 
  • BCL-I
python train.py BCL_I --bcl --rand_k 1 --lr 0.5 --epochs 2000 --temperature 0.2 --weight_decay 5e-4 --data_folder ${data_folder} --trainSplit cifar100_imbSub_with_subsets/cifar100_split1_D_i.npy 
  • SDCLR
python train_sdclr.py SDCLR --lr 0.5 --epochs 2000 --temperature 0.2 --weight_decay 1e-4 --data_folder ${data_folder} --trainSplit cifar100_imbSub_with_subsets/cifar100_split1_D_i.npy 
  • BCL-D
python train_sdclr.py BCL_D --bcl --rand_k 2 --lr 0.5 --epochs 2000 --temperature 0.2 --weight_decay 1e-4 --data_folder ${data_folder} --trainSplit cifar100_imbSub_with_subsets/cifar100_split1_D_i.npy 

Pretrained checkpoints will be saved in 'checkpoints/'.

Evaluating

To evalutate the pretrained model, simply run:

  • SimCLR, BCL-I
python test.py --checkpoint ${checkpoint_pretrain} --test_fullshot --test_100shot --test_50shot --data_folder ${data_folder}
  • SDCLR, BCL-D
python test.py --checkpoint ${checkpoint_pretrain} --prune --test_fullshot --test_100shot --test_50shot --data_folder ${data_folder}

The code will output the results of full-shot/100-shot/50-shot linear probing evaluation.

Results and Pretrained Models

We provide the full-shot/100-shot/50-shot results(demo) pretrained on 'cifar100_split1_D_i.npy' with the corresponding checkpoint weights.

Method Full-shot 100-shot 50-shot Model
SimCLR 50.7 46.3 42.4 ResNet18
SDCLR 55.0 49.7 45.6 ResNet18
BCL-I 55.7 50.1 45.8 ResNet18
BCL-D 58.7 52.6 48.7 ResNet18

After downloading the checkpoints, you could run evaluation by the instructions in the evaluating section.

Extensions

Steps to Implement Your Own Model

  • Add your model to ./models and load the model in train.py.
  • Implement functions(./losses) specfic to your models in train.py.

Steps to Implement Other Datasets

  • Create long-tailed splits of the datasets and add to ./split.
  • Implement the dataset (e.g. memoboosted_cifar100.py).

Acknowledgement

We borrow some codes from SDCLR, RandAugment and W-MSE.

Releases

No releases published

Packages

No packages published

Languages