Skip to content

[ICLR 2023] PyTorch implementation for "Long-Tailed Partial Label Learning via Dynamic Rebalancing"

License

Notifications You must be signed in to change notification settings

MediaBrain-SJTU/RECORDS-LTPLL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

39 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Long-Tailed Partial Label Learning
via Dynamic Rebalancing

Paper Paper Github Slides Poster

by Feng Hong, Jiangchao Yao, Zhihan Zhou, Ya Zhang, and Yanfeng Wang at SJTU and Shanghai AI Lab.

International Conference on Learning Representations (ICLR), 2023.

This repository is the official Pytorch implementation of RECORDS.

Citation

If you find our work inspiring or use our codebase in your research, please consider giving a star ⭐ and a citation.

@inproceedings{hong2023long,
  title={Long-Tailed Partial Label Learning via Dynamic Rebalancing},
  author={Hong, Feng and Yao, Jiangchao and Zhou, Zhihan and Zhang, Ya and Wang, Yanfeng},
  booktitle={{ICLR}},
  year={2023}
}

Overview

  • We delve into a more practical but under-explored LT-PLL scenario, and identify its several challenges in this task that cannot be addressed and even lead to failure by the straightforward combination of the current long-tailed learning and partial label learning.
  • We propose a novel RECORDS for LT-PLL that conducts the dynamic adjustment to rebalance the training without requiring any prior about the class distribution. The theoretical and empirical analysis show that the dynamic parametric class distribution is asymmetrically approaching to the oracle class distribution but more friendly to label disambiguation.
  • Our method is orthogonal to existing PLL methods and can be easily plugged into the current PLL methods in an end-to-end manner.

Get Started

Environment

The project is tested under the following environment settings:

  • OS: Ubuntu 18.04.5
  • GPU: NVIDIA GeForce RTX 3090
  • Python: 3.7.10
  • PyTorch: 1.7.1
  • Torchvision: 0.8.2
  • Cudatoolkit: 11.0.221
  • Numpy: 1.21.2

File Structure

After the preparation work, the whole project should have the following structure:

./RECORDS-LTPLL
├── README.md
├── models              # models
│   ├── resnet.py
├── utils               # utils: datasets, losses, etc.
│   ├── cifar10.py
│   ├── cifar100.py
│   ├── imbalance_cifar.py
│   ├── randaugment.py
│   ├── utils_algo.py
│   ├── utils_loss.py
├── utils_solar         # utils for SoLar
│   ├── data.py
│   ├── resnet.py
│   ├── general.py
├── train.py            # train for CORR (+ RECORDS)
└── train_solar.py      # train for SoLar

Quick Preview

For a self-training PLL loss:

# loss function (batch forwards)
loss = caculate_loss(logits, labels, self.confidence[index,:])
if update_target:
    # disambiguation
    self.confidence[index,:]=update_confidence(logits, self.confidence[index,:])

We can easily add RECORDS to the loss function:

# loss function (batch forwards)
loss = caculate_loss(logits, labels, self.confidence[index,:])
# momentum updates
if self.feat_mean is None:
    self.feat_mean = 0.1*feat.detach().mean(0)
else:
    self.feat_mean = 0.9*self.feat_mean + 0.1*feat.detach().mean(0)
if update_target:
    # debias and disambiguation
    bias = model.module.fc(self.feat_mean.unsqueeze(0)).detach()
    bias = F.softmax(bias, dim=1)
    logits_rebalanced = logits - torch.log(bias + 1e-9)
    self.confidence[index,:]=update_confidence(logits_rebalanced, self.confidence[index,:])

Data Preparation

CIFAR

For the CIFAR dataset, no additional data preparation is required. The first run will automatically download CIFAR to "./data".

PASCAL VOC 2007

Download the PLL version of PASCAL VOC 2007 and extract it to ". /data/VOC2017/". [Download (Google Drive)]

Running

Run CORR[1] on CIFAR-10-LT with $q=0.3$ and Imbalance ratio $\rho = 0.01$

CUDA_VISIBLE_DEVICES=0 python -u train.py --exp_dir experiment/CORR-CIFAR-10 --dataset cifar10_im --num_class 10 --dist_url 'tcp://localhost:10000' --multiprocessing_distributed --world_size 1 --rank 0 --seed 123 --arch resnet18 --upd_start 1 --lr 0.01 --wd 1e-3 --cosine --epochs 800 --print_freq 100 --partial_rate 0.3 --imb_factor 0.01

Run CORR + RECORDS on CIFAR-10-LT with $q=0.3$ and Imbalance ratio $\rho = 0.01$

CUDA_VISIBLE_DEVICES=0 python -u train.py --exp_dir experiment/CORR-CIFAR-10 --dataset cifar10_im --num_class 10 --dist_url 'tcp://localhost:10001' --multiprocessing_distributed --world_size 1 --rank 0 --seed 123 --arch resnet18 --upd_start 1 --lr 0.01 --wd 1e-3 --cosine --epochs 800 --print_freq 100 --partial_rate 0.3 --imb_factor 0.01 --records

Note: --records means to apply RECORDS on the PLL baseline.

Run CORR + RECORDS on CIFAR-100-LT-NU with $q=0.03$ and Imbalance ratio $\rho = 0.01$

CUDA_VISIBLE_DEVICES=0 python -u train.py --exp_dir experiment/CORR-CIFAR-100 --dataset cifar100_im --num_class 100 --dist_url 'tcp://localhost:10002' --multiprocessing_distributed --world_size 1 --rank 0 --seed 123 --arch resnet18 --upd_start 1 --lr 0.01 --wd 1e-3 --cosine --epochs 800 --print_freq 100 --partial_rate 0.03 --imb_factor 0.01 --records --hierarchical

Note: --hierarchical means using the non-uniform version of the dataset, i.e., CIFAR-100-LT-NU.

Run SoLar[2] (w/ Mixup) on CIFAR-10-LT with $q=0.3$ and Imbalance ratio $\rho = 0.01$

CUDA_VISIBLE_DEVICES=0 python -u train_solar.py --exp_dir experiment/SoLar-CIFAR-100 --dataset cifar10 --num_class 10 --partial_rate 0.3 --imb_type exp --imb_ratio 100 --est_epochs 100 --rho_range 0.2,0.6 --gamma 0.1,0.01 --epochs 800 --lr 0.01 --wd 1e-3 --cosine --seed 123

Note: SoLar is a concurrent LT-PLL work published in NeuIPS 2022. It improves the label disambiguation process in LT-PLL through the optimal transport technique. Different from SoLar, RECORDS tries to solve the LT-PLL problem from the perspective of rebalancing in a lightweight and effective manner.

Notes: On CIFAR-100-LT change these parameters to: --est_epochs 20 --rho_range 0.2,0.5 --gamma 0.05,0.01.

Run CORR + RECORDS (w/ Mixup) on CIFAR-10-LT with $q=0.3$ and Imbalance ratio $\rho = 0.01$

CUDA_VISIBLE_DEVICES=0 python -u train.py --exp_dir experiment/CORR-CIFAR-10 --dataset cifar10_im --num_class 10 --dist_url 'tcp://localhost:10003' --multiprocessing_distributed --world_size 1 --rank 0 --seed 123 --arch resnet18 --upd_start 1 --lr 0.01 --wd 1e-3 --cosine --epochs 800 --print_freq 100 --partial_rate 0.3 --imb_factor 0.01 --records --mixup

Note: --mixup means to use Mixup.

Results

CIFAR-10-LT

Imbalance ratio $\rho$ 50 50 50 100 100 100
ambiguity $q$ 0.3 0.5 0.7 0.3 0.5 0.7
CORR 76.12 56.45 41.56 66.38 50.09 38.11
CORR + Oracle-LA[3] 36.27 17.61 12.77 29.97 15.80 11.75
CORR + RECORDS 82.57 80.28 67.24 77.66 72.90 57.46
SoLar (w/ Mixup) 83.88 76.55 54.61 75.38 70.63 53.15
CORR + RECORDS (w/ Mixup) 84.25 82.5 71.24 79.79 74.07 62.25

CIFAR-100-LT

Imbalance ratio $\rho$ 50 50 50 100 100 100
ambiguity $q$ 0.03 0.05 0.07 0.03 0.05 0.07
CORR 42.29 38.03 36.59 38.39 34.09 31.05
CORR + Oracle-LA 22.56 5.59 3.12 11.37 3.32 1.98
CORR + RECORDS 48.06 45.56 42.51 42.25 40.59 38.65
SoLar (w/ Mixup) 47.93 46.85 45.1 42.51 41.71 39.15
CORR + RECORDS (w/ Mixup) 52.08 50.58 47.91 46.57 45.22 44.73

Extensions

To Implement Your Own Model

  • Add your model to "./models" and load the model in train.py.
  • Implement functions(./utils/utils_loss.py) specfic to your models in train.py.

To Implement Other Datasets

  • Create the PLL version of the datasets and add to "./data".
  • Implement the dataset (e.g., ./utils/cifar10.py).
  • Load your data in train.py.

Acknowledgements

We borrow some codes from PiCO, LDAM-DRW, PRODEN, SADE, and SoLar.

References

[1] DD Wu, DB Wang, ML Zhang. Revisiting consistency regularization for deep partial label learning. ICML. 2022.

[2] H Wang, M Xia, Y Li, et al. SoLar: Sinkhorn Label Refinery for Imbalanced Partial-Label Learning. NeurIPS. 2022.

[3] AK Menon, S Jayasumana, AS Rawat, et al. Long-tail learning via logit adjustment. ICLR. 2021.

Contact

If you have any problem with this code, please feel free to contact feng.hong@sjtu.edu.cn.

About

[ICLR 2023] PyTorch implementation for "Long-Tailed Partial Label Learning via Dynamic Rebalancing"

Topics

Resources

License

Stars

Watchers

Forks

Languages