Skip to content

Latest commit

 

History

History
25 lines (19 loc) · 956 Bytes

模型蒸馏.md

File metadata and controls

25 lines (19 loc) · 956 Bytes

模型蒸馏

这里使用的是通过大模型的前项输出作为soft label和生成的label同时监督训练,这里目前只对DBnet做了尝试,其余算法应该类似, 其中soft label监督只对binary和thresh_binary,忽略了thresh,测试下来这种方式效果最好。

如何训练(对压缩后模型进行蒸馏)

python3 tools/det_train.py 
--config ./config/det_DB_mobilev3.yaml 
--log_str total_prune_20201015_distil3 
--pruned_model_dict_path ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200_mobile_slim_all/pruned/pruned_dict.dict 
--prune_model_path ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200_mobile_slim_all/pruned/pruned_dict.pth
--prune_type total 
--n_epoch 200 
--start_val 30 
--base_lr 0.0008 
--gpu_id 2 
--t_ratio 0.1 
--t_model_path ./checkpoint/ag_DB_bb_resnet50_he_DB_Head_bs_8_ep_1201/DB_best.pth.tar 
--t_config ./config/det_DB_resnet50_3_3.yaml