Skip to content

kanshichao/CBML

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Contrastive Bayesian Analysis for Deep Metric Learning

This code is mainly for reproducing the results reported on our TPAMI submitted paper Contrastive Bayesian Analysis for Deep Metric Learning. Beyound for this purpose, we will continue to maintain this project and provide tools for both supervised and unsupervised metric learning research. Aiming to integrate various loss functions and backbones to facilitate academic research progress on deep metric learning. Now, this project contains GoogleNet, BN-Inception, ResNet18, ResNet34, ResNet50, ResNet101 and ResNet152 backbones, and cbml_loss with log, square root and constant, crossentropy_loss, ms_loss, rank_loss, softtriple_loss, margin_loss, adv_loss, proxynca_loss, npair_loss, angular_loss, contrastive_loss, triplet_loss, cluster_loss, histogram_loss, center_loss and multiple losses.

Abstract

Recent methods for deep metric learning has been focusing on designing different contrastive loss functions betweenpositive and negative pairs of samples so that the learned feature embedding is able to pull positive samples of the same class closerand push negative samples from different classes away from each other. In this work, we recognize that there is a significant semanticgap between features at intermediate feature layers and class label decision at the final output layer. To bridge this gap, we develop a contrastive Bayesian analysis to characterize and model the posterior probabilities of image labels conditioned by their metric similarity in a contrastive learning setting. This contrastive Bayesian analysis leads to a new loss function for deep metric learning. To improve the generalization capability of the proposed method onto new classes, we further extend the contrastive Bayesian loss with a metric variance constraint. Our experimental results and ablation studies demonstrate that the proposed contrastive Bayesian metric learning method significantly improves the performance of deep metric learning, outperforming existing methods by a large margin.

Performance compared with SOTA methods on CUB-200-2011 for 512 dimensional embeddings

  • Googlenet Backbone
Recall@K 1 2 4 8
Contrastive 26.4 37.7 49.8 62.3
Triplet 36.1 48.6 59.3 70.0
LiftedStruct 47.2 58.9 70.2 80.2
Binomial Deviance 52.8 64.4 74.7 83.9
Histogram Loss 50.3 61.9 72.6 82.3
HDC 53.6 65.7 77.0 85.6
Angular Loss 54.7 66.3 76.0 83.9
BIER 55.3 67.2 76.9 85.1
A-BIER 57.5 68.7 78.3 82.6
Ours CBML-const-GoogleNet 62.8 73.9 83.2 89.8
Ours CBML-sqrt-GoogleNet 63.1 74.7 83.1 89.8
Ours CBML-log-GoogleNet 63.8 74.8 83.6 90.3
  • BN-Inception Backbone
Recall@K 1 2 4 8
Ranked List (H) 57.4 69.7 79.2 86.9
Ranked List (L,M,H) 61.3 72.7 82.7 89.4
SoftTriple 65.4 76.4 84.5 90.4
DeML 65.4 75.3 83.7 89.5
MS 65.7 77.0 86.3 91.2
Contrastive+HORDE 66.8 77.4 85.1 91.0
Ours CBML-const-BN-Inception 68.3 78.5 86.9 92.1
Ours CBML-sqrt-BN-Inception 69.5 79.5 86.7 91.8
Ours CBML-log-BN-Inception 69.5 79.4 87.0 92.4
  • ResNet50 Backbone
Recall@K 1 2 4 8
Devide-Conquer 65.9 76.6 84.4 90.6
MIC+Margin 66.1 76.8 85.6 -
TML 62.5 73.9 83.0 89.4
Ours CBML-const-ResNet50 69.2 79.3 86.3 91.6
Ours CBML-sqrt-ResNet50 70.0 79.9 87.0 92.0
Ours CBML-log-ResNet50 69.9 80.4 87.2 92.5

Performance compared with SOTA methods on CUB-200-2011 for 64 dimensional embeddings

  • ResNet50 Backbone
Recall@K 1 2 4 8
N-Pair 53.2 65.3 76.0 84.8
ProxyNCA 55.5 67.7 78.2 86.2
EPSHN 57.3 68.9 79.3 87.2
MS 57.4 69.8 80.0 87.8
Ours CBML-const-ResNet50 65.0 76.2 84.9 90.6
Ours CBML-sqrt-ResNet50 65.0 76.0 84.1 90.3
Ours CBML-log-ResNet50 64.3 75.7 84.1 90.1
  • ResNet18 Backbone
Recall@K 1 2 4 8
N-Pair 52.4 65.7 76.8 84.6
ProxyNCA 51.5 63.8 74.6 84.0
EPSHN 54.2 66.6 77.4 86.0
Ours CBML-const-ResNet18 58.0 69.6 80.0 87.5
Ours CBML-sqrt-ResNet18 59.4 70.5 80.4 88.0
Ours CBML-log-ResNet18 61.3 72.6 81.9 88.7
  • GoogleNet Backbone
Recall@K 1 2 4 8
Triplet 42.6 55.0 66.4 77.2
N-Pair 45.4 58.4 69.5 79.5
ProxyNCA 49.2 61.9 67.9 72.4
EPSHN 51.7 64.1 75.3 83.9
Ours CBML-const-GoogleNet 56.8 69.5 79.5 87.9
Ours CBML-sqrt-GoogleNet 57.7 69.7 80.5 88.3
Ours CBML-log-GoogleNet 59.3 70.7 80.6 88.1

Prepare the data and the pretrained model

The following script will prepare the CUB dataset for training by downloading to the ./resource/datasets/ folder; which will then build the data list (train.txt test.txt):

./scripts/prepare_cub.sh

To reproduce the results of our paper. Download the imagenet pretrained model of googlenet, bninception and resnet50, and put them in the folder: ~/.cache/torch/checkpoints/.

Installation

sudo pip3 install -r requirements.txt
sudo python3 setup.py develop build

Train and Test on CUB-200-2011 with CBML-Loss based on the BN-Inception backbone

./scripts/run_cub_bninception.sh

Trained models will be saved in the ./output-bninception-cub/ folder if using the default config.

Train and Test on CUB-200-2011 with CBML-Loss based on the ResNet50 backbone

./scripts/run_cub_resnet50.sh

Trained models will be saved in the ./output-resnet50-cub/ folder if using the default config.

Train and Test on CUB-200-2011 with CBML-Loss based on the GoogleNet backbone

./scripts/run_cub_googlenet.sh

Trained models will be saved in the ./output-googlenet-cub/ folder if using the default config.

Unsupervised training and test

Code will be released in other times.

Citation

If you use this method or this code in your research, please cite as:

@inproceedings{Shichao-2022,
title={Contrastive Bayesian Analysis for Deep Metric Learning},
author={Shichao Kan, Zhiquan He, Yigang Cen, Yang Li, Mladenovic Vladimir, and Zhihai He},
booktitle={IEEE Transactions on Pattern Analysis and Machine Intelligence},
pages={},
year={2022}
}

Acknowledgments

This code is written based on the framework of MS-Loss, we are really grateful to the authors of the MS paper to release their code for academic research / non-commercial use. We also thank the following helpful implementtaions on histogram, proxynca, n-pair and angular, siamese-triplet, clustering.

License

This code is released for academic research / non-commercial use only. If you wish to use for commercial purposes, please contact Shichao Kan by email kanshichao10281078@126.com.

Recommended Papers About Deep Metric Learning

About

Contrastive Bayesian Analysis for Deep Metric Learning and an Integrated Deep Metric Learning Toolbox Based on Pytorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published