Skip to content

keras复现人群数量估计网络"CNN-based Cascaded Multi-task Learning of High-level Prior and Density Estimation for Crowd Counting",欢迎试用、关注并反馈问题...

License

Notifications You must be signed in to change notification settings

embracesource-cv-com/keras-crowdcounting-cmtl

Repository files navigation

keras-crowdcounting-cmtl

keras复现人群数量估计网络"CNN-based Cascaded Multi-task Learning of High-level Prior and Density Estimation for Crowd Counting"。 本工程的实现主要参考crowdcount-cascaded-mtlkeras-mcnn 在ShanghaiTech数据集上训练和测试效果如下:

|        |  MAE   |  MSE   |
----------------------------
| Part_A |  115.57 |  179.82 |
----------------------------
| Part_B |  26.30  |  48.78  |

安装

  1. Clone

    git clone https://github.com/embracesource-cv-com/keras-crowdcounting-cmtl
  2. 安装依赖库

    cd keras-crowdcounting-cmtl
    pip install -r requirements.txt

数据配置

  1. 下载ShanghaiTech数据集:
    Dropbox or 百度云盘

  2. 创建数据存放目录$ORIGIN_DATA_PATH

    mkdir /opt/dataset/crowd_counting/shanghaitech/original
  3. part_A_finalpart_B_final存放到$ORIGIN_DATA_PATH目录下

  4. 生成测试集的ground truth文件

    python create_gt_test_set_shtech.py [A or B]  # Part_A or Part_B

    生成好的ground-truth文件将会保存在$TEST_GT_PATH/test_data/ground_truth_csv目录下

  5. 生成训练集和验证集

    python create_training_set_shtech.py [A or B]

    生成好的数据保存将会在$TRAIN_PATH、$TRAIN_GT_PATH、$VAL_PATH、$VAL_GT_PATH目录下

测试

a)下载训练模型

cmtl-A.235.h5 提取码:prxi、cmtl-B.210.h5 提取码:7if7

b) 如下命令分别测试A和B

python test.py --dataset A --weight_path /tmp/cmtl-A.235.h5 --output_dir /tmp/ctml_A
python test.py --dataset B --weight_path /tmp/cmtl-B.210.h5 --output_dir /tmp/ctml_B

训练

如果你想自己训练模型,很简单:

python train.py [A or B]

About

keras复现人群数量估计网络"CNN-based Cascaded Multi-task Learning of High-level Prior and Density Estimation for Crowd Counting",欢迎试用、关注并反馈问题...

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages