Skip to content

Latest commit

 

History

History
183 lines (139 loc) · 8.76 KB

README.md

File metadata and controls

183 lines (139 loc) · 8.76 KB

Text-AutoAugment (TAA)

This repository contains the code for our paper Text AutoAugment: Learning Compositional Augmentation Policy for Text Classification (EMNLP 2021 main conference).

Overview of IAIS

Updates

  • [22.02.23]: We add an example on how to use TAA for your custom (local) dataset.
  • [21.10.27]: We make taa installable as a package and adapt to huggingface/transformers. Now you can search augmentation policy for the huggingface dataset with TWO lines of code.

Quick Links

Overview

  1. We present a learnable and compositional framework for data augmentation. Our proposed algorithm automatically searches for the optimal compositional policy, which improves the diversity and quality of augmented samples.

  2. In low-resource and class-imbalanced regimes of six benchmark datasets, TAA significantly improves the generalization ability of deep neural networks like BERT and effectively boosts text classification performance.

Getting Started

Prepare environment

Install pytorch and other small additional dependencies. Then, install this repo as a python package. Note that cudatoolkit=10.2 should match the CUDA version on your machine.

# Clone this repo
git clone https://github.com/lancopku/text-autoaugment.git
cd text-autoaugment

# Create a conda environment
conda create -n taa python=3.6
conda activate taa

# Install dependencies
pip install torch==1.10.1+cu102 -f https://download.pytorch.org/whl/cu102/torch_stable.html
pip install git+https://github.com/wbaek/theconf
pip install git+https://github.com/ildoonet/pystopwatch2.git
pip install -r requirements.txt

# Install this library (**no need to re-build if the source code is modified**)
python setup.py develop

# Download the models in NLTK
python -c "import nltk; nltk.download('wordnet'); nltk.download('averaged_perceptron_tagger'); nltk.download('omw-1.4')"

Please make sure your Torch supports GPU, check it with the command python -c "import torch; print(torch.cuda.is_available())" (should output True).

Use TAA with Huggingface

1. Get augmented training dataset with TAA policy

Option 1: Search for the optimal policy

You can search for the optimal policy on classification datasets supported by huggingface/datasets:

from taa.search_and_augment import search_and_augment

# return the augmented train dataset in the form of torch.utils.data.Dataset
augmented_train_dataset = search_and_augment(configfile="/path/to/your/config.yaml")

The configfile (YAML file) contains all the arguments including path, model, dataset, optimization hyper-parameter, etc. To successfully run the code, please carefully preset these arguments:

show details
  • model:

    • type: backbone model
  • dataset:

    • path: Path or name of the dataset
    • name: Defining the name of the dataset configuration
    • data_dir: Defining the data_dir of the dataset configuration
    • data_files: Path(s) to source data file(s)

    ATTENTION: All the augments above are used for the load_dataset() function in huggingface/datasets. Please refer to link for details.

    • text_key: Used to get text from a data instance (dict form in huggingface/datasets. See this IMDB example.)
  • abspath: Your working directory

  • aug: Pre-searched policy. Now we support IMDB, SST5, TREC, YELP2 and YELP5. See archive.py.

  • per_device_train_batch_size: Batch size per device for training

  • per_device_eval_batch_size: Batch size per device for evaluation

  • epoch: Training epoch

  • lr: Learning rate

  • max_seq_length

  • n_aug: Augment each text sample n_aug times

  • num_op: Number of operations per sub-policy

  • num_policy: Number of sub-policy per policy

  • method: Search method (taa)

  • topN: Ensemble topN sub-policy to get final policy

  • ir: Imbalance rate

  • seed: Random seed

  • trail: Trail under current random seed

  • train:

    • npc: Number of examples per class in the training dataset
  • valid:

    • npc: Number of examples per class in the val dataset
  • test:

    • npc: Number of examples per class in the test dataset
  • num_search: Number of optimization iteration

  • num_gpus: Number of GPUs used in RAY

  • num_cpus: Number of CPUs used in RAY

configfile example 1: TAA for huggingface dataset

bert_sst2_example.yaml is a configfile example for BERT model and SST2 dataset. You can follow this example to create your own configfile for other huggingface dataset.

For instance, if you only want to change the dataset from sst2 to imdb, just delete the sst2 in the 'path' argument, modify the 'name' to imdb and modity the 'text_key' to text. The result should be like bert_imdb_example.yaml.

configfile example 2: TAA for custom (local) dataset

bert_custom_data_example.yaml is a configfile example for BERT model and custom (local) dataset. The custom dataset should be in the CSV format, and the column name of the data table should be text and label. custom_data.csv is an example of the custom dataset.

WARNING: The policy optimization framework is based on ray. By default we use 4 GPUs and 40 CPUs for policy optimization. Make sure your computing resources meet this condition, or you will need to create a new configuration file. And please specify the gpus, e.g., CUDA_VISIBLE_DEVICES=0,1,2,3 before using the above code. TPU does not seem to be supported now.

Option 2: Use our pre-searched policy

To train a model on the datasets augmented by our pre-searched policy, please use (Take IMDB as an example):

from taa.search_and_augment import augment_with_presearched_policy

# return the augmented train dataset in the form of torch.utils.data.Dataset
augmented_train_dataset = augment_with_presearched_policy(configfile="/path/to/your/config.yaml")

Now we support IMDB, SST5, TREC, YELP2 and YELP5. See archive.py for details.

This table lists the test accuracy (%) of pre-searched TAA policy on full datasets:

Dataset IMDB SST-5 TREC YELP-2 YELP-5
No Aug 88.77 52.29 96.40 95.85 65.55
TAA 89.37 52.55 97.07 96.04 65.73
n_aug 4 4 4 2 2

More pre-searched policies and their performance will be COMING SOON.

2. Fine-tune a new model on the augmented training dataset

After getting augmented_train_dataset, you can load it to the huggingface trainer directly. Please refer to search_augment_train.py for details.

Reproduce results in the paper

Please see examples/reproduce_experiment.py, and run script/huggingface_lowresource.sh or script/huggingface_imbalanced.sh.

Contact

If you have any questions related to the code or the paper, feel free to open an issue.

Acknowledgments

Code refers to: fast-autoaugment.

Citation

If you find this code useful for your research, please consider citing:

@inproceedings{ren2021taa,
    title = "Text {A}uto{A}ugment: Learning Compositional Augmentation Policy for Text Classification",
    author = "Ren, Shuhuai and Zhang, Jinchao and Li, Lei and Sun, Xu and Zhou, Jie",
    booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing",
    year = "2021",
}

License

MIT