Skip to content


Repository files navigation

BERT pruning


This project is based on the BERT, where you can find more information.

Currently, we can run the bash file "" to execute the bert-pruning. The bash file requires two args: mode and GPU device number. For example, the command of using GPU 0 and prune mode is "bash prune 0". Note that the bash file use the bert pretrained model in the parent folder (../Model/), like:

/BERT: the folder of this project
      | The main file for pruning.
   |           # run the pruning.
   |     # run the pruning on SQuAd.
   | # run the multiple stage pruning.
      |--/uncased_L-12_H-768_A_12: small bert model pretrained by google. 
      |--/MNLI_pretrained_model: use uncased_L-12_H-768_A_12 as pretrained model and fine-tune on MNLI with some epochs.


Create a virtual environment

conda env create -f environment.yml
conda activate pruning

Download the dataset

python ./other_function/
bash ./

Download the model

mkdir -p Model
wget -P ./Model
unzip ./Model/  -d ./Model/

Pretrain the model

# For GLUE 
bash ./ pretrain 0
# For SQuAD
bash ./ pretrain 0

Prune the model. Set pruning_type to choose the pruning method.

type 0: prune the whole head. For example, prune from 12 heads to 6 heads.

type 1: prune the size of per head. For example, prune from 64 per head to 32 per head.

type 2: granularity pruning on dim K. The granularity must be divide hidden units (768 here), like 1, 128, 256, 768.

type 3: granularity pruning on [Q K V] together. The granularity is 3x768.

type 4: VW. Prune on N dimension. Block size is 1x16, and choose top 4.

type 5: BW. Prune on K-N dimension. Block size is 32x32.

type 7: TW pruning on dim K, consider the whole network but not layer by layer.

type 8: TW pruning on dim K and N, consider the whole network but not layer by layer.

# For GLUE 
bash ./ prune 0
# For SQuAD
bash ./ prune 0

You can find the output file started with task name.




No description, website, or topics provided.






No releases published


No packages published