diff --git a/README.md b/README.md index e01c8a2..20808bd 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,8 @@ Our Web version has been released to [OpenXLab](https://openxlab.org.cn/apps/det The Web version's API for Android also supports other devices. See [Python sample code](./tests/test_openxlab_android_api.py). -- \[2024/09\] [code retrieval](./huixiangdou/service/parallel_pipeline.py) +- \[2024/09\] [Inverted indexer](https://github.com/InternLM/HuixiangDou/pull/387) makes LLM prefer knowledge base🎯 +- \[2024/09\] [Code retrieval](./huixiangdou/service/parallel_pipeline.py) - \[2024/08\] [chat_with_readthedocs](https://huixiangdou.readthedocs.io/en/latest/), see [how to integrate](./docs/zh/doc_add_readthedocs.md) 👍 - \[2024/07\] Image and text retrieval & Removal of `langchain` 👍 - \[2024/07\] [Hybrid Knowledge Graph and Dense Retrieval](./docs/en/doc_knowledge_graph.md) improve 1.7% F1 score 🎯 diff --git a/README_zh.md b/README_zh.md index 110e389..1b72815 100644 --- a/README_zh.md +++ b/README_zh.md @@ -52,6 +52,7 @@ Web 版视频教程见 [BiliBili](https://www.bilibili.com/video/BV1S2421N7mn) Web 版给 android 的接口,也支持非 android 调用,见[python 样例代码](./tests/test_openxlab_android_api.py)。 +- \[2024/09\] [倒排索引](https://github.com/InternLM/HuixiangDou/pull/387)让 LLM 更偏向使用领域知识 🎯 - \[2024/09\] 稀疏方法实现[代码检索](./huixiangdou/service/parallel_pipeline.py) - \[2024/08\] ["chat_with readthedocs"](https://huixiangdou.readthedocs.io/zh-cn/latest/) ,见[集成说明](./docs/zh/doc_add_readthedocs.md) - \[2024/07\] 图文检索 & 移除 `langchain` 👍 @@ -366,7 +367,11 @@ python3 tests/test_query_gradio.py # 🛠️ FAQ -1. 机器人太高冷/太嘴碎怎么办? +1. 对于通用问题(如 “番茄是什么” ),我希望 LLM 优先用领域知识(如 “普罗旺斯番茄”)怎么办? + + 参照 [PR](https://github.com/InternLM/HuixiangDou/pull/387),准备实体列表,构建特征库时传入列表,`ParallelPipeline`检索会基于倒排索引增大召回 + +2. 机器人太高冷/太嘴碎怎么办? - 把真实场景中,应该回答的问题填入`resource/good_questions.json`,应该拒绝的填入`resource/bad_questions.json` - 调整 `repodir` 中的文档,确保不包含场景无关内容 @@ -375,30 +380,29 @@ python3 tests/test_query_gradio.py ⚠️ 如果你足够自信,也可以直接修改 config.ini 的 `reject_throttle` 数值,一般来说 0.5 是很高的值;0.2 过低。 -2. 启动正常,但运行期间显存 OOM 怎么办? +3. 启动正常,但运行期间显存 OOM 怎么办? 基于 transformers 结构的 LLM 长文本需要更多显存,此时需要对模型做 kv cache 量化,如 [lmdeploy 量化说明](https://github.com/InternLM/lmdeploy/blob/main/docs/zh_cn/quantization)。然后使用 docker 独立部署 Hybrid LLM Service。 -3. 如何接入其他 local LLM / 接入后效果不理想怎么办? +4. 如何接入其他 local LLM / 接入后效果不理想怎么办? - 打开 [hybrid llm service](./huixiangdou/service/llm_server_hybrid.py),增加新的 LLM 推理实现 - 参照 [test_intention_prompt 和测试数据](./tests/test_intention_prompt.py),针对新模型调整 prompt 和阈值,更新到 [prompt.py](./huixiangdou/service/prompt.py) -4. 响应太慢/网络请求总是失败怎么办? +5. 响应太慢/网络请求总是失败怎么办? - 参考 [hybrid llm service](./huixiangdou/service/llm_server_hybrid.py) 增加指数退避重传 - local LLM 替换为 [lmdeploy](https://github.com/internlm/lmdeploy) 等推理框架,而非原生的 huggingface/transformers -5. 机器配置低,GPU 显存不足怎么办? +6. 机器配置低,GPU 显存不足怎么办? 此时无法运行 local LLM,只能用 remote LLM 配合 text2vec 执行 pipeline。请确保 `config.ini` 只使用 remote LLM,关闭 local LLM -6. 报错 `(500, 'Internal Server Error')`,意为 standalone 模式启动的 LLM 服务没访问到。按如下方式定位 +7. 报错 `(500, 'Internal Server Error')`,意为 standalone 模式启动的 LLM 服务没访问到。按如下方式定位 - 执行 `python3 -m huixiangdou.service.llm_server_hybrid` 确定 LLM 服务无报错,监听的端口和配置一致。检查结束后按 ctrl-c 关掉。 - 检查 `config.ini` 中各种 TOKEN 书写正确。 - # 🍀 致谢 - [KIMI](https://kimi.moonshot.cn/): 长文本 LLM,支持直接上传文件 diff --git a/evaluation/end2end/main.py b/evaluation/end2end/main.py index edd1ed2..a16f0f5 100644 --- a/evaluation/end2end/main.py +++ b/evaluation/end2end/main.py @@ -2,12 +2,15 @@ from huixiangdou.primitive import Query import json import asyncio +import jieba import pdb +import os from typing import List from rouge import Rouge from loguru import logger -assistant = ParallelPipeline(work_dir='/home/khj/hxd-ci/workdir', config_path='/home/khj/hxd-ci/config.ini') +config_path = '/home/data/khj/workspace/huixiangdou/config.ini' +assistant = ParallelPipeline(work_dir='/home/data/khj/workspace/huixiangdou/workdir', config_path=config_path) def format_refs(refs: List[str]): refs_filter = list(set(refs)) @@ -31,49 +34,56 @@ async def run(query_text: str): refs = sess.references return sentence, refs -gts = [] -dts = [] -output_filepath = 'out.jsonl' - -finished_query = [] -with open(output_filepath) as fin: - json_str = "" - for line in fin: - json_str += line +if __name__ == "__main__": + gts = [] + dts = [] - if '}\n' == line: - print(json_str) - json_obj = json.loads(json_str) - finished_query.append(json_obj['query'].strip()) + # hybrid llm serve + print('evaluate ParallelPipeline precision, first `python3 -m huixiangdou.service.llm_server_hybrid`, then prepare your qa pair in `qa.json`.') + output_filepath = 'out.jsonl' + + finished_query = [] + if os.path.exists(output_filepath): + with open(output_filepath) as fin: json_str = "" + for line in fin: + json_str += line + + if '}\n' == line: + print(json_str) + json_obj = json.loads(json_str) + finished_query.append(json_obj['query'].strip()) + json_str = "" -with open('evaluation/end2end/qa.jsonl') as fin: - for json_str in fin: - json_obj = json.loads(json_str) - query = json_obj['query'].strip() - if query in finished_query: - continue - - gt = json_obj['resp'] - gts.append(gt) + with open('evaluation/end2end/qa.jsonl') as fin: + for json_str in fin: + json_obj = json.loads(json_str) + query = json_obj['query'].strip() + if query in finished_query: + continue + + gt = json_obj['resp'] + gts.append(gt) - loop = asyncio.get_event_loop() - dt, refs = loop.run_until_complete(run(query_text=query)) - dts.append(dt) + loop = asyncio.get_event_loop() + dt, refs = loop.run_until_complete(run(query_text=query)) + dts.append(dt) - distance = assistant.retriever.embedder.distance(text1=gt, text2=dt).tolist() + distance = assistant.retriever.embedder.distance(text1=gt, text2=dt).tolist() - rouge = Rouge() - scores = rouge.get_scores(gt, dt) - json_obj['distance'] = distance - json_obj['rouge_scores'] = scores - json_obj['dt'] = dt - json_obj['dt_refs'] = refs + rouge = Rouge() + dt_tokenized = ' '.join(jieba.cut(dt)) + gt_tokenized = ' '.join(jieba.cut(gt)) + scores = rouge.get_scores(dt_tokenized, gt_tokenized) + json_obj['distance'] = distance + json_obj['rouge_scores'] = scores + json_obj['dt'] = dt + json_obj['dt_refs'] = refs - out_json_str = json.dumps(json_obj, ensure_ascii=False, indent=2) - logger.info(out_json_str) + out_json_str = json.dumps(json_obj, ensure_ascii=False, indent=2) + logger.info(out_json_str) - with open(output_filepath, 'a') as fout: - fout.write(out_json_str) - fout.write('\n') + with open(output_filepath, 'a') as fout: + fout.write(out_json_str) + fout.write('\n') diff --git a/huixiangdou/main.py b/huixiangdou/main.py index f7dddf9..9a6cc6c 100755 --- a/huixiangdou/main.py +++ b/huixiangdou/main.py @@ -13,7 +13,6 @@ from .service import ErrorCode, SerialPipeline, build_reply_text, start_llm_server - def parse_args(): """Parse args.""" parser = argparse.ArgumentParser(description='SerialPipeline.') @@ -60,7 +59,6 @@ def check_env(args): def show(assistant, fe_config: dict): - queries = ['请问如何安装 mmpose ?', '请问明天天气如何?'] print(colored('Running some examples..', 'yellow')) for query in queries: @@ -142,7 +140,7 @@ def lark_group_recv_and_send(assistant, fe_config: dict): code, reply, refs = str(sess.code), sess.response, sess.references if code == ErrorCode.SUCCESS: json_obj['reply'] = build_reply_text(reply=reply, - references=references) + references=refs) error, msg_id = send_to_lark_group( json_obj=json_obj, app_id=lark_group_config['app_id'], @@ -169,7 +167,7 @@ async def api(request): for sess in assistant.generate(query=query, history=[], groupname=''): pass code, reply, refs = str(sess.code), sess.response, sess.references - reply_text = build_reply_text(reply=reply, references=references) + reply_text = build_reply_text(reply=reply, references=refs) return web.json_response({'code': int(code), 'reply': reply_text}) diff --git a/huixiangdou/primitive/__init__.py b/huixiangdou/primitive/__init__.py index 173fec0..d626c4e 100644 --- a/huixiangdou/primitive/__init__.py +++ b/huixiangdou/primitive/__init__.py @@ -15,3 +15,4 @@ nested_split_markdown, split_python_code) from .limitter import RPM, TPM from .bm250kapi import BM25Okapi +from .entity import NamedEntity2Chunk diff --git a/huixiangdou/primitive/entity.py b/huixiangdou/primitive/entity.py new file mode 100644 index 0000000..95a5117 --- /dev/null +++ b/huixiangdou/primitive/entity.py @@ -0,0 +1,96 @@ +import sqlite3 +import os +import json +from typing import List, Union, Set + +class NamedEntity2Chunk: + """Save the relationship between Named Entity and Chunk to sqlite""" + def __init__(self, file_dir:str, ignore_case=True): + self.file_dir = file_dir + # case sensitive + self.ignore_case = ignore_case + if not os.path.exists(file_dir): + os.makedirs(file_dir) + self.conn = sqlite3.connect(os.path.join(file_dir, 'entity2chunk.sql')) + self.cursor = self.conn.cursor() + self.cursor.execute(''' + CREATE TABLE IF NOT EXISTS entities ( + eid INTEGER PRIMARY KEY, + chunk_ids TEXT + ) + ''') + self.conn.commit() + self.entities = [] + self.entity_path = os.path.join(self.file_dir, 'entities.json') + if os.path.exists(self.entity_path): + with open(self.entity_path) as f: + self.entities = json.load(f) + if self.ignore_case: + for id, value in enumerate(self.entities): + self.entities[id] = value.lower() + + def clean(self): + self.cursor.execute('''DROP TABLE entities;''') + self.cursor.execute(''' + CREATE TABLE IF NOT EXISTS entities ( + eid INTEGER PRIMARY KEY, + chunk_ids TEXT + ) + ''') + self.conn.commit() + + def insert_relation(self, eid: int, chunk_ids: List[int]): + """Insert the relationship between keywords id and List of chunk_id""" + chunk_ids_str = ','.join(map(str, chunk_ids)) + self.cursor.execute('INSERT INTO entities (eid, chunk_ids) VALUES (?, ?)', (eid, chunk_ids_str)) + self.conn.commit() + + def parse(self, text:str) -> List[int]: + if self.ignore_case: + text = text.lower() + + if len(self.entities) < 1: + raise ValueError('entity list empty, please check feature_store init') + ret = [] + for index, entity in enumerate(self.entities): + if entity in text: + ret.append(index) + return ret + + def set_entity(self, entities: List[str]): + json_str = json.dumps(entities, ensure_ascii=False) + with open(self.entity_path, 'w') as f: + f.write(json_str) + + self.entities = entities + if self.ignore_case: + for id, value in enumerate(self.entities): + self.entities[id] = value.lower() + + def get_chunk_ids(self, entity_ids: Union[List, int]) -> Set: + """Query by keywords ids""" + if type(entity_ids) is int: + entity_ids = [entity_ids] + + counter = dict() + for eid in entity_ids: + self.cursor.execute('SELECT chunk_ids FROM entities WHERE eid = ?', (eid,)) + result = self.cursor.fetchone() + if result: + chunk_ids = result[0].split(',') + for chunk_id_str in chunk_ids: + chunk_id = int(chunk_id_str) + if chunk_id not in counter: + counter[chunk_id] = 1 + else: + counter[chunk_id] += 1 + + counter_list = [] + for k,v in counter.items(): + counter_list.append((k,v)) + counter_list.sort(key=lambda item: item[1], reverse=True) + return counter_list + + def __del__(self): + self.cursor.close() + self.conn.close() diff --git a/huixiangdou/primitive/faiss.py b/huixiangdou/primitive/faiss.py index 1bd0fea..bc00899 100644 --- a/huixiangdou/primitive/faiss.py +++ b/huixiangdou/primitive/faiss.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from __future__ import annotations +import time import logging import os import pdb @@ -16,25 +17,13 @@ from .embedder import Embedder from .query import Query, DistanceStrategy from .chunk import Chunk - - -# heavily modified from langchain -def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any: - """Import faiss if available, otherwise raise error. - - Args: - no_avx2: Load FAISS strictly with no AVX2 optimization - so that the vectorstore is portable and compatible with other devices. - """ - try: - import faiss - except ImportError: - raise ImportError( - 'Could not import faiss python package. ' - 'Please install it with `pip install faiss-gpu` (for CUDA supported GPU) ' - 'or `pip install faiss-cpu` (depending on Python version).') - return faiss - +try: + import faiss +except ImportError: + raise ImportError( + 'Could not import faiss python package. ' + 'Please install it with `pip install faiss-gpu` (for CUDA supported GPU) ' + 'or `pip install faiss-cpu` (depending on Python version).') class Faiss(): @@ -57,8 +46,6 @@ def similarity_search(self, List of chunks most similar to the query text and L2 distance in float for each. High score represents more similarity. """ - faiss = dependable_faiss_import() - embedding = embedding.astype(np.float32) scores, indices = self.index.search(embedding, self.k) pairs = [] @@ -133,6 +120,23 @@ def split_by_batchsize(self, chunks: List[Chunk] = [], batchsize:int = 4): block_image.append(images[i:i+batchsize]) return block_text, block_image + @classmethod + def build_index(self, np_feature: np.ndarray, distance_strategy: DistanceStrategy): + dimension = np_feature.shape[-1] + M = 16 + # max neighours for each node + # see https://github.com/facebookresearch/faiss/wiki/Indexing-1M-vectors + if distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: + # index = faiss.IndexFlatL2(dimension) + index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_L2) + elif distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + # index = faiss.IndexFlatIP(dimension) + index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_IP) + else: + raise ValueError('Unknown distance {}'.format(distance_strategy)) + index.hnsw.efSearch = 128 + return index + @classmethod def save_local(self, folder_path: str, chunks: List[Chunk], embedder: Embedder) -> None: @@ -144,9 +148,9 @@ def save_local(self, folder_path: str, chunks: List[Chunk], embedder: embedding function. """ - faiss = dependable_faiss_import() index = None batchsize = 1 + # max neighbours for each node try: batchsize_str = os.getenv('HUIXIANGDOU_BATCHSIZE') @@ -176,25 +180,16 @@ def save_local(self, folder_path: str, chunks: List[Chunk], continue if index is None: - dimension = np_feature.shape[-1] - - if embedder.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: - index = faiss.IndexFlatL2(dimension) - elif embedder.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: - index = faiss.IndexFlatIP(dimension) + index = self.build_index(np_feature=np_feature, distance_strategy=embedder.distance_strategy) index.add(np_feature) else: # batching block_text, block_image = self.split_by_batchsize(chunks=chunks, batchsize=batchsize) for subchunks in tqdm(block_text, 'build_text'): np_features = embedder.embed_query_batch_text(chunks=subchunks) + if index is None: - dimension = np_features[0].shape[-1] - - if embedder.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: - index = faiss.IndexFlatL2(dimension) - elif embedder.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: - index = faiss.IndexFlatIP(dimension) + index = self.build_index(np_feature=np_features, distance_strategy=embedder.distance_strategy) index.add(np_features) for subchunks in tqdm(block_image, 'build_image'): @@ -205,12 +200,7 @@ def save_local(self, folder_path: str, chunks: List[Chunk], continue if index is None: - dimension = np_feature.shape[-1] - - if embedder.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: - index = faiss.IndexFlatL2(dimension) - elif embedder.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: - index = faiss.IndexFlatIP(dimension) + index = self.build_index(np_feature=np_feature, distance_strategy=embedder.distance_strategy) index.add(np_feature) path = Path(folder_path) @@ -237,9 +227,11 @@ def load_local(cls, folder_path: str) -> FAISS: """ path = Path(folder_path) # load index separately since it is not picklable - faiss = dependable_faiss_import() + + t1 = time.time() index = faiss.read_index(str(path / f'embedding.faiss')) strategy = DistanceStrategy.UNKNOWN + t2 = time.time() # load docstore with open(path / f'chunks_and_strategy.pkl', 'rb') as f: @@ -254,4 +246,6 @@ def load_local(cls, folder_path: str) -> FAISS: else: raise ValueError('Unknown strategy type {}'.format(strategy_str)) + t3 = time.time() + logger.info('Timecost for load dense, load faiss {} seconds, load chunk {} seconds'.format(int(t2-t1), int(t3-t2))) return cls(index, chunks, strategy) diff --git a/huixiangdou/primitive/splitter.py b/huixiangdou/primitive/splitter.py index 1820dd8..bb895f3 100644 --- a/huixiangdou/primitive/splitter.py +++ b/huixiangdou/primitive/splitter.py @@ -618,11 +618,12 @@ def nested_split_markdown(filepath: str, modal='image') image_chunks.append(c) else: - logger.error( - f'image cannot access. file: {filepath}, image path: {image_path}' - ) + pass + # logger.error( + # f'image cannot access. file: {filepath}, image path: {image_path}' + # ) - logger.info('{} text_chunks, {} image_chunks'.format(len(text_chunks), len(image_chunks))) + # logger.info('{} text_chunks, {} image_chunks'.format(len(text_chunks), len(image_chunks))) return text_chunks + image_chunks def split_python_code(filepath: str, text: str, metadata: dict = {}): diff --git a/huixiangdou/server.py b/huixiangdou/server.py index 0d6b68f..24c145b 100644 --- a/huixiangdou/server.py +++ b/huixiangdou/server.py @@ -1,19 +1,9 @@ import argparse -import os -import time -import pytoml -import requests -from aiohttp import web -from loguru import logger -from termcolor import colored - -from .service import ErrorCode, SerialPipeline, ParallelPipeline, start_llm_server +from .service import SerialPipeline, ParallelPipeline, start_llm_server from .primitive import Query -import asyncio -from fastapi import FastAPI, APIRouter +from fastapi import FastAPI from fastapi.responses import StreamingResponse -from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import uvicorn import json @@ -43,7 +33,6 @@ async def huixiangdou_inference(talk: Talk): query = Query(talk.text, talk.image) pipeline = {'step': []} - debug = dict() if type(assistant) is SerialPipeline: for sess in assistant.generate(query=query): status = { @@ -73,7 +62,6 @@ async def huixiangdou_stream(talk: Talk): query = Query(talk.text, talk.image) pipeline = {'step': []} - debug = dict() def event_stream(): for sess in assistant.generate(query=query): @@ -104,7 +92,7 @@ async def event_stream_async(): def parse_args(): """Parse args.""" - parser = argparse.ArgumentParser(description='SerialPipeline.') + parser = argparse.ArgumentParser(description='Serial or Parallel Pipeline.') parser.add_argument('--work_dir', type=str, default='workdir', diff --git a/huixiangdou/service/feature_store.py b/huixiangdou/service/feature_store.py index 1b75e88..47b0677 100644 --- a/huixiangdou/service/feature_store.py +++ b/huixiangdou/service/feature_store.py @@ -6,19 +6,21 @@ import pdb import re import shutil +import time from multiprocessing import Pool from typing import Any, Dict, List, Optional - +import random import pytoml from loguru import logger from torch.cuda import empty_cache from tqdm import tqdm + from ..primitive import (ChineseRecursiveTextSplitter, Chunk, Embedder, Faiss, FileName, FileOperation, RecursiveCharacterTextSplitter, nested_split_markdown, split_python_code, - BM25Okapi) + BM25Okapi, NamedEntity2Chunk) from .helper import histogram from .llm_server_hybrid import start_llm_server from .retriever import CacheRetriever, Retriever @@ -110,7 +112,42 @@ def parse_markdown(self, file: FileName, metadata: Dict): for c in chunks: length += len(c.content_or_path) return chunks, length - + + def build_inverted_index(self, chunks: List[Chunk], ner_file: str, work_dir: str): + """Build inverted index based on named entity for knowledge base.""" + if ner_file is None: + return + # 倒排索引 retrieve 建库 + index_dir = os.path.join(work_dir, 'db_reverted_index') + if not os.path.exists(index_dir): + os.makedirs(index_dir) + entities = [] + with open(ner_file) as f: + entities = json.load(f) + + time0 = time.time() + map_entity2chunks = dict() + indexer = NamedEntity2Chunk(file_dir=index_dir) + indexer.clean() + indexer.set_entity(entities=entities) + + # build inverted index + for chunk_id, chunk in enumerate(chunks): + if chunk.modal != 'text': + continue + entity_ids = indexer.parse(text=chunk.content_or_path) + for entity_id in entity_ids: + if entity_id not in map_entity2chunks: + map_entity2chunks[entity_id] = [chunk_id] + else: + map_entity2chunks[entity_id].append(chunk_id) + + for entity_id, chunk_indexes in map_entity2chunks.items(): + indexer.insert_relation(eid = entity_id, chunk_ids=chunk_indexes) + del indexer + time1 = time.time() + logger.info('Timecost for build_inverted_index {}s'.format(time1-time0)) + def build_sparse(self, files: List[FileName], work_dir: str): """Use BM25 for building code feature""" # split by function, class and annotation, remove blank @@ -168,13 +205,17 @@ def build_dense(self, files: List[FileName], work_dir: str, markdown_as_txt: boo else: filtered_chunks = chunks if len(chunks) < 1: - return + return chunks self.analyze(filtered_chunks) Faiss.save_local(folder_path=feature_dir, chunks=filtered_chunks, embedder=self.embedder) + return chunks def analyze(self, chunks: List[Chunk]): """Output documents length mean, median and histogram.""" + MAX_COUNT = 10000 + if len(chunks) > MAX_COUNT: + chunks = random.sample(chunks, MAX_COUNT) text_lens = [] token_lens = [] @@ -273,7 +314,7 @@ def preprocess(self, files: List, work_dir: str): file.state = False file.reason = 'read error' - def initialize(self, files: list, work_dir: str): + def initialize(self, files: list, ner_file:str, work_dir: str): """Initializes response and reject feature store. Only needs to be called once. Also calculates the optimal threshold @@ -286,10 +327,11 @@ def initialize(self, files: list, work_dir: str): self.preprocess(files=files, work_dir=work_dir) # build dense retrieval refusal-to-answer and response database documents = list(filter(lambda x: x._type != 'code', files)) - self.build_dense(files=documents, work_dir=work_dir) + chunks = self.build_dense(files=documents, work_dir=work_dir) codes = list(filter(lambda x: x._type == 'code', files)) self.build_sparse(files=codes, work_dir=work_dir) + self.build_inverted_index(chunks=chunks, ner_file=ner_file, work_dir=work_dir) def parse_args(): """Parse command-line arguments.""" @@ -320,6 +362,11 @@ def parse_args(): help= # noqa E251 'Negative examples json path. Default value is resource/bad_questions.json' # noqa E501 ) + parser.add_argument( + '--ner-file', + default=None, + help='The path of NER file, which is a dumped json list. HuixiangDou would build relationship between entities and chunks for retrieve.' + ) parser.add_argument( '--sample', help='Input an json file, save reject and search output.') parser.add_argument( @@ -408,7 +455,8 @@ def test_query(retriever: Retriever, sample: str = None): file_opr = FileOperation() files = file_opr.scan_dir(repo_dir=args.repo_dir) - fs_init.initialize(files=files, work_dir=args.work_dir) + + fs_init.initialize(files=files, ner_file=args.ner_file, work_dir=args.work_dir) file_opr.summarize(files) del fs_init diff --git a/huixiangdou/service/parallel_pipeline.py b/huixiangdou/service/parallel_pipeline.py index 75ebb9f..9ebeb7c 100644 --- a/huixiangdou/service/parallel_pipeline.py +++ b/huixiangdou/service/parallel_pipeline.py @@ -101,34 +101,25 @@ def process(self, sess: Session) -> Generator[Session, None, None]: class Text2vecRetrieval: """Text2vecNode is for retrieve from knowledge base.""" + def __init__(self, retriever: Retriever): + self.retriever = retriever - def __init__(self, config: dict, llm: ChatClient, retriever: Retriever, - language: str): - self.llm = llm + async def process_coroutine(self, sess: Session) -> Session: + """Try get reply with text2vec & rerank model.""" + # retrieve from knowledge base + sess.parallel_chunks = await asyncio.to_thread(self.retriever.text2vec_retrieve, sess.query) + return sess + +class InvertedIndexRetrieval: + """Text2vecNode is for retrieve from knowledge base.""" + def __init__(self, retriever: Retriever): self.retriever = retriever - llm_config = config['llm'] - self.context_max_length = llm_config['server'][ - 'local_llm_max_text_length'] - if llm_config['enable_remote']: - self.context_max_length = llm_config['server'][ - 'remote_llm_max_text_length'] - if language == 'zh': - self.TOPIC_TEMPLATE = TOPIC_TEMPLATE_CN - self.SCORING_RELAVANCE_TEMPLATE = SCORING_RELAVANCE_TEMPLATE_CN - self.GENERATE_TEMPLATE = GENERATE_TEMPLATE_CN - else: - self.TOPIC_TEMPLATE = TOPIC_TEMPLATE_EN - self.SCORING_RELAVANCE_TEMPLATE = SCORING_RELAVANCE_TEMPLATE_EN - self.GENERATE_TEMPLATE = GENERATE_TEMPLATE_EN - self.max_length = self.context_max_length - 2 * len( - self.GENERATE_TEMPLATE) async def process_coroutine(self, sess: Session) -> Session: """Try get reply with text2vec & rerank model.""" # retrieve from knowledge base - sess.parallel_chunks = await asyncio.to_thread(self.retriever.text2vec_retrieve, sess.query) - # sess.parallel_chunks = self.retriever.text2vec_retrieve(query=sess.query.text) + sess.parallel_chunks = await asyncio.to_thread(self.retriever.inverted_index_retrieve, sess.query) return sess class CodeRetrieval: @@ -312,7 +303,9 @@ async def generate(self, # build pipeline preproc = PreprocNode(self.config, self.llm, language) - text2vec = Text2vecRetrieval(self.config, self.llm, self.retriever, language) + text2vec = Text2vecRetrieval(self.retriever) + inverted_index = InvertedIndexRetrieval(self.retriever) + coderetrieval = CodeRetrieval(self.retriever) websearch = WebSearchRetrieval(self.config, self.config_path, self.llm, language) reduce = ReduceGenerate(self.config, self.llm, self.retriever, language) @@ -330,7 +323,8 @@ async def generate(self, return # parallel run text2vec, websearch and codesearch - tasks = [text2vec.process_coroutine(copy.deepcopy(sess))] + tasks = [text2vec.process_coroutine(copy.deepcopy(sess)), inverted_index.process_coroutine(copy.deepcopy(sess))] + if enable_web_search: tasks.append(websearch.process_coroutine(copy.deepcopy(sess))) diff --git a/huixiangdou/service/retriever.py b/huixiangdou/service/retriever.py index fe1ae6f..17e5727 100644 --- a/huixiangdou/service/retriever.py +++ b/huixiangdou/service/retriever.py @@ -11,11 +11,10 @@ from sklearn.metrics import precision_recall_curve from typing import Any, Union, Tuple, List -from huixiangdou.primitive import Embedder, Faiss, LLMReranker, Query, Chunk, BM25Okapi, FileOperation +from huixiangdou.primitive import Embedder, Faiss, LLMReranker, Query, Chunk, BM25Okapi, FileOperation, NamedEntity2Chunk from .helper import QueryTracker from .kg import KnowledgeGraph - class Retriever: """Tokenize and extract features from the project's chunks, for use in the reject pipeline and response pipeline.""" @@ -29,6 +28,7 @@ def __init__(self, config_path: str, embedder: Any, reranker: Any, self.embedder = embedder self.reranker = reranker self.faiss = None + self.work_dir = work_dir if not os.path.exists(work_dir): logger.warning('!!!warning, workdir not exist.!!!') @@ -89,8 +89,32 @@ def update_throttle(self, logger.info( f'The optimal threshold is: {optimal_threshold}, saved it to {config_path}' # noqa E501 ) - - def text2vec_retrieve(self, query: Union[Query, str]): + + def inverted_index_retrieve(self, query: Union[Query, str], topk=100) -> List[Chunk]: + """Retrieve chunks by named entity.""" + # reverted index retrieval + + reverted_index_dir = os.path.join(self.work_dir, 'db_reverted_index') + if not os.path.exists(reverted_index_dir): + return [] + + # In async executor, `reverted_indexer` must lazy build and destroy + reverted_indexer = NamedEntity2Chunk(reverted_index_dir) + if type(query) is str: + query = Query(text=query) + + entity_ids = reverted_indexer.parse(query.text) + # chunk_id match counter + chunk_id_score_list = reverted_indexer.get_chunk_ids(entity_ids=entity_ids) + chunk_id_score_list = chunk_id_score_list[0:topk] + del reverted_indexer + + chunks = [] + for chunk_id, ref_count in chunk_id_score_list: + chunks.append(self.faiss.chunks[chunk_id]) + return chunks + + def text2vec_retrieve(self, query: Union[Query, str]) -> List[Chunk]: """Retrieve chunks by text2vec model or knowledge graph. Args: @@ -112,8 +136,11 @@ def text2vec_retrieve(self, query: Union[Query, str]): logger.info('KG folder exists, but search failed, skip.') threshold = self.reject_throttle - graph_delta + t1 = time.time() pairs = self.faiss.similarity_search_with_query(self.embedder, query=query, threshold=threshold) + t2 = time.time() + logger.info('Timecost for text2vec_retrieve {} seconds'.format(float(t2-t1))) # 280ms chunks = [pair[0] for pair in pairs] return chunks @@ -245,7 +272,6 @@ def is_relative(self, logger.info('KG folder exists, but search failed, skip.') threshold = self.reject_throttle - graph_delta - if enable_threshold: pairs = self.faiss.similarity_search_with_query(self.embedder, query=query, threshold=threshold) else: diff --git a/requirements.txt b/requirements.txt index 11a57f8..dc91f7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,7 @@ readability-lxml redis requests scikit-learn +sqlite3 # See https://github.com/deanmalmgren/textract/issues/461 # textract @ git+https://github.com/tpoisonooo/textract@master textract diff --git a/unittest/primitive/test_entity.py b/unittest/primitive/test_entity.py new file mode 100644 index 0000000..16486fe --- /dev/null +++ b/unittest/primitive/test_entity.py @@ -0,0 +1,43 @@ +import os +import pdb + +from huixiangdou.primitive import NamedEntity2Chunk, Chunk + + +def test_entity_build_and_query(): + entities = ['HuixiangDou', 'WeChat'] + + indexer = NamedEntity2Chunk('/tmp') + indexer.clean() + indexer.set_entity(entities=entities) + + c0 = Chunk(content_or_path='How to deploy HuixiangDou on wechaty ?') + c1 = Chunk(content_or_path='do you know what huixiangdou means ?') + chunks = [c0, c1] + map_entity2chunks = dict() + # build inverted index + for chunk_id, chunk in enumerate(chunks): + if chunk.modal != 'text': + continue + entity_ids = indexer.parse(text=chunk.content_or_path) + for entity_id in entity_ids: + if entity_id not in map_entity2chunks: + map_entity2chunks[entity_id] = [chunk_id] + else: + map_entity2chunks[entity_id].append(chunk_id) + + for entity_id, chunk_indexes in map_entity2chunks.items(): + indexer.insert_relation(eid = entity_id, chunk_ids=chunk_indexes) + del indexer + + query_text = 'how to install wechat ?' + retriver = NamedEntity2Chunk('/tmp') + entity_ids = retriver.parse(query_text) + # chunk_id match counter + chunk_id_list = retriver.get_chunk_ids(entity_ids=entity_ids) + print(chunk_id_list) + assert chunk_id_list[0][0] == 0 + + +if __name__ == '__main__': + test_entity_build_and_query()