Skip to content

Commit

Permalink
feat(primitive/faiss.py): support HNSW and reverted indexer (#387)
Browse files Browse the repository at this point in the history
* feat(primitive/faiss.py): support HNSW

* feat(feature_store.py): simplify distribution

* feat(primitive/entity.py): add inverted index retrieve
  • Loading branch information
tpoisonooo authored Sep 23, 2024
1 parent fbd4ecb commit 7e1be3f
Show file tree
Hide file tree
Showing 14 changed files with 350 additions and 145 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ Our Web version has been released to [OpenXLab](https://openxlab.org.cn/apps/det

The Web version's API for Android also supports other devices. See [Python sample code](./tests/test_openxlab_android_api.py).

- \[2024/09\] [code retrieval](./huixiangdou/service/parallel_pipeline.py)
- \[2024/09\] [Inverted indexer](https://github.com/InternLM/HuixiangDou/pull/387) makes LLM prefer knowledge base🎯
- \[2024/09\] [Code retrieval](./huixiangdou/service/parallel_pipeline.py)
- \[2024/08\] [chat_with_readthedocs](https://huixiangdou.readthedocs.io/en/latest/), see [how to integrate](./docs/zh/doc_add_readthedocs.md) 👍
- \[2024/07\] Image and text retrieval & Removal of `langchain` 👍
- \[2024/07\] [Hybrid Knowledge Graph and Dense Retrieval](./docs/en/doc_knowledge_graph.md) improve 1.7% F1 score 🎯
Expand Down
18 changes: 11 additions & 7 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Web 版视频教程见 [BiliBili](https://www.bilibili.com/video/BV1S2421N7mn)

Web 版给 android 的接口,也支持非 android 调用,见[python 样例代码](./tests/test_openxlab_android_api.py)

- \[2024/09\] [倒排索引](https://github.com/InternLM/HuixiangDou/pull/387)让 LLM 更偏向使用领域知识 🎯
- \[2024/09\] 稀疏方法实现[代码检索](./huixiangdou/service/parallel_pipeline.py)
- \[2024/08\] ["chat_with readthedocs"](https://huixiangdou.readthedocs.io/zh-cn/latest/) ,见[集成说明](./docs/zh/doc_add_readthedocs.md)
- \[2024/07\] 图文检索 & 移除 `langchain` 👍
Expand Down Expand Up @@ -366,7 +367,11 @@ python3 tests/test_query_gradio.py

# 🛠️ FAQ

1. 机器人太高冷/太嘴碎怎么办?
1. 对于通用问题(如 “番茄是什么” ),我希望 LLM 优先用领域知识(如 “普罗旺斯番茄”)怎么办?

参照 [PR](https://github.com/InternLM/HuixiangDou/pull/387),准备实体列表,构建特征库时传入列表,`ParallelPipeline`检索会基于倒排索引增大召回

2. 机器人太高冷/太嘴碎怎么办?

- 把真实场景中,应该回答的问题填入`resource/good_questions.json`,应该拒绝的填入`resource/bad_questions.json`
- 调整 `repodir` 中的文档,确保不包含场景无关内容
Expand All @@ -375,30 +380,29 @@ python3 tests/test_query_gradio.py

⚠️ 如果你足够自信,也可以直接修改 config.ini 的 `reject_throttle` 数值,一般来说 0.5 是很高的值;0.2 过低。

2. 启动正常,但运行期间显存 OOM 怎么办?
3. 启动正常,但运行期间显存 OOM 怎么办?

基于 transformers 结构的 LLM 长文本需要更多显存,此时需要对模型做 kv cache 量化,如 [lmdeploy 量化说明](https://github.com/InternLM/lmdeploy/blob/main/docs/zh_cn/quantization)。然后使用 docker 独立部署 Hybrid LLM Service。

3. 如何接入其他 local LLM / 接入后效果不理想怎么办?
4. 如何接入其他 local LLM / 接入后效果不理想怎么办?

- 打开 [hybrid llm service](./huixiangdou/service/llm_server_hybrid.py),增加新的 LLM 推理实现
- 参照 [test_intention_prompt 和测试数据](./tests/test_intention_prompt.py),针对新模型调整 prompt 和阈值,更新到 [prompt.py](./huixiangdou/service/prompt.py)

4. 响应太慢/网络请求总是失败怎么办?
5. 响应太慢/网络请求总是失败怎么办?

- 参考 [hybrid llm service](./huixiangdou/service/llm_server_hybrid.py) 增加指数退避重传
- local LLM 替换为 [lmdeploy](https://github.com/internlm/lmdeploy) 等推理框架,而非原生的 huggingface/transformers

5. 机器配置低,GPU 显存不足怎么办?
6. 机器配置低,GPU 显存不足怎么办?

此时无法运行 local LLM,只能用 remote LLM 配合 text2vec 执行 pipeline。请确保 `config.ini` 只使用 remote LLM,关闭 local LLM

6. 报错 `(500, 'Internal Server Error')`,意为 standalone 模式启动的 LLM 服务没访问到。按如下方式定位
7. 报错 `(500, 'Internal Server Error')`,意为 standalone 模式启动的 LLM 服务没访问到。按如下方式定位

- 执行 `python3 -m huixiangdou.service.llm_server_hybrid` 确定 LLM 服务无报错,监听的端口和配置一致。检查结束后按 ctrl-c 关掉。
- 检查 `config.ini` 中各种 TOKEN 书写正确。


# 🍀 致谢

- [KIMI](https://kimi.moonshot.cn/): 长文本 LLM,支持直接上传文件
Expand Down
86 changes: 48 additions & 38 deletions evaluation/end2end/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
from huixiangdou.primitive import Query
import json
import asyncio
import jieba
import pdb
import os
from typing import List
from rouge import Rouge
from loguru import logger

assistant = ParallelPipeline(work_dir='/home/khj/hxd-ci/workdir', config_path='/home/khj/hxd-ci/config.ini')
config_path = '/home/data/khj/workspace/huixiangdou/config.ini'
assistant = ParallelPipeline(work_dir='/home/data/khj/workspace/huixiangdou/workdir', config_path=config_path)

def format_refs(refs: List[str]):
refs_filter = list(set(refs))
Expand All @@ -31,49 +34,56 @@ async def run(query_text: str):
refs = sess.references
return sentence, refs

gts = []
dts = []

output_filepath = 'out.jsonl'

finished_query = []
with open(output_filepath) as fin:
json_str = ""
for line in fin:
json_str += line
if __name__ == "__main__":
gts = []
dts = []

if '}\n' == line:
print(json_str)
json_obj = json.loads(json_str)
finished_query.append(json_obj['query'].strip())
# hybrid llm serve
print('evaluate ParallelPipeline precision, first `python3 -m huixiangdou.service.llm_server_hybrid`, then prepare your qa pair in `qa.json`.')
output_filepath = 'out.jsonl'

finished_query = []
if os.path.exists(output_filepath):
with open(output_filepath) as fin:
json_str = ""
for line in fin:
json_str += line

if '}\n' == line:
print(json_str)
json_obj = json.loads(json_str)
finished_query.append(json_obj['query'].strip())
json_str = ""

with open('evaluation/end2end/qa.jsonl') as fin:
for json_str in fin:
json_obj = json.loads(json_str)
query = json_obj['query'].strip()
if query in finished_query:
continue

gt = json_obj['resp']
gts.append(gt)
with open('evaluation/end2end/qa.jsonl') as fin:
for json_str in fin:
json_obj = json.loads(json_str)
query = json_obj['query'].strip()
if query in finished_query:
continue
gt = json_obj['resp']
gts.append(gt)

loop = asyncio.get_event_loop()
dt, refs = loop.run_until_complete(run(query_text=query))
dts.append(dt)
loop = asyncio.get_event_loop()
dt, refs = loop.run_until_complete(run(query_text=query))
dts.append(dt)

distance = assistant.retriever.embedder.distance(text1=gt, text2=dt).tolist()
distance = assistant.retriever.embedder.distance(text1=gt, text2=dt).tolist()

rouge = Rouge()
scores = rouge.get_scores(gt, dt)
json_obj['distance'] = distance
json_obj['rouge_scores'] = scores
json_obj['dt'] = dt
json_obj['dt_refs'] = refs
rouge = Rouge()
dt_tokenized = ' '.join(jieba.cut(dt))
gt_tokenized = ' '.join(jieba.cut(gt))
scores = rouge.get_scores(dt_tokenized, gt_tokenized)
json_obj['distance'] = distance
json_obj['rouge_scores'] = scores
json_obj['dt'] = dt
json_obj['dt_refs'] = refs

out_json_str = json.dumps(json_obj, ensure_ascii=False, indent=2)
logger.info(out_json_str)
out_json_str = json.dumps(json_obj, ensure_ascii=False, indent=2)
logger.info(out_json_str)

with open(output_filepath, 'a') as fout:
fout.write(out_json_str)
fout.write('\n')
with open(output_filepath, 'a') as fout:
fout.write(out_json_str)
fout.write('\n')
6 changes: 2 additions & 4 deletions huixiangdou/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from .service import ErrorCode, SerialPipeline, build_reply_text, start_llm_server


def parse_args():
"""Parse args."""
parser = argparse.ArgumentParser(description='SerialPipeline.')
Expand Down Expand Up @@ -60,7 +59,6 @@ def check_env(args):


def show(assistant, fe_config: dict):

queries = ['请问如何安装 mmpose ?', '请问明天天气如何?']
print(colored('Running some examples..', 'yellow'))
for query in queries:
Expand Down Expand Up @@ -142,7 +140,7 @@ def lark_group_recv_and_send(assistant, fe_config: dict):
code, reply, refs = str(sess.code), sess.response, sess.references
if code == ErrorCode.SUCCESS:
json_obj['reply'] = build_reply_text(reply=reply,
references=references)
references=refs)
error, msg_id = send_to_lark_group(
json_obj=json_obj,
app_id=lark_group_config['app_id'],
Expand All @@ -169,7 +167,7 @@ async def api(request):
for sess in assistant.generate(query=query, history=[], groupname=''):
pass
code, reply, refs = str(sess.code), sess.response, sess.references
reply_text = build_reply_text(reply=reply, references=references)
reply_text = build_reply_text(reply=reply, references=refs)

return web.json_response({'code': int(code), 'reply': reply_text})

Expand Down
1 change: 1 addition & 0 deletions huixiangdou/primitive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
nested_split_markdown, split_python_code)
from .limitter import RPM, TPM
from .bm250kapi import BM25Okapi
from .entity import NamedEntity2Chunk
96 changes: 96 additions & 0 deletions huixiangdou/primitive/entity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import sqlite3
import os
import json
from typing import List, Union, Set

class NamedEntity2Chunk:
"""Save the relationship between Named Entity and Chunk to sqlite"""
def __init__(self, file_dir:str, ignore_case=True):
self.file_dir = file_dir
# case sensitive
self.ignore_case = ignore_case
if not os.path.exists(file_dir):
os.makedirs(file_dir)
self.conn = sqlite3.connect(os.path.join(file_dir, 'entity2chunk.sql'))
self.cursor = self.conn.cursor()
self.cursor.execute('''
CREATE TABLE IF NOT EXISTS entities (
eid INTEGER PRIMARY KEY,
chunk_ids TEXT
)
''')
self.conn.commit()
self.entities = []
self.entity_path = os.path.join(self.file_dir, 'entities.json')
if os.path.exists(self.entity_path):
with open(self.entity_path) as f:
self.entities = json.load(f)
if self.ignore_case:
for id, value in enumerate(self.entities):
self.entities[id] = value.lower()

def clean(self):
self.cursor.execute('''DROP TABLE entities;''')
self.cursor.execute('''
CREATE TABLE IF NOT EXISTS entities (
eid INTEGER PRIMARY KEY,
chunk_ids TEXT
)
''')
self.conn.commit()

def insert_relation(self, eid: int, chunk_ids: List[int]):
"""Insert the relationship between keywords id and List of chunk_id"""
chunk_ids_str = ','.join(map(str, chunk_ids))
self.cursor.execute('INSERT INTO entities (eid, chunk_ids) VALUES (?, ?)', (eid, chunk_ids_str))
self.conn.commit()

def parse(self, text:str) -> List[int]:
if self.ignore_case:
text = text.lower()

if len(self.entities) < 1:
raise ValueError('entity list empty, please check feature_store init')
ret = []
for index, entity in enumerate(self.entities):
if entity in text:
ret.append(index)
return ret

def set_entity(self, entities: List[str]):
json_str = json.dumps(entities, ensure_ascii=False)
with open(self.entity_path, 'w') as f:
f.write(json_str)

self.entities = entities
if self.ignore_case:
for id, value in enumerate(self.entities):
self.entities[id] = value.lower()

def get_chunk_ids(self, entity_ids: Union[List, int]) -> Set:
"""Query by keywords ids"""
if type(entity_ids) is int:
entity_ids = [entity_ids]

counter = dict()
for eid in entity_ids:
self.cursor.execute('SELECT chunk_ids FROM entities WHERE eid = ?', (eid,))
result = self.cursor.fetchone()
if result:
chunk_ids = result[0].split(',')
for chunk_id_str in chunk_ids:
chunk_id = int(chunk_id_str)
if chunk_id not in counter:
counter[chunk_id] = 1
else:
counter[chunk_id] += 1

counter_list = []
for k,v in counter.items():
counter_list.append((k,v))
counter_list.sort(key=lambda item: item[1], reverse=True)
return counter_list

def __del__(self):
self.cursor.close()
self.conn.close()
Loading

0 comments on commit 7e1be3f

Please sign in to comment.