From f97f2bbc04139b9e0e222215018a4241ed78c969 Mon Sep 17 00:00:00 2001 From: Zhiling Zhang <1840962220@qq.com> Date: Thu, 8 Oct 2020 16:54:48 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=85=B3=E9=94=AE=E8=AF=8D?= =?UTF-8?q?=E6=8A=BD=E5=8F=96=E5=8A=9F=E8=83=BD=EF=BC=8C=E5=B9=B6=E6=8F=90?= =?UTF-8?q?=E4=BE=9Bbenchmark=20https://github.com/blmoistawinde/HarvestTe?= =?UTF-8?q?xt/issues/23=20=E5=BC=95=E5=85=A5=E5=8F=AF=E4=B8=8B=E8=BD=BD?= =?UTF-8?q?=E7=9A=84=E5=A4=96=E9=83=A8=E8=AF=8D=E5=85=B8=EF=BC=8C=E8=BE=85?= =?UTF-8?q?=E5=8A=A9=E6=96=B0=E8=AF=8D=E5=8F=91=E7=8E=B0=E6=8E=92=E9=99=A4?= =?UTF-8?q?=E6=97=A7=E8=AF=8D=20https://github.com/blmoistawinde/HarvestTe?= =?UTF-8?q?xt/issues/24?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 58 ++- examples/basics.py | 100 ++++- examples/kwd_benchmark/CSL.ipynb | 523 ++++++++++++++++++++++ harvesttext/algorithms/keyword.py | 36 ++ harvesttext/algorithms/word_discoverer.py | 4 +- harvesttext/download_utils.py | 135 ++++++ harvesttext/harvesttext.py | 3 +- harvesttext/parsing.py | 2 +- harvesttext/resources.py | 30 +- harvesttext/word_discover.py | 71 ++- 10 files changed, 928 insertions(+), 34 deletions(-) create mode 100644 examples/kwd_benchmark/CSL.ipynb create mode 100644 harvesttext/algorithms/keyword.py create mode 100644 harvesttext/download_utils.py diff --git a/README.md b/README.md index 7602b6d..2ed63df 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Sow with little data seed, harvest much from a text field. ![GitHub stars](https://img.shields.io/github/stars/blmoistawinde/harvesttext?style=social) ![PyPI - Python Version](https://img.shields.io/badge/python-3.6+-blue.svg) ![GitHub](https://img.shields.io/github/license/mashape/apistatus.svg) -![Version](https://img.shields.io/badge/version-V0.7-red.svg) +![Version](https://img.shields.io/badge/version-V0.8-red.svg) ## 用途 HarvestText是一个专注无(弱)监督方法,能够整合领域知识(如类型,别名)对特定领域文本进行简单高效地处理和分析的库。适用于许多文本预处理和初步探索性分析任务,在小说分析,网络文本,专业文献等领域都有潜在应用价值。 @@ -478,6 +478,37 @@ Text summarization(避免重复) 武磊和郜林,谁是中国最好的前锋? ``` + + +### 关键词抽取 + +目前提供包括`textrank`和HarvestText封装jieba并配置好参数和停用词的`jieba_tfidf`(默认)两种算法。 + +示例(完整见[example](./examples/basics.py)): + +```python3 +# text为林俊杰《关键词》歌词 +print("《关键词》里的关键词") +kwds = ht.extract_keywords(text, 5, method="jieba_tfidf") +print("jieba_tfidf", kwds) +kwds = ht.extract_keywords(text, 5, method="textrank") +print("textrank", kwds) +``` + +``` +《关键词》里的关键词 +jieba_tfidf ['自私', '慷慨', '落叶', '消逝', '故事'] +textrank ['自私', '落叶', '慷慨', '故事', '位置'] +``` + +[CSL.ipynb](./examples/kwd_benchmark/CSL.ipynb)提供了不同算法,以及本库的实现与[textrank4zh](https://github.com/letiantian/TextRank4ZH)的在[CSL数据集](https://github.com/CLUEbenchmark/CLUE#6-csl-%E8%AE%BA%E6%96%87%E5%85%B3%E9%94%AE%E8%AF%8D%E8%AF%86%E5%88%AB-keyword-recognition)上的比较。由于仅有一个数据集且数据集对于以上算法都很不友好,表现仅供参考。 + +| 算法 | P@5 | R@5 | F@5 | +| --- | --- | --- | --- | +| textrank4zh | 0.0836 | 0.1174 | 0.0977 | +| ht_textrank | 0.0955 | 0.1342 | 0.1116 | +| ht_jieba_tfidf | **0.1035** | **0.1453** | **0.1209** | + @@ -486,9 +517,11 @@ Text summarization(避免重复) 现在本库内集成了一些资源,方便使用和建立demo。 资源包括: -- 褒贬义词典 清华大学 李军 整理自http://nlp.csai.tsinghua.edu.cn/site2/index.php/13-sms -- 百度停用词词典 来自网络:https://wenku.baidu.com/view/98c46383e53a580216fcfed9.html -- 领域词典 来自清华THUNLP: http://thuocl.thunlp.org/ 全部类型`['IT', '动物', '医药', '历史人名', '地名', '成语', '法律', '财经', '食物']` +- `get_qh_sent_dict`: 褒贬义词典 清华大学 李军 整理自http://nlp.csai.tsinghua.edu.cn/site2/index.php/13-sms +- `get_baidu_stopwords`: 百度停用词词典 来自网络:https://wenku.baidu.com/view/98c46383e53a580216fcfed9.html +- `get_qh_typed_words`: 领域词典 来自清华THUNLP: http://thuocl.thunlp.org/ 全部类型`['IT', '动物', '医药', '历史人名', '地名', '成语', '法律', '财经', '食物']` +- `get_english_senti_lexicon`: 英语情感词典 +- `get_jieba_dict`: (需要下载)jieba词频词典 此外,还提供了一个特殊资源——《三国演义》,包括: @@ -590,6 +623,21 @@ min_aggregation = np.sqrt(length) / 15
+
使用结巴词典过滤旧词(展开查看) +``` +from harvesttext.resources import get_jieba_dict +jieba_dict = get_jieba_dict(min_freq=100) +print("jiaba词典中的词频>100的词语数:", len(jieba_dict)) +text = "1979-1998-2020的喜宝们 我现在记忆不太好,大概是拍戏时摔坏了~有什么笔记都要当下写下来。前几天翻看,找着了当时记下的话.我觉得喜宝既不娱乐也不启示,但这就是生活就是人生,10/16来看喜宝吧" +new_words_info = ht.word_discover(text, + excluding_words=set(jieba_dict), # 排除词典已有词语 + exclude_number=True) # 排除数字(默认True) +new_words = new_words_info.index.tolist() +print(new_words) # ['喜宝'] +``` +
+
+ [根据反馈更新](https://github.com/blmoistawinde/HarvestText/issues/13#issue-551894838) 原本默认接受一个单独的字符串,现在也可以接受字符串列表输入,会自动进行拼接 [根据反馈更新](https://github.com/blmoistawinde/HarvestText/issues/14#issuecomment-576081430) 现在默认按照词频降序排序,也可以传入`sort_by='score'`参数,按照综合质量评分排序。 @@ -802,3 +850,5 @@ we imagine what we'll find, in another life. [EventTriplesExtraction](https://github.com/liuhuanyong/EventTriplesExtraction) +[textrank4ZH](https://github.com/letiantian/TextRank4ZH) + diff --git a/examples/basics.py b/examples/basics.py index ab538e0..cec6774 100644 --- a/examples/basics.py +++ b/examples/basics.py @@ -1,6 +1,7 @@ #coding=utf-8 import re from harvesttext import HarvestText + ht = HarvestText() def new_word_discover(): @@ -398,29 +399,80 @@ def test_english(): # for sent0 in sentences: # print(sent0, ht_eng.analyse_sent(sent0)) - +def jieba_dict_new_word(): + from harvesttext.resources import get_jieba_dict + jieba_dict = get_jieba_dict(min_freq=100) + print("jiaba词典中的词频>100的词语数:", len(jieba_dict)) + text = "1979-1998-2020的喜宝们 我现在记忆不太好,大概是拍戏时摔坏了~有什么笔记都要当下写下来。前几天翻看,找着了当时记下的话.我觉得喜宝既不娱乐也不启示,但这就是生活就是人生,10/16来看喜宝吧" + new_words_info = ht.word_discover(text, + excluding_words=set(jieba_dict), # 排除词典已有词语 + exclude_number=True) # 排除数字(默认True) + new_words = new_words_info.index.tolist() + print(new_words) # ['喜宝'] + +def extract_keywords(): + text = """ +好好爱自己 就有人会爱你 +这乐观的说词 +幸福的样子 我感觉好真实 +找不到形容词 +沉默在掩饰 快泛滥的激情 +只剩下语助词 +有一种踏实 当你口中喊我名字 +落叶的位置 谱出一首诗 +时间在消逝 我们的故事开始 +这是第一次 +让我见识爱情 可以慷慨又自私 +你是我的关键词 +我不太确定 爱最好的方式 +是动词或名词 +很想告诉你 最赤裸的感情 +却又忘词 +聚散总有时 而哭笑也有时 +我不怕潜台词 +有一种踏实 是你心中有我名字 +落叶的位置 谱出一首诗 +时间在消逝 我们的故事开始 +这是第一次 +让我见识爱情 可以慷慨又自私 +你是我的关键词 +你藏在歌词 代表的意思 +是专有名词 +落叶的位置 谱出一首诗 +我们的故事 才正要开始 +这是第一次 +爱一个人爱得 如此慷慨又自私 +你是我的关键 + """ + print("《关键词》里的关键词") + kwds = ht.extract_keywords(text, 5, method="jieba_tfidf") + print("jieba_tfidf", kwds) + kwds = ht.extract_keywords(text, 5, method="textrank") + print("textrank", kwds) if __name__ == "__main__": - test_english() - new_word_discover() - new_word_register() - entity_segmentation() - sentiment_dict() - sentiment_dict_default() - entity_search() - text_summarization() - entity_network() - save_load_clear() - load_resources() - linking_strategy() - find_with_rules() - load_resources() - using_typed_words() - build_word_ego_graph() - entity_error_check() - depend_parse() - named_entity_recognition() - el_keep_all() - filter_el_with_rule() - clean_text() - cut_paragraph() + # test_english() + # new_word_discover() + # new_word_register() + # entity_segmentation() + # sentiment_dict() + # sentiment_dict_default() + # entity_search() + # text_summarization() + # entity_network() + # save_load_clear() + # load_resources() + # linking_strategy() + # find_with_rules() + # load_resources() + # using_typed_words() + # build_word_ego_graph() + # entity_error_check() + # depend_parse() + # named_entity_recognition() + # el_keep_all() + # filter_el_with_rule() + # clean_text() + # cut_paragraph() + # jieba_dict_new_word() + extract_keywords() diff --git a/examples/kwd_benchmark/CSL.ipynb b/examples/kwd_benchmark/CSL.ipynb new file mode 100644 index 0000000..9c0bb0a --- /dev/null +++ b/examples/kwd_benchmark/CSL.ipynb @@ -0,0 +1,523 @@ +{ + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": 3 + }, + "orig_nbformat": 2, + "kernelspec": { + "name": "python_defaultSpec_1602139106579", + "display_name": "Python 3.6.9 64-bit ('py36': conda)" + } + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "source": [ + "# HarvestText中的关键词算法benchmark" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import pandas as pd\n", + "import numpy as np\n", + "import networkx as nx\n", + "from tqdm import tqdm\n", + "import jieba\n", + "from collections import defaultdict\n", + "from harvesttext import HarvestText" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "ht = HarvestText()" + ] + }, + { + "source": [ + "首先,选取的数据集是CLUE整理的CSL关键词预测数据集([下载地址](https://github.com/CLUEbenchmark/CLUE#6-csl-%E8%AE%BA%E6%96%87%E5%85%B3%E9%94%AE%E8%AF%8D%E8%AF%86%E5%88%AB-keyword-recognition))。需要先下载并放到本目录的`CSL关键词预测`文件夹下\n", + "\n", + "在上面先在开发集上做一些基本的分析及调参。" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "3000" + }, + "metadata": {}, + "execution_count": 9 + } + ], + "source": [ + "data_dev = []\n", + "with open('CSL关键词预测/dev.json', encoding='utf-8') as f:\n", + " for line in f:\n", + " tmp = json.loads(line)\n", + " data_dev.append((tmp['abst'], tmp['keyword']))\n", + "len(data_dev)" + ] + }, + { + "source": [ + "一些基础的数据探索性分析(EDA)\n", + "- 每个文档的关键词个数\n", + "- 关键词的长度分布\n", + "- 考察分词`seg`的情况和不分词`nseg`的情况,有多少比例的关键词被覆盖。这决定了依赖分词和不依赖分词的算法所能达到的理论recall上限。" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "all_keywords = 0\n", + "recalls = {'seg':0, 'nseg':0}\n", + "kwd_cnt = defaultdict(int)\n", + "kwd_len_cnt = defaultdict(int)\n", + "for abst, kwds in data_dev:\n", + " kwd_cnt[len(kwds)] += 1\n", + " words = set(jieba.lcut(abst))\n", + " all_keywords += len(kwds)\n", + " recalls['seg'] += len(set(kwds) & words)\n", + " recalls['nseg'] += sum(int(kwd in abst) for kwd in kwds)\n", + " for kwd in kwds:\n", + " kwd_len_cnt[len(kwd)] += 1\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "defaultdict(, {4: 1814, 3: 1128, 2: 58})\n" + } + ], + "source": [ + "print(kwd_cnt)" + ] + }, + { + "source": [ + "每篇文档的关键词数量在2-4之间" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "1 0.004277\n2 0.260134\n3 0.387970\n4 0.702864\n5 0.812756\n6 0.904239\n7 0.937151\n8 0.956489\n9 0.971551\n10 0.980104\n11 0.988100\n12 0.991633\n13 0.995258\n14 0.995816\n15 0.996281\n16 0.997583\n17 0.998791\n18 0.999256\n19 0.999442\n20 0.999907\n31 1.000000\ndtype: float64" + }, + "metadata": {}, + "execution_count": 14 + } + ], + "source": [ + "# 关键词长度的累积概率分布\n", + "pd.Series(kwd_len_cnt).sort_index().cumsum() / sum(kwd_len_cnt.values())" + ] + }, + { + "source": [ + "存在很长的关键词,以一个词而不是多词词组为单元的关键词算法无法处理这些情况,不过4个字以内也已经可以覆盖70%" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "{'seg': 0.3697471178876906, 'nseg': 0.7791000371885459}\n" + } + ], + "source": [ + "for k in recalls:\n", + " recalls[k] /= all_keywords\n", + "print(recalls)" + ] + }, + { + "source": [ + "上述情况说明,依赖jieba分词的算法在这个数据集上最多只能达到36.97%的recall,而其他从原文直接中抽取方法(新词发现,序列标注等)有可能达到77.91%。\n", + "\n", + "下面的算法,因此在数值上不会有很好的表现,不过依旧可以为比较和调参提供一些参考。" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "给出一个关键词抽取的示例,包括`textrank`和HarvestText封装jieba并配置好参数和停用词的`jieba_tfidf`。" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "随机噪声雷达通常利用时域相关完成脉冲压缩从而进行目标检测.该文根据压缩感知理论提出一种适用于噪声雷达目标检测的新算法,它用低维投影测量和信号重建取代了传统的相关操作和压缩处理,将大量运算转移到后期处理.该算法以噪声雷达所检测的目标空间分布满足稀疏性为前提;利用发射信号形成卷积矩阵,然后通过随机抽取卷积矩阵的行构建测量矩阵;并采用迭代收缩阈值算法实现目标信号重建.该文对算法作了详细的理论推导,形成完整的实现框架.仿真实验验证了算法的有效性,并分析了对处理结果影响较大的因素.该算法能够有效地重建目标,具有良好的运算效率.与时域相关法相比,大幅度减小了目标检测误差,有效抑制了输出旁瓣,并保持了信号的相位特性.\n真实关键词:['目标', '相关', '矩阵']\njieba_tfidf 关键词(前5):['算法', '矩阵', '检测', '目标', '信号']\ntextrank 关键词(前5):['算法', '信号', '目标', '压缩', '矩阵']\n" + } + ], + "source": [ + "text, kwds = data_dev[10]\n", + "print(text)\n", + "print(\"真实关键词:\", kwds)\n", + "print(\"jieba_tfidf 关键词(前5):\", ht.extract_keywords(text, 5, method=\"jieba_tfidf\"))\n", + "print(\"textrank 关键词(前5):\", ht.extract_keywords(text, 5, method=\"textrank\"))" + ] + }, + { + "source": [ + "每篇文章取前5个作为预测值,我们可以得到precision@5, recall@5, F1@5来评估算法的效果" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "all_keywords = 0\n", + "pred_keywords = 0\n", + "recall_new_word = 0" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": "100%|██████████| 3000/3000 [00:29<00:00, 100.76it/s]\njieba Precison:{prec:.4f}, Recall:{recall:.4f}, F1:{f1:.4f}\n" + } + ], + "source": [ + "topK = 5\n", + "ref_keywords, pred_keywords = 0, 0\n", + "acc_count = 0\n", + "for text, kwds in tqdm(data_dev):\n", + " ref_keywords += len(kwds)\n", + " pred_keywords += topK\n", + " preds = ht.extract_keywords(text, topK, method=\"jieba_tfidf\")\n", + " acc_count += len(set(kwds) & set(preds))\n", + "prec = acc_count / pred_keywords\n", + "recall = acc_count / ref_keywords\n", + "f1 = 2*prec*recall/(prec+recall)\n", + "print(f\"jieba Precison:{prec:.4f}, Recall:{recall:.4f}, F1:{f1:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "jieba Precison:0.1060, Recall:0.1478, F1:0.1235\n" + } + ], + "source": [ + "print(f\"jieba Precison:{prec:.4f}, Recall:{recall:.4f}, F1:{f1:.4f}\")" + ] + }, + { + "source": [ + "Textrank调参" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": "100%|██████████| 3000/3000 [00:45<00:00, 66.11it/s]\ntextrank[block: doc, window:2, weighted:False] Precison:0.0942, Recall:0.1314, F1:0.1097\n100%|██████████| 3000/3000 [00:46<00:00, 64.20it/s]\ntextrank[block: doc, window:2, weighted:True] Precison:0.0955, Recall:0.1332, F1:0.1113\n100%|██████████| 3000/3000 [00:41<00:00, 71.53it/s]\ntextrank[block: doc, window:3, weighted:False] Precison:0.0948, Recall:0.1322, F1:0.1104\n100%|██████████| 3000/3000 [00:41<00:00, 65.70it/s]\ntextrank[block: doc, window:3, weighted:True] Precison:0.0945, Recall:0.1318, F1:0.1101\n100%|██████████| 3000/3000 [00:41<00:00, 72.11it/s]\ntextrank[block: doc, window:4, weighted:False] Precison:0.0944, Recall:0.1316, F1:0.1100\n100%|██████████| 3000/3000 [00:41<00:00, 71.65it/s]\ntextrank[block: doc, window:4, weighted:True] Precison:0.0939, Recall:0.1309, F1:0.1093\n100%|██████████| 3000/3000 [00:45<00:00, 66.37it/s]\ntextrank[block: sent, window:2, weighted:False] Precison:0.0931, Recall:0.1299, F1:0.1085\n100%|██████████| 3000/3000 [00:45<00:00, 65.93it/s]\ntextrank[block: sent, window:2, weighted:True] Precison:0.0945, Recall:0.1318, F1:0.1101\n100%|██████████| 3000/3000 [00:41<00:00, 53.28it/s]\ntextrank[block: sent, window:3, weighted:False] Precison:0.0936, Recall:0.1305, F1:0.1090\n100%|██████████| 3000/3000 [00:40<00:00, 73.21it/s]\ntextrank[block: sent, window:3, weighted:True] Precison:0.0929, Recall:0.1295, F1:0.1082\n100%|██████████| 3000/3000 [00:40<00:00, 73.50it/s]\ntextrank[block: sent, window:4, weighted:False] Precison:0.0931, Recall:0.1298, F1:0.1084\n100%|██████████| 3000/3000 [00:41<00:00, 72.45it/s]\ntextrank[block: sent, window:4, weighted:True] Precison:0.0925, Recall:0.1290, F1:0.1077\n" + } + ], + "source": [ + "from itertools import product\n", + "\n", + "topK = 5\n", + "block_types = [\"doc\", \"sent\"]\n", + "window_sizes = [2, 3, 4]\n", + "if_weighted = [False, True]\n", + "for block_type, window, weighted in product(block_types, window_sizes, if_weighted):\n", + " ref_keywords, pred_keywords = 0, 0\n", + " acc_count = 0\n", + " for text, kwds in tqdm(data_dev):\n", + " ref_keywords += len(kwds)\n", + " pred_keywords += topK\n", + " preds = ht.extract_keywords(text, topK, method=\"textrank\", block_type=block_type, window=window, weighted=weighted)\n", + " acc_count += len(set(kwds) & set(preds))\n", + " prec = acc_count / pred_keywords\n", + " recall = acc_count / ref_keywords\n", + " f1 = 2*prec*recall/(prec+recall)\n", + " print(f\"textrank[block: {block_type}, window:{window}, weighted:{weighted}] Precison:{prec:.4f}, Recall:{recall:.4f}, F1:{f1:.4f}\")" + ] + }, + { + "source": [ + "textrank的最佳参数是 block: doc, window:2, weighted:True\n", + "\n", + "precision和recall与jieba_tfidf还是有差距,可能是因为后者拥有从大量语料库中统计得到的idf数据能起到一定帮助" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "## 测试集benchmark\n", + "\n", + "选取各个算法的最佳参数在测试集上获得最终表现" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "3000" + }, + "metadata": {}, + "execution_count": 22 + } + ], + "source": [ + "data_test = []\n", + "with open('CSL关键词预测/test.json', encoding='utf-8') as f:\n", + " for line in f:\n", + " tmp = json.loads(line)\n", + " data_test.append((tmp['abst'], tmp['keyword']))\n", + "len(data_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": "100%|██████████| 3000/3000 [00:30<00:00, 99.11it/s]\njieba Precison:0.1035, Recall:0.1453, F1:0.1209\n" + } + ], + "source": [ + "topK = 5\n", + "ref_keywords, pred_keywords = 0, 0\n", + "acc_count = 0\n", + "for text, kwds in tqdm(data_test):\n", + " ref_keywords += len(kwds)\n", + " pred_keywords += topK\n", + " preds = ht.extract_keywords(text, topK, method=\"jieba_tfidf\")\n", + " acc_count += len(set(kwds) & set(preds))\n", + "prec = acc_count / pred_keywords\n", + "recall = acc_count / ref_keywords\n", + "f1 = 2*prec*recall/(prec+recall)\n", + "print(f\"jieba Precison:{prec:.4f}, Recall:{recall:.4f}, F1:{f1:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": "100%|██████████| 3000/3000 [00:45<00:00, 65.51it/s]\ntextrank Precison:0.0955, Recall:0.1342, F1:0.1116\n" + } + ], + "source": [ + "topK = 5\n", + "ref_keywords, pred_keywords = 0, 0\n", + "acc_count = 0\n", + "for text, kwds in tqdm(data_test):\n", + " ref_keywords += len(kwds)\n", + " pred_keywords += topK\n", + " preds = ht.extract_keywords(text, topK, method=\"textrank\", block_size=\"doc\", window=2, weighted=True)\n", + " acc_count += len(set(kwds) & set(preds))\n", + "prec = acc_count / pred_keywords\n", + "recall = acc_count / ref_keywords\n", + "f1 = 2*prec*recall/(prec+recall)\n", + "print(f\"textrank Precison:{prec:.4f}, Recall:{recall:.4f}, F1:{f1:.4f}\")" + ] + }, + { + "source": [ + "另,附上HarvestText与另一个流行的textrank的实现,[textrank4zh](https://github.com/letiantian/TextRank4ZH)的比较" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "随机噪声雷达通常利用时域相关完成脉冲压缩从而进行目标检测.该文根据压缩感知理论提出一种适用于噪声雷达目标检测的新算法,它用低维投影测量和信号重建取代了传统的相关操作和压缩处理,将大量运算转移到后期处理.该算法以噪声雷达所检测的目标空间分布满足稀疏性为前提;利用发射信号形成卷积矩阵,然后通过随机抽取卷积矩阵的行构建测量矩阵;并采用迭代收缩阈值算法实现目标信号重建.该文对算法作了详细的理论推导,形成完整的实现框架.仿真实验验证了算法的有效性,并分析了对处理结果影响较大的因素.该算法能够有效地重建目标,具有良好的运算效率.与时域相关法相比,大幅度减小了目标检测误差,有效抑制了输出旁瓣,并保持了信号的相位特性.\n真实关键词:['目标', '相关', '矩阵']\ntextrank4zh 关键词(前5):['算法', '信号', '目标', '压缩', '运算']\n" + } + ], + "source": [ + "from textrank4zh import TextRank4Keyword\n", + "\n", + "def textrank4zh(text, topK, window=2):\n", + " # same as used in ht\n", + " allowPOS = {'n', 'ns', 'nr', 'nt', 'nz', 'vn', 'v', 'an', 'a', 'i'}\n", + " tr4w = TextRank4Keyword(allow_speech_tags=allowPOS)\n", + " tr4w.analyze(text=text, lower=True, window=window)\n", + " return [item.word for item in tr4w.get_keywords(topK)]\n", + "\n", + "text, kwds = data_dev[10]\n", + "print(text)\n", + "print(\"真实关键词:\", kwds)\n", + "print(\"textrank4zh 关键词(前5):\", textrank4zh(text, 5))" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": "100%|██████████| 3000/3000 [02:12<00:00, 24.17it/s]\ntextrank4zh Precison:0.0836, Recall:0.1174, F1:0.0977\n" + } + ], + "source": [ + "topK = 5\n", + "ref_keywords, pred_keywords = 0, 0\n", + "acc_count = 0\n", + "for text, kwds in tqdm(data_test):\n", + " ref_keywords += len(kwds)\n", + " pred_keywords += topK\n", + " preds = textrank4zh(text, topK)\n", + " acc_count += len(set(kwds) & set(preds))\n", + "prec = acc_count / pred_keywords\n", + "recall = acc_count / ref_keywords\n", + "f1 = 2*prec*recall/(prec+recall)\n", + "print(f\"textrank4zh Precison:{prec:.4f}, Recall:{recall:.4f}, F1:{f1:.4f}\")" + ] + }, + { + "source": [ + "HarvestText的textrank的实现在精度和速度上都有一定的优势。" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "总结各个算法在CSL数据及上的结果:\n", + "\n", + "| 算法 | P@5 | R@5 | F@5 |\n", + "| --- | --- | --- | --- |\n", + "| textrank4zh | 0.0836 | 0.1174 | 0.0977 |\n", + "| ht_textrank | 0.0955 | 0.1342 | 0.1116 |\n", + "| ht_jieba_tfidf | **0.1035** | **0.1453** | **0.1209** |\n", + "\n", + "综上,HarvestText的关键词抽取功能\n", + "- 把配置好参数的jieba_tfidf作为默认方法\n", + "- 使用自己的textrank实现而不是用流行的textrank4zh。" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ] +} \ No newline at end of file diff --git a/harvesttext/algorithms/keyword.py b/harvesttext/algorithms/keyword.py new file mode 100644 index 0000000..0db95c3 --- /dev/null +++ b/harvesttext/algorithms/keyword.py @@ -0,0 +1,36 @@ +import numpy as np +import networkx as nx + +def combine(word_list, window = 2): + """构造在window下的单词组合,用来构造单词之间的边。 + + :params word_list: list of str, 由单词组成的列表。 + :params window: int, 窗口大小。 + """ + if window < 2: window = 2 + for x in range(1, window): + if x >= len(word_list): + break + word_list2 = word_list[x:] + res = zip(word_list, word_list2) + for r in res: + yield r + +def textrank(block_words, topK, with_score=False, window=2, weighted=False): + G = nx.Graph() + for word_list in block_words: + for u, v in combine(word_list, window): + if not weighted: + G.add_edge(u, v) + else: + if G.has_edge(u, v): + G[u][v]['weight'] += 1 + else: + G.add_edge(u, v, weight=1) + + pr = nx.pagerank_scipy(G) + pr_sorted = sorted(pr.items(), key=lambda x: x[1], reverse=True) + if with_score: + return pr_sorted[:topK] + else: + return [w for (w, imp) in pr_sorted[:topK]] \ No newline at end of file diff --git a/harvesttext/algorithms/word_discoverer.py b/harvesttext/algorithms/word_discoverer.py index 7283856..b0376cc 100644 --- a/harvesttext/algorithms/word_discoverer.py +++ b/harvesttext/algorithms/word_discoverer.py @@ -202,11 +202,13 @@ def genWords2(self, doc): v.left = entropyOfList(v.left) v.right = entropyOfList(v.right) return values - def get_df_info(self, ex_mentions): + def get_df_info(self, ex_mentions, exclude_number=True): info = {"text":[],"freq":[],"left_ent":[],"right_ent":[],"agg":[]} for w in self.word_infos: if w.text in ex_mentions: continue + if exclude_number and w.text.isdigit(): + continue info["text"].append(w.text) info["freq"].append(w.freq) info["left_ent"].append(w.left) diff --git a/harvesttext/download_utils.py b/harvesttext/download_utils.py new file mode 100644 index 0000000..e0360ac --- /dev/null +++ b/harvesttext/download_utils.py @@ -0,0 +1,135 @@ +import os +import shutil +import requests +import hashlib +from tqdm import tqdm +from collections import namedtuple +from os import environ, listdir, makedirs +from os.path import dirname, exists, expanduser, isdir, join, splitext + +RemoteFileMetadata = namedtuple('RemoteFileMetadata', + ['filename', 'url', 'checksum']) + +# config according to computer, this should be default setting of shadowsocks +DEFAULT_PROXIES = { + 'http': 'socks5h://127.0.0.1:1080', + 'https': 'socks5h://127.0.0.1:1080' +} + +def get_data_home(data_home=None): + """Return the path of the scikit-learn data dir. + This folder is used by some large dataset loaders to avoid downloading the + data several times. + By default the data dir is set to a folder named 'scikit_learn_data' in the + user home folder. + Alternatively, it can be set by the 'SCIKIT_LEARN_DATA' environment + variable or programmatically by giving an explicit folder path. The '~' + symbol is expanded to the user home folder. + If the folder does not already exist, it is automatically created. + Parameters + ---------- + data_home : str | None + The path to data dir. + """ + if data_home is None: + data_home = environ.get('HARVESTTEXT_DATA', + join('~', '.harvesttext')) + data_home = expanduser(data_home) + if not exists(data_home): + makedirs(data_home) + return data_home + +def clear_data_home(data_home=None): + """Delete all the content of the data home cache. + Parameters + ---------- + data_home : str | None + The path to data dir. + """ + data_home = get_data_home(data_home) + shutil.rmtree(data_home) + +def _sha256(path): + """Calculate the sha256 hash of the file at path.""" + sha256hash = hashlib.sha256() + chunk_size = 8192 + with open(path, "rb") as f: + while True: + buffer = f.read(chunk_size) + if not buffer: + break + sha256hash.update(buffer) + return sha256hash.hexdigest() + +def _download_with_bar(url, file_path, proxies=DEFAULT_PROXIES): + # Streaming, so we can iterate over the response. + response = requests.get(url, stream=True, proxies=proxies) + total_size_in_bytes= int(response.headers.get('content-length', 0)) + block_size = 1024 # 1 KB + progress_bar = tqdm(total=total_size_in_bytes, unit='B', unit_scale=True) + with open(file_path, 'wb') as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: + raise Exception("ERROR, something went wrong with the downloading") + return file_path + +def _fetch_remote(remote, dirname=None, use_proxy=False, proxies=DEFAULT_PROXIES): + """Helper function to download a remote dataset into path + Fetch a dataset pointed by remote's url, save into path using remote's + filename and ensure its integrity based on the SHA256 Checksum of the + downloaded file. + Parameters + ---------- + remote : RemoteFileMetadata + Named tuple containing remote dataset meta information: url, filename + and checksum + dirname : string + Directory to save the file to. + Returns + ------- + file_path: string + Full path of the created file. + """ + + file_path = (remote.filename if dirname is None + else join(dirname, remote.filename)) + proxies = None if not use_proxy else proxies + file_path = _download_with_bar(remote.url, file_path, proxies) + checksum = _sha256(file_path) + if remote.checksum != checksum: + raise IOError("{} has an SHA256 checksum ({}) " + "differing from expected ({}), " + "file may be corrupted.".format(file_path, checksum, + remote.checksum)) + return file_path + + +def download(remote, file_path=None, use_proxy=False, proxies=DEFAULT_PROXIES): + data_home = get_data_home() + file_path = _fetch_remote(remote, data_home, use_proxy, proxies) + return file_path + +def check_download_resource(remote, use_proxy=False, proxies=None): + proxies = DEFAULT_PROXIES if use_proxy and proxies is None else proxies + data_home = get_data_home() + file_path = os.path.join(data_home, remote.filename) + if not os.path.exists(file_path): + # currently don't capture error at this level, assume download success + file_path = download(remote, data_home) + return file_path + +if __name__ == "__main__": + ARCHIVE = RemoteFileMetadata( + filename='harvesttext-0.7.2-py3-none-any.whl', + url='https://github.com/blmoistawinde/HarvestText/releases/download/V0.7.2/harvesttext-0.7.2-py3-none-any.whl', + checksum='004c8b0b1858f69025f721bc84cff33127d53c6ab526beed7a7a801a9c21f30b') + print("Download") + file_path = download(ARCHIVE) + print(file_path) + # if proxy is available + # print("Download using proxy") + # file_path = download(ARCHIVE, use_proxy=True) + # print(file_path) \ No newline at end of file diff --git a/harvesttext/harvesttext.py b/harvesttext/harvesttext.py index 5698414..2579969 100644 --- a/harvesttext/harvesttext.py +++ b/harvesttext/harvesttext.py @@ -53,6 +53,7 @@ def __init__(self, standard_name=False, language='zh_CN'): self.pinyin_adjlist = json.load(f) self.language = language if language == "en": + import nltk try: nltk.data.find('taggers/averaged_perceptron_tagger') except: @@ -774,7 +775,7 @@ def clean_text(self, text, remove_url=True, email=True, weibo_at=True, stop_term if t2s: cc = OpenCC('t2s') text = cc.convert(text) - assert hasattr(stop_terms, "__init__"), Exception("去除的词语必须是一个可迭代对象") + assert hasattr(stop_terms, "__iter__"), Exception("去除的词语必须是一个可迭代对象") if type(stop_terms) == str: text = text.replace(stop_terms, "") else: diff --git a/harvesttext/parsing.py b/harvesttext/parsing.py index 63c7fce..171be03 100644 --- a/harvesttext/parsing.py +++ b/harvesttext/parsing.py @@ -139,7 +139,7 @@ def cut_paragraphs(self, text, num_paras=None, block_sents=3, std_weight=0.5, if num_paras is not None: assert num_paras > 0, "Should give a positive number of num_paras" assert stopwords == 'baidu' or (hasattr(stopwords, '__iter__') and type(stopwords) != str) - stopwords = get_baidu_stopwords() if stopwords == 'baidu' else stopwords + stopwords = get_baidu_stopwords() if stopwords == 'baidu' else set(stopwords) if seq_chars < 1: cut_seqs = lambda x: self.cut_sentences(x, **kwargs) else: diff --git a/harvesttext/resources.py b/harvesttext/resources.py index 623329b..84ff3de 100644 --- a/harvesttext/resources.py +++ b/harvesttext/resources.py @@ -10,6 +10,7 @@ # 李军 中文评论的褒贬义分类实验研究 硕士论文 清华大学 2008 import os import json +from collections import defaultdict def get_qh_sent_dict(): """ @@ -123,4 +124,31 @@ def get_english_senti_lexicon(type="LH"): senti_lexicon = json.load(f) return senti_lexicon - +def get_jieba_dict(min_freq=0, max_freq=float('inf'), with_pos=False, use_proxy=False, proxies=None): + """ + 获得jieba自带的中文词语词频词典 + + :params min_freq: 选取词语需要的最小词频 + :params max_freq: 选取词语允许的最大词频 + :params with_pos: 返回结果是否包括词性信息 + :return if not with_pos, dict of {wd: freq}, else, dict of {(wd, pos): freq} + """ + from .download_utils import RemoteFileMetadata, check_download_resource + remote = RemoteFileMetadata( + filename='jieba_dict.txt', + url='https://github.com/blmoistawinde/HarvestText/releases/download/V0.8/jieba_dict.txt', + checksum='7197c3211ddd98962b036cdf40324d1ea2bfaa12bd028e68faa70111a88e12a8') + file_path = check_download_resource(remote, use_proxy, proxies) + ret = defaultdict(int) + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + if len(line.strip().split()) == 3: + wd, freq, pos = line.strip().split() + freq = int(freq) + if freq > min_freq and freq < max_freq: + if not with_pos: + ret[wd] = freq + else: + ret[(wd, pos)] = freq + return ret + \ No newline at end of file diff --git a/harvesttext/word_discover.py b/harvesttext/word_discover.py index b302ec4..087e57d 100644 --- a/harvesttext/word_discover.py +++ b/harvesttext/word_discover.py @@ -1,4 +1,7 @@ +import jieba +import jieba.analyse import logging +import networkx as nx import numpy as np import pandas as pd from collections import defaultdict @@ -6,6 +9,7 @@ from .resources import get_baidu_stopwords from .algorithms.word_discoverer import WordDiscoverer from .algorithms.entity_discoverer import NFLEntityDiscoverer, NERPEntityDiscover +from .algorithms.keyword import textrank class WordDiscoverMixin: """ @@ -18,7 +22,7 @@ class WordDiscoverMixin: def word_discover(self, doc, threshold_seeds=[], auto_param=True, excluding_types=[], excluding_words='baidu_stopwords', # 可以排除已经登录的某些种类的实体,或者某些指定词 max_word_len=5, min_freq=0.00005, min_entropy=1.4, min_aggregation=50, - ent_threshold="both", mem_saving=None, sort_by='freq'): + ent_threshold="both", mem_saving=None, sort_by='freq', exclude_number=True): '''新词发现,基于 http://www.matrix67.com/blog/archives/5044 实现及微调 :param doc: (string or list) 待进行新词发现的语料,如果是列表的话,就会自动用换行符拼接 @@ -33,6 +37,7 @@ def word_discover(self, doc, threshold_seeds=[], auto_param=True, :param ent_threshold: "both": (默认)在使用左右交叉熵进行筛选时,两侧都必须超过阈值; "avg": 两侧的平均值达到阈值即可 :param mem_saving: bool or None, 采用一些过滤手段来减少内存使用,但可能影响速度。如果不指定,对长文本自动打开,而对短文本不使用 :param sort_by: 以下string之一: {'freq': 词频, 'score': 综合分数, 'agg':凝聚度} 按照特定指标对得到的词语信息排序,默认使用词频 + :param exclude_number: (默认True)过滤发现的纯数字新词 :return: info: 包含新词作为index, 以及对应各项指标的DataFrame ''' if type(doc) != str: @@ -72,7 +77,7 @@ def word_discover(self, doc, threshold_seeds=[], auto_param=True, else: ex_mentions |= set(excluding_words) - info = ws.get_df_info(ex_mentions) + info = ws.get_df_info(ex_mentions, exclude_number) # 利用种子词来确定筛选优质新词的标准,种子词中最低质量的词语将被保留(如果一开始就被找到的话) if len(threshold_seeds) > 0: @@ -234,4 +239,66 @@ def entity_discover(self, text, return_count=False, method="NFL", min_count=5, p return entity_mention_dict, entity_type_dict, mention_count else: return entity_mention_dict, entity_type_dict + + def extract_keywords(self, text, topK, with_score=False, min_word_len=2, stopwords="baidu", allowPOS="default", method="jieba_tfidf", **kwargs): + """用各种算法抽取关键词(目前均为无监督),结合了ht的实体分词来提高准确率 + 目前支持的算法类型(及额外参数): + + - jieba_tfidf: (默认)jieba自带的基于tfidf的关键词抽取算法,idf统计信息来自于其语料库 + - textrank: 基于textrank的关键词抽取算法 + - block_type: 默认"doc"。 支持三种级别,"sent", "para", "doc",每个block之间的临近词语不建立连边 + - window: 默认2, 邻接的几个词语之内建立连边 + - weighted: 默认False, 时候使用加权图计算textrank + - 构建词图时会过滤不符合min_word_len, stopwords, allowPOS要求的词语 + + :params text: 从中挖掘关键词的文档 + :params topK: int, 从每个文档中抽取的关键词(最大)数量 + :params with_score: bool, 默认False, 是否同时返回算法提供的分数(如果有的话) + :params min_word_len: 默认2, 被纳入关键词的词语不低于此长度 + :param stopwords: 字符串列表/元组/集合,或者'baidu'为默认百度停用词,在算法中引入的停用词,一般能够提升准确度 + :params allowPOS: iterable of str,关键词应当属于的词性,默认为"default" {'n', 'ns', 'nr', 'nt', 'nz', 'vn', 'v', 'an', 'a', 'i'}以及已登录的实体词类型 + :params method: 选择用于抽取的算法,目前支持"jieba_tfidf", "tfidf", "textrank" + :params kwargs: 其他算法专属参数 + + + """ + assert method in {"jieba_tfidf", "textrank"}, print("目前不支持的算法") + if allowPOS == 'default': + # ref: 结巴分词标注兼容_ICTCLAS2008汉语词性标注集 https://www.cnblogs.com/hpuCode/p/4416186.html + allowPOS = {'n', 'ns', 'nr', 'nt', 'nz', 'vn', 'v', 'an', 'a', 'i'} + else: + assert hasattr(allowPOS, "__iter__") + # for HT, we consider registered entity types specifically + allowPOS |= set(self.type_entity_mention_dict) + + assert stopwords == 'baidu' or (hasattr(stopwords, '__iter__') and type(stopwords) != str) + stopwords = get_baidu_stopwords() if stopwords == 'baidu' else set(stopwords) + + if method == "jieba_tfidf": + kwds = jieba.analyse.extract_tags(text, topK=int(2*topK), allowPOS=allowPOS, withWeight=with_score) + if with_score: + kwds = [(kwd, score) for (kwd, score) in kwds if kwd not in stopwords][:topK] + else: + kwds = kwds[:topK] + elif method == "textrank": + block_type = kwargs.get("block_type", "doc") + assert block_type in {"sent", "para", "doc"} + window = kwargs.get("window", 2) + weighted = kwargs.get("weighted", True) + if block_type == "doc": + blocks = [text] + elif block_type == "para": + blocks = [para.strip() for para in text.split("\n") if para.strip() != ""] + elif block_type == "sent": + blocks = self.cut_sentences(text) + block_pos = (self.posseg(block.strip(), stopwords=stopwords) for block in blocks) + block_words = [[wd for wd, pos in x + if pos in allowPOS and len(wd) >= min_word_len] + for x in block_pos] + kwds = textrank(block_words, topK, with_score, window, weighted) + + return kwds + + + \ No newline at end of file