Skip to content
This repository has been archived by the owner on Nov 9, 2022. It is now read-only.
/ Att-Induction Public archive

Attention-based Induction Networks for Few-Shot Text Classification

License

Notifications You must be signed in to change notification settings

ShaneTian/Att-Induction

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

English | 简体中文

Att-Induction: Attention-based Induction Networks for Few-Shot Text Classification

issues-open issues-closed license

Code for paper Attention-based Induction Networks for Few-Shot Text Classification.

Table of Contents

Introduction

Attention-based Induction Networks is a model for few-shot text classification, which continues the work of Induction Networks.

Attention-based Induction Networks can learn different class representations for diverse queries by the multi-head self-attention, in which induction module pays more attention to effective instances and feature dimensions for current query. In addition, we use the pre-trained model instead of training an encoder from scratch, which can capture more semantic information in the few-shot learning scenarios. Experiment results show that, on three public datasets and a real-world dataset, this model significantly outperforms the existing state-of-the-art approaches.

Datasets

  • ARSC: Amazon Review Sentiment Classification. This dataset is proposed by Yu in the NAACL 2018 paper Diverse few-shot text classification with multiple metrics. The dataset is downloaded from DiverseFewShot_Amazon. We use the same settings as Geng.
  • HuffPost Headlines: This dataset is published in kaggle -- News Category Dataset. We use a subset of the entire dataset following Bao et al. We split it in the ./src/utils.py.
  • 20 Newsgroups: This dataset was originally collected by Lang. The dataset is downloaded from Distributional-Signatures. We split it in the ./src/utils.py.
  • Controversial Issues: This dataset consists of controversial issues during the trial. It is a real-world dataset. We create this dataset by choosing Labour Disputes (Disp-L) and Product Liability Disputes (Disp-PL).

Usage

Requirements

You can use pip install -r requirements.txt to install the following dependent packages:

  • python-version
  • pytorch-version
  • transformers-version
  • numpy-version
  • pandas-version
  • matplotlib-version

Training

Training scripts are placed in ./scripts/. You only need to modify some training parameters in a shell file, and then run it on the terminal. For example:

bash ./scripts/run_train_HuffPost.sh

You can use python3 train.py -h to see all available parameters.

Test

In fact, if the --test_data is given in the training, the test task will be always performed after training. Of course, you can perform a separate test task by specifying --load_checkpoint and --only_test in the training script.

Maintainers

@ShaneTian.

Citation

License

Apache License 2.0 © ShaneTian

Releases

No releases published

Packages

No packages published