Skip to content

NJUyued/PRG4SSL-MNAR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

74 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PRG4SSL-MNAR

This repo is the official Pytorch implementation of our paper:

Towards Semi-supervised Learning with Non-random Missing Labels
Authors: Yue Duan, Zhen Zhao, Lei Qi, Lei Wang, Luping Zhou and Yinghuan Shi

  • Quick links: [arXiv | Published paper | Poster | Zhihu | Code download]

  • Latest news:

    • We write a detailed introduction to this work on the Zhihu.
    • Our paper is accepted by International Conference on Computer Vision (ICCV) 2023 🎉🎉. Thanks to users.
  • Related works:

    • đź“Ť [MOST RELEVANT] Interested in robust SSL in MNAR setting with mismatched distributions? 👉 Check out our ECCV'22 paper RDA [arXiv | Repo].
    • 🆕 [LATEST] Interested in the SSL in fine-grained visual classification (SS-FGVC)? 👉 Check out our AAAI'24 paper SoC [arXiv | Repo].
    • Interested in the conventional SSL or more application of complementary label in SSL? 👉 Check out our TNNLS paper MutexMatch [arXiv | Repo].

Introduction

Semi-supervised learning (SSL) tackles the label missing problem by enabling the effective usage of unlabeled data. While existing SSL methods focus on the traditional setting, a practical and challenging scenario called label Missing Not At Random (MNAR) is usually ignored. In MNAR, the labeled and unlabeled data fall into different class distributions resulting in biased label imputation, which deteriorates the performance of SSL models. In this work, class transition tracking based Pseudo-Rectifying Guidance (PRG) is devised for MNAR. We explore the class-level guidance information obtained by the Markov random walk, which is modeled on a dynamically created graph built over the class tracking matrix. PRG unifies the history information of each class transition caused by the pseudo-rectifying procedure to activate the model's enthusiasm for neglected classes, so as the quality of pseudo-labels on both popular classes and rare classes in MNAR could be improved.

Requirements

  • numpy==1.21.6
  • pandas==1.3.2
  • Pillow==10.0.0
  • scikit_learn==1.3.0
  • torch==1.8.0
  • torchvision==0.9.0

How to Train

Important Args

  • --last: Set this flag to use the model of $\textrm{PRG}^{\textrm{Last}}$.
  • --alpha: class invariance coefficient. By default, --alpha 1 is set. When set --last, please set --alpha 3.
  • --nb: Number of tracked bathches.
  • --mismatch [none/prg/cadr/darp/darp_reversed] : Select the MNAR protocol. none means the conventional balanced setting. See Sec. 4 in our paper for the details of MNAR protocols.
  • --n0 : When --mismatch prg, this arg means the imbalanced ratio $N_0$ for labeled data; When --mismatch [darp/darp_reversed], this arg means the imbalanced ratio $\gamma_l$ for labeled data.
  • --gamma : When --mismatch cadr, this arg means the imbalanced ratio $\gamma$ for labeled data. When --mismatch prg, this arg means the imbalanced ratio $\gamma$ for unlabeled data; When --mismatch DARP/DARP_reversed, this arg means the imbalanced ratio $\gamma_u$ for unlabeled data.
  • --num_labels : Amount of labeled data used in conventional balanced setting.
  • --net : By default, Wide ResNet (WRN-28-2) are used for experiments. If you want to use other backbones for tarining, set --net [resnet18/preresnet/cnn13]. We provide alternatives as follows: ResNet-18, PreAct ResNet and CNN-13.
  • --dataset [cifar10/cifar100/miniimage] and --data_dir : Your dataset name and path.
  • --num_eval_iter : After how many iterations, we evaluate the model. Note that although we show the accuracy of pseudo-labels on unlabeled data in the evaluation, this is only to show the training process. We did not use any information about labels for unlabeled data in the training.

Training with Single GPU

We recommend using a single GPU for training to better reproduce our results. Multi-GPU training is feasible, but our results are all obtained from single GPU training.

python train_prg.py --world-size 1 --rank 0 --gpu [0/1/...] @@@other args@@@

Training with Multi-GPUs

  • Using DataParallel
python train_prg.py --world-size 1 --rank 0 @@@other args@@@
  • Using DistributedDataParallel with single node
python train_prg.py --world-size 1 --rank 0 --multiprocessing-distributed @@@other args@@@

Examples of Running

By default, the model and dist&index.txt will be saved in \--save_dir\--save_name. The file dist&index.txt will display detailed settings of MNAR. This code assumes 1 epoch of training, but the number of iterations is 2**20. For CIFAR-100, you need set --widen_factor 8 for WRN-28-8 whereas WRN-28-2 is used for CIFAR-10. Note that you need set --net resnet18 for mini-ImageNet.

MNAR Settings

CADR's protocol in Tab. 1

  • CIFAR-10 with $\gamma=20$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --mismatch cadr --gamma 20 --gpu 0
  • CIFAR-100 with $\gamma=50$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar100 --dataset cifar100 --num_classes 100 --num_labels 400 --mismatch cadr --gamma 50 --gpu 0 --widen_factor 8
  • mini-ImageNet with $\gamma=50$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name miniimage --dataset miniimage --num_classes 100 --num_labels 1000 --mismatch cadr --gamma 50 --gpu 0 --net resnet18 

Our protocol in Tab. 2

  • CIFAR-10 with 40 labels and $N_0=10$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --mismatch prg --n0 10 --gpu 0
  • CIFAR-100 with 400 labels and $N_0=40$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar100 --dataset cifar100 --num_classes 100 --num_labels 400 --mismatch prg --n0 40 --gpu 0 --widen_factor 8
  • mini-ImageNet with 1000 labels and $N_0=40$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name miniimage --dataset miniimage --num_classes 100 --num_labels 1000 --mismatch prg --n0 40 --gpu 0 --net resnet18 

Our protocol in Fig. 6(a)

  • CIFAR-10 with 40 labels, $N_0=10$ and $\gamma=5$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --mismatch prg --n0 10 --gamma 5 --gpu 0

Our protocol in Tab. 10

  • CIFAR-10 with 40 labels and $\gamma=20$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --mismatch prg --gamma 20 --gpu 0

DARP's protocol in Fig. 6(a)

  • CIFAR-10 with $\gamma_l=100$ and $\gamma_u=1$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --mismatch darp --n0 100 --gamma 1 --gpu 0
  • CIFAR-10 with $\gamma_l=100$ and $\gamma_u=100$ (reversed)
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --mismatch darp_reversed --n0 100 --gamma 100 --gpu 0

Conventional Setting

Matched and balanced distribution in Tab. 11

  • CIFAR-10 with 40 labels
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40  --gpu 0

Resume Training and Evaluation

If you restart the training, please use --resume --load_path @your_weight_path.

For evaluation, run

python eval_prg.py --load_path @your_weight_path --dataset [cifar10/cifar100/miniimage] --data_dir @your_dataset_path --num_classes @number_of_classes

By default, WideResNet-28-2 backbone is used for CIFAR-10. Use --widen-factor 8 (i.e., WideResNet-28-8) for CIFAR-100 and --net resnet18 for mini-ImageNet.

Results (e.g., seed=1)

Dateset Labels N0 gamma Acc Setting Method Weight
CIFAR-10 40 - - 94.05 Conventional settings PRG here
250 - - 94.36 here
4000 - - 95.48 here
40 - - 93.79 Conventional settings PRG^Last here
250 - - 94.76 here
4000 - - 95.75 here
- - 20 94.04 CADR's protocol PRG here
- - 50 93.78 here
- - 100 94.51 here
- - 20 94.74 CADR's protocol PRG^Last here
- - 50 94.74 here
- - 100 94.75 here
40 10 - 93.81 Ours protocol PRG here
40 20 - 93.39 here
40 10 2 90.25 here
40 10 5 82.84 here
100 40 5 79.58 here
100 40 10 78.61 here
250 100 - 93.76 here
250 200 - 91.65 here
40 10 - 91.59 Ours protocol PRG^Last here
40 20 - 80.31 here
250 100 - 91.36 here
250 200 - 62.16 here
DARP 100 1 94.41 DARP's protocol PRG here
DARP 100 50 78.28 here
DARP 100 150 75.21 here
DARP (reversed) 100 100 80.86 here
CIFAR-100 400 - - 48.70 Conventional settings PRG here
2500 - - 69.81 here
10000 - - 76.91 here
400 - - 48.66 Conventional settings PRG^Last here
2500 - - 70.03 here
10000 - - 76.93 here
- - 50 58.57 CADR's protocol PRG here
- - 100 62.28 here
- - 200 59.33 here
- - 50 60.32 CADR's protocol PRG^Last here
- - 100 62.13 here
- - 200 58.70 here
2500 100 - 57.56 Ours protocol PRG here
2500 200 - 51.21 here
2500 100 - 59.40 Ours protocol PRG^Last here
2500 200 - 42.09 here
mini-ImageNet 1000 - - 45.74 Conventional settings PRG here
1000 - - 48.63 Conventional settings PRG^Last here
- - 50 43.74 CADR's protocol PRG here
- - 100 43.74 here
- - 50 42.22 CADR's protocol PRG^Last here
- - 100 43.74 here
1000 40 - 40.75 Ours protocol PRG here
1000 80 - 35.86 here
1000 40 - 39.79 Ours protocol PRG^Last here
1000 80 - 32.64 here

Citation

Please cite our paper if you find PRG useful:

@inproceedings{duan2023towards,
  title={Towards Semi-supervised Learning with Non-random Missing Labels},
  author={Duan, Yue and Zhao, Zhen and Qi, Lei and Zhou, Luping and Wang, Lei and Shi, Yinghuan},
  booktitle={IEEE/CVF International Conference on Computer Vision},
  year={2023}
}

or

@article{duan2023towards,
  title={Towards Semi-supervised Learning with Non-random Missing Labels},
  author={Duan, Yue and Zhao, Zhen and Qi, Lei and Zhou, Luping and Wang, Lei and Shi, Yinghuan},
  journal={arXiv preprint arXiv:2308.08872},
  year={2023}
}

Releases

No releases published

Packages

No packages published

Languages