Skip to content

CVPR 2024: AllSpark: Reborn Labeled Features from Unlabeled in Transformer for Semi-Supervised Semantic Segmentation

License

Notifications You must be signed in to change notification settings

xmed-lab/AllSpark

Repository files navigation

[CVPR-2024] AllSpark: Reborn Labeled Features from Unlabeled in Transformer for Semi-Supervised Semantic Segmentation

PWC
PWC
PWC
PWC
PWC
PWC
PWC
PWC
PWC
PWC
PWC
PWC

This repo is the official implementation of AllSpark: Reborn Labeled Features from Unlabeled in Transformer for Semi-Supervised Semantic Segmentation which is accepted at CVPR-2024.

The AllSpark is a powerful Cybertronian artifact in the film series of Transformers. It was used to reborn Optimus Prime in Transformers: Revenge of the Fallen, which aligns well with our core idea.


πŸ’₯ Motivation

In this work, we discovered that simply converting existing semi-segmentation methods into a pure-transformer framework is ineffective.

  • The first reason is that transformers inherently possess weaker inductive bias compared to CNNs, so transformers heavily rely on a large volume of training data to perform well.

  • The more critical issue lies in the existing semi-supervised segmentation frameworks. These frameworks separate the training flows for labeled and unlabeled data, which aggravates the overfitting issue of transformers on the limited labeled data.

Thus, we propose to intervene and diversify the labeled data flow with unlabeled data in the feature domain, leading to improvements in generalizability.


πŸ› οΈ Usage

‼️ IMPORTANT: This version is not the final version. We made some mistakes when re-organizing the code. We will release the correct version soon. Sorry for any inconvenience this may cause.

1. Environment

First, clone this repo:

git clone https://github.com/xmed-lab/AllSpark.git
cd AllSpark/

Then, create a new environment and install the requirements:

conda create -n allspark python=3.7
conda activate allspark
pip install torch==1.12.0+cu116 torchvision==0.13.0+cu116 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu116
pip install tensorboard
pip install six
pip install pyyaml
pip install -U openmim
mim install mmcv==1.6.2
pip install einops
pip install timm

2. Data Preparation & Pre-trained Weights

2.1 Pascal VOC 2012 Dataset

Download the dataset with wget:

wget https://hkustconnect-my.sharepoint.com/:u:/g/personal/hwanggr_connect_ust_hk/EcgD_nffqThPvSVXQz6-8T0B3K9BeUiJLkY_J-NvGscBVA\?e\=2b0MdI\&download\=1 -O pascal.zip
unzip pascal.zip

2.2 Cityscapes Dataset

Download the dataset with wget:

wget https://hkustconnect-my.sharepoint.com/:u:/g/personal/hwanggr_connect_ust_hk/EWoa_9YSu6RHlDpRw_eZiPUBjcY0ZU6ZpRCEG0Xp03WFxg\?e\=LtHLyB\&download\=1 -O cityscapes.zip
unzip cityscapes.zip

2.3 COCO Dataset

Download the dataset with wget:

wget https://hkustconnect-my.sharepoint.com/:u:/g/personal/hwanggr_connect_ust_hk/EXCErskA_WFLgGTqOMgHcAABiwH_ncy7IBg7jMYn963BpA\?e\=SQTCWg\&download\=1 -O coco.zip
unzip coco.zip

Then your file structure will be like:

β”œβ”€β”€ VOC2012
    β”œβ”€β”€ JPEGImages
    └── SegmentationClass
    
β”œβ”€β”€ cityscapes
    β”œβ”€β”€ leftImg8bit
    └── gtFine
    
β”œβ”€β”€ coco
    β”œβ”€β”€ train2017
    β”œβ”€β”€ val2017
    └── masks

Next, download the following pretrained weights.

β”œβ”€β”€ ./pretrained_weights
    β”œβ”€β”€ mit_b2.pth
    β”œβ”€β”€ mit_b3.pth
    β”œβ”€β”€ mit_b4.pth
    └── mit_b5.pth

For example, mit-B5:

mkdir pretrained_weights
wget https://hkustconnect-my.sharepoint.com/:u:/g/personal/hwanggr_connect_ust_hk/ET0iubvDmcBGnE43-nPQopMBw9oVLsrynjISyFeGwqXQpw?e=9wXgso\&download\=1 -O ./pretrained_weights/mit_b5.pth

3. Training & Evaluating

# use torch.distributed.launch
sh scripts/train.sh <num_gpu> <port>
# to fully reproduce our results, the <num_gpu> should be set as 4 on all three datasets
# otherwise, you need to adjust the learning rate accordingly

# or use slurm
# sh scripts/slurm_train.sh <num_gpu> <port> <partition>

To train on other datasets or splits, please modify dataset and split in train.sh.

4. Results

Model weights and training logs will be released soon.

4.1 PASCAL VOC 2012 original

Splits 1/16 1/8 1/4 1/2 Full
Weights of AllSpark 76.07 78.41 79.77 80.75 82.12
Reproduced 76.06 | log 78.41 79.93 | log 80.70 | log 82.56 | log

4.2 PASCAL VOC 2012 augmented

Splits 1/16 1/8 1/4 1/2
Weights of AllSpark 78.32 79.98 80.42 81.14

4.3 Cityscapes

Splits 1/16 1/8 1/4 1/2
Weights of AllSpark 78.33 79.24 80.56 81.39

4.4 COCO

Splits 1/512 1/256 1/128 1/64
Weights of AllSpark 34.10 | log 41.65 | log 45.48 | log 49.56 | log

Citation

If you find this project useful, please consider citing:

@inproceedings{allspark,
  title={AllSpark: Reborn Labeled Features from Unlabeled in Transformer for Semi-Supervised Semantic Segmentation},
  author={Wang, Haonan and Zhang, Qixiang and Li, Yi and Li, Xiaomeng},
  booktitle={CVPR},
  year={2024}
}

Acknowlegement

AllSpark is built upon UniMatch and SegFormer. We thank their authors for making the source code publicly available.