Skip to content

UCSD-AI4H/SSReg

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 

Repository files navigation

SSL-Reg

1. Introduction

This is a PyTorch implementation of our paper "Self-supervised Regularization for Text Classification".

2. Environment

The corresponding environments to our codes have been included as requirements.txt in the two folders.

3. SSL-Reg (MTP)

We provide example data format for ChemProt. As for other datasets, they are publicly available here.

To finetune Roberta with SSL-Reg-MTP, use the command like below:

export DATA_DIR=/
export TASK=chemprot
export OUTPUT_PATH=results/chemprot
export BERT_ALL_DIR=/cluster/home/it_stu114/PTMs/roberta_base

python sslreg.py \
    --task_name $TASK \
    --do_train --do_val --do_test --do_val_history \
    --do_save \
    --do_lower_case \
    --bert_model roberta-base \
    --data_dir nlp_datasets/chemprot \
    --bert_all_dir $BERT_ALL_DIR \
    --max_seq_length 512 \
    --train_batch_size 16 \
    --gradient_accumulation_steps 8 \
    --train_lm_loss_weight 0.1 \
    --adam_beta1    0.9    \
    --adam_beta2    0.98   \
    --clip_grad_norm    \
    --learning_rate 2e-5 \
    --output_dir $OUTPUT_PATH \
    --num_train_epochs 10   \
    --has_test_label    \
    --save_best_model   \
    --force-overwrite

Note that for different datasets, num_train_epochs and train_lm_loss_weight might be set to difference values.

4. SSL-Reg (SATP)

The General Language Understanding Evaluation (GLUE) benchmark is a collection of nine sentence- or sentence-pair language understanding tasks for evaluating and analyzing natural language understanding systems.

Before running any of these GLUE tasks you should download the GLUE data by running this script and unpack it to some directory $GLUE_DIR.

We used EDA as the sentence augmentation method.

To transform your GLUE task to sentence augmentation type prediction task, you can run aug.py with following scripts:

python aug.py \
  --task_name CoLA \
  --dataroot './glue_data/' \
  --aug_dataroot './aug_data/' 

or

for i in 'SST-2' 'CoLA' 'MRPC' 'QNLI' 'RTE' 'STS-B' 'WNLI' 'QQP' 'MNLI'
do
    echo Augment $i with two types ...
    python aug.py --num_type 2 --task_name $i --dataroot './glue_data/' --aug_dataroot './aug_data/type-2/' 
    echo Augment $i with three types ...
    python aug.py --num_type 3 --task_name $i --dataroot './glue_data/' --aug_dataroot './aug_data/type-3/' 
    echo Augment $i with four types ...
    python aug.py --num_type 4 --task_name $i --dataroot './glue_data/' --aug_dataroot './aug_data/type-4/' 
done

To finetune GLUE task with SSL-Reg-SATP, run main.py with following scripts.

python main.py \
  --lr 3e-5 \
  --epochs 6 \
  --auxiliary_weight 0.4 \
  --max_seq_length 128 \
  --batch_size 8 \
  --do_eval_ssl_task \
  --gradient_accumulation_steps 4 \
  --dataroot ./glue_data/CoLA/ \
  --aug_dataroot ./glue_data/CoLA/aug_data/type-4/ \
  --auxiliary_labels 4 \
  --task_name cola --print_freq 10 \
  --force-overwrite \
  --outf results/results_CoLA/bert1

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published