基于GPT2-AI人工智能文章标题自动生成项目付费

文案狗 4096 2021-10-06 GPT2

资源介绍

  项目描述

  本项目是一个带有超级详细中文注释的基于GPT2模型的新闻标题生成项目。

  本项目参考了多个GPT2开源项目,并根据自己的理解,将代码进行重构,添加详细注释,希望可以帮助到有需要的人。

  本项目使用HuggingFace的transformers实现GPT2模型代码编写、训练及测试。

  本项目通过Flask框架搭建了一个Web服务,将新闻摘要生成模型进行工程化,可以通过页面可视化地体验新闻标题生成效果。

  本项目的代码详细讲解,可以自行阅读代码,也可查看代码注释介绍。

  本项目提供的新闻标题模型是一个6层的小模型(其实是穷人没人卡,只能训练小模型),并且在训练该模型过程中,没有加载预训练的GPT2模型而是随机初始化的参数,并且训练轮数较少(5轮,还没收敛完),因此效果一般。如果想要更好效果的模型,可以按照个人需求训练一个模型。

  本项目的目的是带领大家走一遍GPT2生成模型的训练、测试及部署全部流程。

  文件结构

  •   config

    •   config.json 模型的配置信息,包含n_ctx、n_embd、n_head、n_layer等。

  •   vocab

    •   vocab.txt 字典文件,该字典为大小为13317,删除了将原始字典中的“##中文”,并且增加了“[Content]”、“[Title]”、“[Space]”等标记。

  •   data_dir 存放数据的文件夹

  •   templates 存放html页面的文件夹

  •   data_helper.py 数据预处理文件,将数据进行简单的清洗

  •   data_set.py 数据类文件,定义模型所需的数据类,方便模型训练使用

  •   model.py GPT2模型文件,主要对transformers包中GPT2LMHeadModel的重写,修改计算loss部分,只计算预测title部分的loss

  •   train.py 通过新闻正文生成新闻标题的GPT2模型的训练文件

  •   generate_title.py 根据训练好的模型,进行新闻标题生成,预测文件

  •   http_server.py 构建web服务文件

  运行环境

  gevent == 1.3a1

  flask == 0.12.2

  transformers == 3.0.2

  详细见requirements.txt文件

  模型训练

  python3 train.py
  或
  python3 train.py --output_dir output_dir/(自定义保存模型路径)

  训练参数可自行添加,包含参数具体如下:

参数 类型 默认值 描述
device str "0" 设置训练或测试时使用的显卡
config_path str "config/config.json" 模型参数配置信息
vocab_path str "vocab/vocab.txt" 词表,该词表为小词表,并增加了一些新的标记
train_file_path str "data_dir/train_data.json" 新闻标题生成的训练数据
test_file_path str "data_dir/test_data.json" 新闻标题生成的测试数据
pretrained_model_path str None 预训练的GPT2模型的路径
data_dir str "data_dir" 生成缓存数据的存放路径
num_train_epochs int 5 模型训练的轮数
train_batch_size int 16 训练时每个batch的大小
test_batch_size int 8 测试时每个batch的大小
learning_rate float 1e-4 模型训练时的学习率
warmup_proportion float 0.1 warm up概率,即训练总步长的百分之多少,进行warm up操作
adam_epsilon float 1e-8 Adam优化器的epsilon值
logging_steps int 20 保存训练日志的步数
eval_steps int 4000 训练时,多少步进行一次测试
gradient_accumulation_steps int 1 梯度积累
max_grad_norm float 1.0
output_dir str "output_dir/" 模型输出路径
seed int 2020 随机种子
max_len int 512 输入模型的最大长度,要比config中n_ctx小

  或者修改train.py文件中的set_args函数内容,可修改默认值。

  模型测试

  python3 generate_title.py
  或
  python3 generate_title.py --top_k 3 --top_p 0.9999 --generate_max_len 32

  参数可自行添加,包含参数具体如下:

参数 类型 默认值 描述
device str "0" 设置训练或测试时使用的显卡
model_path str "output_dir/checkpoint-139805" 模型文件路径
vocab_path str "vocab/vocab.txt" 词表,该词表为小词表,并增加了一些新的标记
batch_size int 3 生成标题的个数
generate_max_len int 32 生成标题的最大长度
repetition_penalty float 1.2 重复处罚率
top_k int 5 解码时保留概率最高的多少个标记
top_p float 0.95 解码时保留概率累加大于多少的标记
max_len int 512 输入模型的最大长度,要比config中n_ctx小

  测试结果如下:

  从测试集中抽一篇

  content:

  今日,中国三条重要高铁干线——兰新高铁、贵广铁路和南广铁路将开通运营。其中兰新高铁是中国首条高原高铁,全长1776公里,最高票价658元。贵广铁路最贵车票320元,南广铁路最贵车票206.5元,这两条线路大大缩短西南与各地的时空距离。出行更方便了!中国“高铁版图”再扩容
三条重要高铁今日开通

  title:

  生成的第1个标题为:中国“高铁版图”再扩容 三条重要高铁今日开通

  生成的第2个标题为:贵广铁路最高铁版图

  生成的第3个标题为:出行更方便了!中国“高铁版图”再扩容三条重要高铁今日开通

  从网上随便找一篇新闻

  content:

  值岁末,一年一度的中央经济工作会议牵动全球目光。今年的会议,背景特殊、节点关键、意义重大。12月16日至18日。北京,京西宾馆。站在“两个一百年”奋斗目标的历史交汇点上,2020年中央经济工作会议谋划着中国经济发展大计。习近平总书记在会上发表了重要讲话,深刻分析国内外经济形势,提出2021年经济工作总体要求和政策取向,部署重点任务,为开局“十四五”、开启全面建设社会主义现代化国家新征程定向领航。

  title:

  生成的第1个标题为:习近平总书记在京会上发表重大计划 提出2025年经济工作总体要求和政策

  生成的第2个标题为:习近平总书记在会上发表重要讲话

  生成的第3个标题为:习近平总书记在会上发表重要讲话,深刻分析国内外经济形势

  解码采用top_k和top_p解码策略,有一定的随机性,可重复生成。

  启动Flask服务

  python3 http_server.py
  或
  python3 http_server.py --http_id "0.0.0.0" --port 5555

  本地测试直接使用"127.0.0.1:5555/news-title-generate",如果给他人访问,只需将"127.0.0.1"替换成的电脑的IP地址即可。

END
相关标签

发表评论

用户评论(2)

  • 利婷 2021年11月4日 下午11:37

    这个软件怎么用我试了一下没有我喜欢的文状

    • guangh 回复 利婷 2021年11月5日 上午9:16

      你好,你所评论的项目是文章标题自动生成,不是自动生成文章的项目,如果需要文章生成的项目,可以看下http://www.guangh.cn/164.html,http://www.guangh.cn/161.html