Skip to content

KentoNishi/Augmentation-for-LNL

Repository files navigation

Augmentation-for-LNL

PWC

Code for Augmentation Strategies for Learning with Noisy Labels (CVPR 2021).

Authors: Kento Nishi*, Yi Ding*, Alex Rich, Tobias Höllerer [*: equal contribution]

Abstract Imperfect labels are ubiquitous in real-world datasets. Several recent successful methods for training deep neural networks (DNNs) robust to label noise have used two primary techniques: filtering samples based on loss during a warm-up phase to curate an initial set of cleanly labeled samples, and using the output of a network as a pseudo-label for subsequent loss calculations. In this paper, we evaluate different augmentation strategies for algorithms tackling the "learning with noisy labels" problem. We propose and examine multiple augmentation strategies and evaluate them using synthetic datasets based on CIFAR-10 and CIFAR-100, as well as on the real-world dataset Clothing1M. Due to several commonalities in these algorithms, we find that using one set of augmentations for loss modeling tasks and another set for learning is the most effective, improving results on the state-of-the-art and other previous methods. Furthermore, we find that applying augmentation during the warm-up period can negatively impact the loss convergence behavior of correctly versus incorrectly labeled samples. We introduce this augmentation strategy to the state-of-the-art technique and demonstrate that we can improve performance across all evaluated noise levels. In particular, we improve accuracy on the CIFAR-10 benchmark at 90% symmetric noise by more than 15% in absolute accuracy, and we also improve performance on the real-world dataset Clothing1M.

Banner

View on arXiv / View PDF / Download Paper Source / Download Source Code

Thumbnail
Watch CVPR Video

Benchmarks

All Benchmarks

Key

Annotation Meaning
Small Worse or equivalent to previous state-of-the-art
Normal Better than previous state-of-the-art
Bold Best in task/category

CIFAR-10

Model Metric Noise Type/Ratio
20% sym 50% sym 80% sym 90% sym 40% asym
Runtime-W (Vanilla DivideMix) Highest 96.100% 94.600% 93.200% 76.000% 93.400%
Last 10 95.700% 94.400% 92.900% 75.400% 92.100%
Raw Highest 85.940% 27.580%
Last 10 83.230% 23.915%
Expansion.Weak Highest 90.860% 31.220%
Last 10 89.948% 10.000%
Expansion.Strong Highest 90.560% 35.100%
Last 10 89.514% 34.228%
AugDesc-WW Highest 96.270% 36.050%
Last 10 96.084% 23.503%
Runtime-S Highest 96.540% 70.470%
Last 10 96.327% 70.223%
AugDesc-SS Highest 96.470% 81.770%
Last 10 96.193% 81.540%
AugDesc-WS.RandAug.n1m6 Highest 96.280% 89.750%
Last 10 96.006% 89.629%
AugDesc-WS.SAW Highest 96.350% 95.640% 93.720% 35.330% 94.390%
Last 10 96.138% 95.417% 93.563% 10.000% 94.078%
AugDesc-WS (WAW) Highest 96.330% 95.360% 93.770% 91.880% 94.640%
Last 10 96.168% 95.134% 93.641% 91.760% 94.258%

CIFAR-100

Model Metric Noise Type/Ratio
20% sym 50% sym 80% sym 90% sym
Runtime-W (Vanilla DivideMix) Highest 77.300% 74.600% 60.200% 31.500%
Last 10 76.900% 74.200% 59.600% 31.000%
Raw Highest 52.240% 7.990%
Last 10 39.176% 2.979%
Expansion.Weak Highest 57.110% 7.300%
Last 10 53.288% 2.223%
Expansion.Strong Highest 55.150% 7.540%
Last 10 54.369% 3.242%
AugDesc-WW Highest 78.900% 30.330%
Last 10 78.437% 29.876%
Runtime-S Highest 79.890% 40.520%
Last 10 79.395% 40.343%
AugDesc-SS Highest 79.790% 38.850%
Last 10 79.511% 38.553%
AugDesc-WS.RandAug.n1m6 Highest 78.060% 36.890%
Last 10 77.826% 36.672%
AugDesc-WS.SAW Highest 79.610% 77.640% 61.830% 17.570%
Last 10 79.464% 77.522% 61.632% 15.050%
AugDesc-WS (WAW) Highest 79.500% 77.240% 66.360% 41.200%
Last 10 79.216% 77.010% 66.046% 40.895%

Clothing1M

Model Accuracy
Runtime-W (Vanilla DivideMix) 74.760%
AugDesc-WS (WAW) 74.720%
AugDesc-WS.SAW 75.109%
Summary Metrics

CIFAR-10

Model Metric Noise Type/Ratio
20% sym 50% sym 80% sym 90% sym 40% asym
SOTA Highest 96.100% 94.600% 93.200% 76.000% 93.400%
Last 10 95.700% 94.400% 92.900% 75.400% 92.100%
Ours Highest 96.540% 95.640% 93.770% 91.880% 94.640%
Last 10 96.327% 95.417% 93.641% 91.760% 94.258%

CIFAR-100

Model Metric Noise Type/Ratio
20% sym 50% sym 80% sym 90% sym
SOTA Highest 77.300% 74.600% 60.200% 31.500%
Last 10 76.900% 74.200% 59.600% 31.000%
Ours Highest 79.890% 77.640% 66.360% 41.200%
Last 10 79.511% 77.522% 66.046% 40.895%

Clothing1M

Model Accuracy
SOTA 74.760%
Ours 75.109%

Training Locally

The source code is heavily reliant on CUDA. Please make sure that you have the newest version of Pytorch and a compatible version of CUDA installed. Using older versions may exhibit inconsistent performance.

Download Pytorch / Download CUDA

Other requirements are included in requirements.txt.

Reproducibility

At particularly high noise ratios (ex. 90% on CIFAR-10), results may vary across training runs. We are aware of this issue, and are exploring ways to yield more consistent results. We will publish any findings (consistently performant configurations, improved procedures, etc.) both in this repository and in continuations of this work.

All training configurations and parameters are controlled via the presets.json file. Configurations can contain infinite subconfigurations, and settings specified in subconfigurations always override the parent.

To train locally, first add your local machine to the presets.json:

{
    // ... inside the root scope
    "machines": { // list of machines
        "localPC": { // name for your local PC, can be anything
            "checkpoint_path": "./localPC_checkpoints"
        }
    },
    "configs": {
        "c10": { // cifar-10 dataset
            "machines": { // list of machines
                "localPC": { // local PC name
                    "data_path": "/path/to/your/dataset"
                    // path to dataset (python) downloaded from:
                    // https://www.cs.toronto.edu/~kriz/cifar.html
                }
                // ... keep all other machines unchanged
            }
            // ... keep all other config values unchanged
        }
        // ... keep all other configs unchanged
    }
    // ... keep all other global values unchanged
}

A "preset" is a specific configuration branch. For example, if you would like to run train_cifar.py with the preset root -> c100 -> 90sym -> AugDesc-WS on your machine named localPC, you can run the following command:

python train_cifar.py --preset c100.90sym.AugDesc-WS --machine localPC

The script will begin training the preset specified by the --preset argument. Progress will be saved in the appropriate directory in your specified checkpoint_path. Additionally, if the --machine flag is ommitted, the training script will look for the dataset in the data_path inherited from parent configurations.

Here are some abbreviations used in our presets.json:

Abbreviation Meaning
c10 CIFAR-10
c100 CIFAR-100
c1m Clothing1M
sym Symmetric Noise
asym Asymmetric Noise
SAW Strongly Augmented Warmup
WAW Weakly Augmented Warmup
RandAug RandAugment

Citations

Please cite the following:

@InProceedings{Nishi_2021_CVPR,
    author    = {Nishi, Kento and Ding, Yi and Rich, Alex and {H{\"o}llerer, Tobias},
    title     = {Augmentation Strategies for Learning With Noisy Labels},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2021},
    pages     = {8022-8031}
}

Extras

Extra bits of unsanitized code for plotting, training, etc. can be found in the Aug-for-LNL-Extras repository.

Additional Info

This repository is a fork of the official DivideMix implementation.