Skip to content

zhouhaoyi/TripletAttention

Repository files navigation

Triplet Attention: Rethinking the similarity in Transformers

This is the pytorch implementation of Triplet Attention in the KDD'21 paper: Triplet Attention: Rethinking the similarity in Transformers.

Requirements

  • Python 3.6
  • numpy==1.17.3
  • scipy==1.1.0
  • pandas==0.25.1
  • torch==1.2.0
  • tqdm==4.36.1
  • matplotlib==3.1.1
  • tokenizers==0.10.3
  • ...

Dependencies can be installed using the following command:

pip install -r requirements.txt

Usage

We implement BERT-A3 and DistilBERT-A3 in huggingface transformers, you can use BERT-A3 or DistilBERT-A3 model like BERT or DistilBERT model in huggingface transformers.

build BERT-A3

from transformers import BertTokenizer, BertModel, BertConfig

config = BertConfig.from_pretrained('bert-base-uncased')

config.group_size = 2 # number of triplet attention head
config.cross_type = 0 # cross product type (0: L cross product with permutation, 1: L*L cross product)
config.agg_type = 0 # aggregation type when using L*L cross product
config.absolute_flag = 0 # whether to use absolute value of triplet attention (1: use abs)
config.random_flag = 0 # permutation type (0: multi permutation, 1:only random permutation)
config.permute_type = '1,2,3,4,5' # permutaion type groups
config.permute_back = 0 # whether to do permutation inverse (0: do permutation inverse)
config.Tlayers = '0,1,2,8,9,10' # layers which use triplet attention heads
config.key2_flag = 0 # whether to use key2 linear layer to get key_triplet2 (0: use key2 layer)
config.head_choice = 12 # whether to choose triplet attention heads randomly (0: choose the last 3*group_size heads as triplet attention heads, 1: randomly choose 3*group_size heads as triplet attention heads)

bert_A3 = BertModel.from_pretrained('bert-base-uncased', config=config)

use BERT-A3

from transformers import BertTokenizer, BertModel, BertConfig

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

config = BertConfig.from_pretrained('bert-base-uncased')
bert_A3 = BertModel.from_pretrained('bert-base-uncased', config=config)

inputs = tokenizer('hello world',return_tensors="pt")
outputs = bert_A3(**inputs)

You can refer to /transformers/models/bert/modeling_bert.py to get more details.

Train Commands

Commands for training and testing the model BERT-A3 on GLUE task (rte):

python run_glue.py --model_name_or_path bert-base-uncased --task_name rte --do_train --do_eval --do_predict --max_seq_length 128 --per_device_train_batch_size 16 --per_device_eval_batch_size 16 --cross_type 0 --agg_type 0 --tlayers '0,1,2,3' --learning_rate 3e-5 --num_train_epochs 6 --key2_flag 0 --random_flag 0 --absolute_flag 0 --permute_back 0 --permute_type '0,1,3,5,6' --head_choice 0 --group_size 1 --overwrite_output_dir --output_dir ./run/

Citation

If you find this repository useful in your research, please consider citing the following paper: Hits

@inproceedings{haoyietal-tripletAttention-2021,
  author    = {Haoyi Zhou and
               Jianxin Li and
               Jieqi Peng and
               Shuai Zhang and
               Shanghang Zhang},
  editor    = {Feida Zhu and
               Beng Chin Ooi and
               Chunyan Miao},
  title     = {Triplet Attention: Rethinking the Similarity in Transformers},
  booktitle = {The 27th {ACM} {SIGKDD} Conference on Knowledge Discovery and Data Mining, {KDD} 2021, Virtual Event},
  pages     = {2378--2388},
  publisher = {{ACM}},
  year      = {2021},
}

About

The official implementation of Triplet Attention.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages