Skip to content

zh3389/text_class

Repository files navigation

环境搭建

如果本项目对你学习构建一个 文本分类模型 + 部署模型 有帮助, 欢迎 Start...

在当前环境安装本项目使用的环境

pip install -r ./requirements.txt

我尝试将项目克隆下来之后安装了requirements里的包,发现始终缺少依赖.

所以我将重要的包版本罗列出来, 建议使用conda安装以下列表的依赖和对应的版本...

jieba==0.39
numpy==1.17.3
requests==2.22.0
keras==2.3.0
pandas==0.25.1
tqdm==4.31.0
tensorflow==1.14.0

快速开始测试

下载wiki.zh.vec至项目文件夹下 ./data/ 下载地址

找到或者直接点击Chinese: bin+text, text下载

python train.py  # 运行train.py文件进行训练demo数据

训练自定义数据集

1. 准备你的数据集csv格式 由 , 分隔如下:

一列为class用于存储每个类别的标签, 一列为data用于存储每条文本数据

data_example

class data
phone 苹果
phone 华为
phone 小米
phone 传音
bank 中国建设 银行
bank 中国 银行
bank 中国工商银行
bank 中国农业银行
country 中国
country 美国
country 俄罗斯
country 加拿大

2. 修改config.py文件

  1. train_data_path 为自定义数据的文件路径,也可覆盖demo数据.默认为: "./data/train_data.csv"
  2. embedded_matrix_size 为嵌入矩阵大小, 根据词频保留的词数,用于构建嵌入矩阵.默认为: 10240
  3. validation_ratio 为划分测试数据集占总数据集比例. 默认为: 0.2
  4. epochs 为整个数据集迭代次数. 默认为: 512
  5. batch_size 为优化模型每个批次的数据条数. 默认为: 2 注意:当前2为特殊情况(因为测试数据集较小)一定记得修改
  6. learning_rate 为优化模型的学习速率. 默认为: 0.01
  7. learning_rate_decay 为学习速率每个epochs进行衰减的比率. 默认为: 0.95

3. 运行 train.py 文件对数据进行训练

  1. 运行过程中会在./save_model/save/下生成model.h5模型文件,运行结束会生成final_model.h5
  2. 运行过程中会在./save_model/logs/下生成并不断更新一个日志文件,在项目根目录执行 tensorboard --logdir=save_model/logs即可监控模型训练过程
  3. 运行成功后会在./save_model/deploy/下生成可用于服务器部署的 pb 格式文件:
.
└── 0
    ├── saved_model.pb
    └── variables
        ├── variables.data-00000-of-00001
        └── variables.index

4. 部署成功后使用 client.py 进行模型的使用

记得修改class_dict = {0: "phone", 1: "bank", 2: "country"}模型输出对应的值,即可得到对应的类别名称

About

通用文本分类 -> keras+tensorflow -> 部署文件pb tf-serving docker-serving -> 客户端 client.py

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published