diff --git a/.github/workflows/docker-build.yaml b/.github/workflows/docker-build.yaml index 70836f2d4..17245ba39 100644 --- a/.github/workflows/docker-build.yaml +++ b/.github/workflows/docker-build.yaml @@ -1,154 +1,154 @@ -name: docker-build -on: - push: - branches: - - master - paths-ignore: - - 'README.md' - - 'README_en.md' -env: - TZ: Asia/Shanghai -jobs: - docker-build: - runs-on: ubuntu-latest - # if: github.event.pull_request.merged == true - steps: - - name: Optimize Disk Space - uses: hugoalh/disk-space-optimizer-ghaction@v0.8.0 - with: - operate_sudo: "True" - general_include: ".+" - general_exclude: |- - ^GCC$ - ^G\+\+$ - Clang - LLVM - docker_include: ".+" - docker_prune: "True" - docker_clean: "True" - apt_prune: "True" - apt_clean: "True" - homebrew_prune: "True" - homebrew_clean: "True" - npm_prune: "True" - npm_clean: "True" - os_swap: "True" - - name: Remove Unnecessary Tools And Files - env: - DEBIAN_FRONTEND: noninteractive - run: | - sudo apt-get remove -y '^dotnet-.*' '^llvm-.*' 'php.*' azure-cli google-chrome-stable firefox powershell mono-devel - sudo apt-get autoremove --purge -y - sudo find /var/log -name "*.gz" -type f -delete - sudo rm -rf /var/cache/apt/archives - sudo rm -rf /tmp/* - sudo rm -rf /etc/apt/sources.list.d/* /usr/share/dotnet /usr/local/lib/android /opt/ghc /etc/mysql /etc/php - sudo -E apt-get -y purge azure-cli* docker* ghc* zulu* hhvm* llvm* firefox* google* dotnet* aspnetcore* powershell* openjdk* adoptopenjdk* mysql* php* mongodb* moby* snap* || true - sudo rm -rf /etc/apt/sources.list.d/* /usr/local/lib/android /opt/ghc /usr/share/dotnet /usr/local/graalvm /usr/local/.ghcup \ - /usr/local/share/powershell /usr/local/share/chromium /usr/local/lib/node_modules - sudo rm -rf /etc/apt/sources.list.d/* /usr/share/dotnet /usr/local/lib/android /opt/ghc /etc/mysql /etc/php - sudo -E apt-get -y purge azure-cli* docker* ghc* zulu* hhvm* llvm* firefox* google* dotnet* aspnetcore* powershell* openjdk* adoptopenjdk* mysql* php* mongodb* moby* snap* || true - sudo -E apt-get -qq update - sudo -E apt-get -qq install libfuse-dev $(curl -fsSL git.io/depends-ubuntu-2204) - sudo -E apt-get -qq autoremove --purge - sudo -E apt-get -qq clean - sudo apt-get clean - rm -rf /opt/hostedtoolcache - sudo timedatectl set-timezone "$TZ" - - name: Free Up Disk Space - uses: easimon/maximize-build-space@master - with: - root-reserve-mb: 62464 # 给 / 预留 61GiB 空间( docker 预留) - swap-size-mb: 1 - remove-dotnet: 'true' - remove-android: 'true' - remove-haskell: 'true' - remove-codeql: 'true' - remove-docker-images: 'true' - - name: Checkout Repository - uses: actions/checkout@v4 - - name: Get Latest Release - id: get_version - run: | - VERSION=$(curl --silent "https://api.github.com/repos/${{ github.repository }}/releases/latest" | jq -r .tag_name) - echo "RELEASE_VERSION=${VERSION}" >> $GITHUB_ENV - - name: Set Image Tag - id: imageTag - run: echo "::set-output name=image_tag::$RELEASE_VERSION-$(date +%Y%m%d)-$(git rev-parse --short HEAD)" - - name: Set Up QEMU - uses: docker/setup-qemu-action@v2 - - name: Set Up Docker Buildx - uses: docker/setup-buildx-action@v2 - - name: Clone Model - run: | - sudo mkdir -p $GITHUB_WORKSPACE/bge-large-zh-v1.5 - cd $GITHUB_WORKSPACE/bge-large-zh-v1.5 - sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/.gitattributes &> /dev/null - sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/config.json &> /dev/null - sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/config_sentence_transformers.json &> /dev/null - sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/modules.json &> /dev/null - sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/pytorch_model.bin &> /dev/null - sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/sentence_bert_config.json &> /dev/null - sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/special_tokens_map.json &> /dev/null - sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/tokenizer.json &> /dev/null - sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/tokenizer_config.json &> /dev/null - sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/vocab.txt &> /dev/null - sudo mkdir -p $GITHUB_WORKSPACE/bge-large-zh-v1.5/1_Pooling - cd $GITHUB_WORKSPACE/bge-large-zh-v1.5/1_Pooling - sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/1_Pooling/config.json &> /dev/null - sudo mkdir -p $GITHUB_WORKSPACE/chatglm3-6b - cd $GITHUB_WORKSPACE/chatglm3-6b - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/config.json &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/configuration_chatglm.py &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model-00001-of-00007.safetensors &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model-00002-of-00007.safetensors &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model-00003-of-00007.safetensors &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model-00004-of-00007.safetensors &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model-00005-of-00007.safetensors &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model-00006-of-00007.safetensors &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model-00007-of-00007.safetensors &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model.safetensors.index.json &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/modeling_chatglm.py &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/pytorch_model.bin.index.json &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/quantization.py &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/special_tokens_map.json &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/tokenization_chatglm.py &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/tokenizer.model &> /dev/null - sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/tokenizer_config.json &> /dev/null - du -sh $GITHUB_WORKSPACE - du -sh $GITHUB_WORKSPACE/* - du -sh $GITHUB_WORKSPACE/bge-large-zh-v1.5/* - du -sh $GITHUB_WORKSPACE/chatglm3-6b/* - - name: Show Runner Disk - run: df -hT - - name: Docker Build - run: | - docker build -t uswccr.ccs.tencentyun.com/chatchat/chatchat:${{ steps.imageTag.outputs.image_tag }} -f Dockerfile . - - name: Show Images Size - run: docker images - - name: Login To Tencent CCR - uses: docker/login-action@v2 - with: - registry: uswccr.ccs.tencentyun.com - username: ${{ secrets.CCR_REGISTRY_USERNAME }} - password: ${{ secrets.CCR_REGISTRY_PASSWORD }} - - name: Docker Push - run: docker push uswccr.ccs.tencentyun.com/chatchat/chatchat:${{ steps.imageTag.outputs.image_tag }} -# - name: Login to Docker Hub +#name: docker-build +#on: +# push: +# branches: +# - master +# paths-ignore: +# - 'README.md' +# - 'README_en.md' +#env: +# TZ: Asia/Shanghai +#jobs: +# docker-build: +# runs-on: ubuntu-latest +# # if: github.event.pull_request.merged == true +# steps: +# - name: Optimize Disk Space +# uses: hugoalh/disk-space-optimizer-ghaction@v0.8.0 +# with: +# operate_sudo: "True" +# general_include: ".+" +# general_exclude: |- +# ^GCC$ +# ^G\+\+$ +# Clang +# LLVM +# docker_include: ".+" +# docker_prune: "True" +# docker_clean: "True" +# apt_prune: "True" +# apt_clean: "True" +# homebrew_prune: "True" +# homebrew_clean: "True" +# npm_prune: "True" +# npm_clean: "True" +# os_swap: "True" +# - name: Remove Unnecessary Tools And Files +# env: +# DEBIAN_FRONTEND: noninteractive +# run: | +# sudo apt-get remove -y '^dotnet-.*' '^llvm-.*' 'php.*' azure-cli google-chrome-stable firefox powershell mono-devel +# sudo apt-get autoremove --purge -y +# sudo find /var/log -name "*.gz" -type f -delete +# sudo rm -rf /var/cache/apt/archives +# sudo rm -rf /tmp/* +# sudo rm -rf /etc/apt/sources.list.d/* /usr/share/dotnet /usr/local/lib/android /opt/ghc /etc/mysql /etc/php +# sudo -E apt-get -y purge azure-cli* docker* ghc* zulu* hhvm* llvm* firefox* google* dotnet* aspnetcore* powershell* openjdk* adoptopenjdk* mysql* php* mongodb* moby* snap* || true +# sudo rm -rf /etc/apt/sources.list.d/* /usr/local/lib/android /opt/ghc /usr/share/dotnet /usr/local/graalvm /usr/local/.ghcup \ +# /usr/local/share/powershell /usr/local/share/chromium /usr/local/lib/node_modules +# sudo rm -rf /etc/apt/sources.list.d/* /usr/share/dotnet /usr/local/lib/android /opt/ghc /etc/mysql /etc/php +# sudo -E apt-get -y purge azure-cli* docker* ghc* zulu* hhvm* llvm* firefox* google* dotnet* aspnetcore* powershell* openjdk* adoptopenjdk* mysql* php* mongodb* moby* snap* || true +# sudo -E apt-get -qq update +# sudo -E apt-get -qq install libfuse-dev $(curl -fsSL git.io/depends-ubuntu-2204) +# sudo -E apt-get -qq autoremove --purge +# sudo -E apt-get -qq clean +# sudo apt-get clean +# rm -rf /opt/hostedtoolcache +# sudo timedatectl set-timezone "$TZ" +# - name: Free Up Disk Space +# uses: easimon/maximize-build-space@master +# with: +# root-reserve-mb: 62464 # 给 / 预留 61GiB 空间( docker 预留) +# swap-size-mb: 1 +# remove-dotnet: 'true' +# remove-android: 'true' +# remove-haskell: 'true' +# remove-codeql: 'true' +# remove-docker-images: 'true' +# - name: Checkout Repository +# uses: actions/checkout@v4 +# - name: Get Latest Release +# id: get_version +# run: | +# VERSION=$(curl --silent "https://api.github.com/repos/${{ github.repository }}/releases/latest" | jq -r .tag_name) +# echo "RELEASE_VERSION=${VERSION}" >> $GITHUB_ENV +# - name: Set Image Tag +# id: imageTag +# run: echo "::set-output name=image_tag::$RELEASE_VERSION-$(date +%Y%m%d)-$(git rev-parse --short HEAD)" +# - name: Set Up QEMU +# uses: docker/setup-qemu-action@v2 +# - name: Set Up Docker Buildx +# uses: docker/setup-buildx-action@v2 +# - name: Clone Model +# run: | +# sudo mkdir -p $GITHUB_WORKSPACE/bge-large-zh-v1.5 +# cd $GITHUB_WORKSPACE/bge-large-zh-v1.5 +# sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/.gitattributes &> /dev/null +# sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/config.json &> /dev/null +# sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/config_sentence_transformers.json &> /dev/null +# sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/modules.json &> /dev/null +# sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/pytorch_model.bin &> /dev/null +# sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/sentence_bert_config.json &> /dev/null +# sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/special_tokens_map.json &> /dev/null +# sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/tokenizer.json &> /dev/null +# sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/tokenizer_config.json &> /dev/null +# sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/vocab.txt &> /dev/null +# sudo mkdir -p $GITHUB_WORKSPACE/bge-large-zh-v1.5/1_Pooling +# cd $GITHUB_WORKSPACE/bge-large-zh-v1.5/1_Pooling +# sudo wget https://huggingface.co/BAAI/bge-large-zh-v1.5/resolve/main/1_Pooling/config.json &> /dev/null +# sudo mkdir -p $GITHUB_WORKSPACE/chatglm3-6b +# cd $GITHUB_WORKSPACE/chatglm3-6b +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/config.json &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/configuration_chatglm.py &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model-00001-of-00007.safetensors &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model-00002-of-00007.safetensors &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model-00003-of-00007.safetensors &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model-00004-of-00007.safetensors &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model-00005-of-00007.safetensors &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model-00006-of-00007.safetensors &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model-00007-of-00007.safetensors &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/model.safetensors.index.json &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/modeling_chatglm.py &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/pytorch_model.bin.index.json &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/quantization.py &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/special_tokens_map.json &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/tokenization_chatglm.py &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/tokenizer.model &> /dev/null +# sudo wget https://huggingface.co/THUDM/chatglm3-6b/resolve/main/tokenizer_config.json &> /dev/null +# du -sh $GITHUB_WORKSPACE +# du -sh $GITHUB_WORKSPACE/* +# du -sh $GITHUB_WORKSPACE/bge-large-zh-v1.5/* +# du -sh $GITHUB_WORKSPACE/chatglm3-6b/* +# - name: Show Runner Disk +# run: df -hT +# - name: Docker Build +# run: | +# docker build -t uswccr.ccs.tencentyun.com/chatchat/chatchat:${{ steps.imageTag.outputs.image_tag }} -f Dockerfile . +# - name: Show Images Size +# run: docker images +# - name: Login To Tencent CCR # uses: docker/login-action@v2 # with: -# username: ${{ secrets.DOCKERHUB_USERNAME }} -# password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Update README_en.md - run: | - sed -i "s|uswccr.ccs.tencentyun.com/chatchat/chatchat:[^ ]*|uswccr.ccs.tencentyun.com/chatchat/chatchat:${{ steps.imageTag.outputs.image_tag }}|g" README.md - sed -i "s|uswccr.ccs.tencentyun.com/chatchat/chatchat:[^ ]*|uswccr.ccs.tencentyun.com/chatchat/chatchat:${{ steps.imageTag.outputs.image_tag }}|g" README_en.md - sed -i "s|uswccr.ccs.tencentyun.com/chatchat/chatchat:[^ ]*|uswccr.ccs.tencentyun.com/chatchat/chatchat:${{ steps.imageTag.outputs.image_tag }}|g" README_ja.md - git config --local user.email "action@github.com" - git config --local user.name "GitHub Action" - git commit -am "feat:update docker image:tag" - - name: Push README_en.md - uses: ad-m/github-push-action@master - with: - github_token: ${{ secrets.GH_PAT }} - branch: ${{ github.ref }} \ No newline at end of file +# registry: uswccr.ccs.tencentyun.com +# username: ${{ secrets.CCR_REGISTRY_USERNAME }} +# password: ${{ secrets.CCR_REGISTRY_PASSWORD }} +# - name: Docker Push +# run: docker push uswccr.ccs.tencentyun.com/chatchat/chatchat:${{ steps.imageTag.outputs.image_tag }} +## - name: Login to Docker Hub +## uses: docker/login-action@v2 +## with: +## username: ${{ secrets.DOCKERHUB_USERNAME }} +## password: ${{ secrets.DOCKERHUB_TOKEN }} +# - name: Update README_en.md +# run: | +# sed -i "s|uswccr.ccs.tencentyun.com/chatchat/chatchat:[^ ]*|uswccr.ccs.tencentyun.com/chatchat/chatchat:${{ steps.imageTag.outputs.image_tag }}|g" README.md +# sed -i "s|uswccr.ccs.tencentyun.com/chatchat/chatchat:[^ ]*|uswccr.ccs.tencentyun.com/chatchat/chatchat:${{ steps.imageTag.outputs.image_tag }}|g" README_en.md +# sed -i "s|uswccr.ccs.tencentyun.com/chatchat/chatchat:[^ ]*|uswccr.ccs.tencentyun.com/chatchat/chatchat:${{ steps.imageTag.outputs.image_tag }}|g" README_ja.md +# git config --local user.email "action@github.com" +# git config --local user.name "GitHub Action" +# git commit -am "feat:update docker image:tag" +# - name: Push README_en.md +# uses: ad-m/github-push-action@master +# with: +# github_token: ${{ secrets.GH_PAT }} +# branch: ${{ github.ref }} \ No newline at end of file diff --git a/README.md b/README.md index 29208ca5c..8ae32ebb6 100644 --- a/README.md +++ b/README.md @@ -311,7 +311,7 @@ chatchat-config model --default_llm_model qwen2-instruct ```shell # 这里应为 3.2 中 "CHATCHAT_ROOT" 变量指向目录 cd /root/anaconda3/envs/chatchat/lib/python3.11/site-packages/chatchat -vim model_providers.yaml +vim configs/model_providers.yaml ``` 配置介绍请参考 [model-providers/README.md](libs/model-providers/README.md) @@ -427,6 +427,14 @@ chatchat -a > ``` > +### Docker 部署 +```shell +docker pull chatimage/chatchat:0.3.0-0623-3 +``` +> [!important] +> 强烈建议: 使用 docker-compose 部署, 具体参考 [README_docker](docs/install/README_docker.md) + + ### 旧版本迁移 * 0.3.x 结构改变很大,强烈建议您按照文档重新部署. 以下指南不保证100%兼容和成功. 记得提前备份重要数据! diff --git a/docker/Dockerfile b/docker/Dockerfile index b0dfd33c7..9044c04e9 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,19 +1,29 @@ +# Base Image FROM python:3.11 - -RUN apt-get update -RUN apt-get install -y libgl1-mesa-glx - -RUN mkdir /Langchain-Chatchat -COPY requirements.txt /Langchain-Chatchat -COPY requirements_api.txt /Langchain-Chatchat -COPY requirements_webui.txt /Langchain-Chatchat - -WORKDIR /Langchain-Chatchat -RUN pip install --upgrade pip -RUN pip install -r requirements.txt -RUN pip install -r requirements_api.txt -RUN pip install -r requirements_webui.txt - -EXPOSE 8501 -EXPOSE 7861 -EXPOSE 20000 +# Labels +LABEL maintainer=chatchat +# Environment Variables +ENV HOME=/usr/local/lib/python3.11/site-packages/chatchat +# Init Environment +RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ + echo "Asia/Shanghai" > /etc/timezone +# Install Dependencies +RUN apt-get update -y && \ + apt-get install -y --no-install-recommends libgl1 libglib2.0-0 && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* +RUN pip install openpyxl networkx faiss-cpu jq unstructured[pdf] \ + opencv-python rapidocr-onnxruntime PyMuPDF rank_bm25 youtube_search python-docx +# Install Chatchat +RUN pip install --index-url https://pypi.python.org/simple/ langchain-chatchat -U +# Install ModelProvider +RUN pip install xinference-client +# Make Custom Settings +RUN chatchat-config server --default_bind_host=0.0.0.0 && \ + chatchat-config model --default_llm_model qwen2-instruct +# Copy Data +COPY /libs/chatchat-server/chatchat/configs/model_providers.yaml $HOME/configs/model_providers.yaml +ADD /docker/data.tar.gz $HOME/ +WORKDIR $HOME +EXPOSE 7861 8501 +ENTRYPOINT ["chatchat", "-a"] \ No newline at end of file diff --git a/docker/data.tar.gz b/docker/data.tar.gz new file mode 100644 index 000000000..c6a7fe4f2 Binary files /dev/null and b/docker/data.tar.gz differ diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml new file mode 100644 index 000000000..e46b18a43 --- /dev/null +++ b/docker/docker-compose.yaml @@ -0,0 +1,38 @@ +version: '3.9' +services: + xinference: + image: xprobe/xinference:v0.12.1 + restart: always + command: xinference-local -H 0.0.0.0 + # ports: # 不使用 host network 时可打开. + # - "9997:9997" + network_mode: "host" + # 将本地路径(~/xinference)挂载到容器路径(/root/.xinference)中, + # 详情见: https://inference.readthedocs.io/zh-cn/latest/getting_started/using_docker_image.html + volumes: + - ~/xinference:/root/.xinference + # - ~/xinference/cache/huggingface:/root/.cache/huggingface + # - ~/xinference/cache/modelscope:/root/.cache/modelscope + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + runtime: nvidia + # 模型源更改为 ModelScope, 默认为 HuggingFace + # environment: + # - XINFERENCE_MODEL_SRC=modelscope + chatchat: + image: chatimage/chatchat:0.3.0-0623-3 + restart: always + # ports: # 不使用 host network 时可打开. + # - "7861:7861" + # - "8501:8501" + network_mode: "host" + # 将本地路径(~/chatchat/data)挂载到容器默认数据路径(/usr/local/lib/python3.11/site-packages/chatchat/data)中 + # 将本地模型接入配置文件(~/chatchat/model_providers.yaml)挂载到容器默认模型接入配置文件路径(/usr/local/lib/python3.11/site-packages/chatchat/configs/)中 + # volumes: + # - ~/chatchat/data:/usr/local/lib/python3.11/site-packages/chatchat/data + # - ~/chatchat/model_providers.yaml:/usr/local/lib/python3.11/site-packages/chatchat/configs/model_providers.yaml \ No newline at end of file diff --git a/docs/install/README_docker.md b/docs/install/README_docker.md new file mode 100644 index 000000000..b30a4acda --- /dev/null +++ b/docs/install/README_docker.md @@ -0,0 +1,183 @@ +### chatchat 容器化部署指引 + +> 提示: 此指引为在 Linux 环境下编写完成, 其他环境下暂未测试, 理论上可行. +> +> Langchain-Chatchat docker 镜像已支持多架构, 欢迎大家自行测试. + +#### 一. Langchain-Chatchat 体验部署 + +##### 1. 安装 docker-compose +寻找适合你环境的 docker-compose 版本, 请参考 [Docker-Compose](https://github.com/docker/compose). + +举例: Linux X86 环境 可下载 [docker-compose-linux-x86_64](https://github.com/docker/compose/releases/download/v2.27.3/docker-compose-linux-x86_64) 使用. +```shell +cd ~ +wget https://github.com/docker/compose/releases/download/v2.27.3/docker-compose-linux-x86_64 +mv docker-compose-linux-x86_64 /usr/bin/docker-compose +which docker-compose +``` +/usr/bin/docker-compose +```shell +docker-compose -v +``` +Docker Compose version v2.27.3 + +##### 2. 安装 NVIDIA Container Toolkit +寻找适合你环境的 NVIDIA Container Toolkit 版本, 请参考: [Installing the NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). + +安装完成后记得按照刚刚文档中`Configuring Docker`章节对 docker 进行初始化. + +##### 3. 创建 xinference 数据缓存路径 + +这一步强烈建议, 因为可以将 xinference 缓存的模型都保存到本地, 长期使用. +```shell +mkdir -p ~/xinference +``` + +##### 4. 下载 chatchat & xinference 启动配置文件(docker-compose.yaml) +```shell +cd ~ +wget https://github.com/chatchat-space/Langchain-Chatchat/blob/master/docker/docker-compose.yaml +``` + +##### 5. 启动 chatchat & xinference 服务 +```shell +docker-compose up -d +``` +出现如下日志即为成功 ( 第一次启动需要下载 docker 镜像, 时间较长, 这里已经提前下载好了 ) +```text +WARN[0000] /root/docker-compose.yaml: `version` is obsolete +[+] Running 2/2 + ✔ Container root-chatchat-1 Started 0.2s + ✔ Container root-xinference-1 Started 0.3s +``` + +##### 6.检查服务启动情况 +```shell +docker-compose up -d +``` +```text +WARN[0000] /root/docker-compose.yaml: `version` is obsolete +NAME IMAGE COMMAND SERVICE CREATED STATUS PORTS +root-chatchat-1 chatimage/chatchat:0.3.0-0622 "chatchat -a" chatchat 3 minutes ago Up 3 minutes +root-xinference-1 xprobe/xinference:v0.12.1 "/opt/nvidia/nvidia_…" xinference 3 minutes ago Up 3 minutes +``` +```shell +ss -anptl | grep -E '(8501|7861|9997)' +``` +```text +LISTEN 0 128 0.0.0.0:9997 0.0.0.0:* users:(("pt_main_thread",pid=1489804,fd=21)) +LISTEN 0 128 0.0.0.0:8501 0.0.0.0:* users:(("python",pid=1490078,fd=10)) +LISTEN 0 128 0.0.0.0:7861 0.0.0.0:* users:(("python",pid=1490014,fd=9)) +``` +如上, 服务均已正常启动, 即可体验使用. + +> 提示: 先登陆 xinference ui `http://:9997` 启动 llm 和 embedding 后, 再登陆 chatchat ui `http://:8501` 进行体验. +> +> 详细文档: +> - Langchain-chatchat 使用请参考: [LangChain-Chatchat](/README.md) +> +> - Xinference 使用请参考: [欢迎来到 Xinference!](https://inference.readthedocs.io/zh-cn/latest/index.html) + +#### 二. Langchain-Chatchat 进阶部署 + +##### 1. 按照 `Langchain-Chatchat 体验部署` 内容顺序依次完成 + +##### 2. 创建 chatchat 数据缓存路径 +```shell +cd ~ +mkdir -p ~/chatchat +``` + +##### 3. 修改 `docker-compose.yaml` 文件内容 + +原文件内容: +```yaml + (上文 ...) + chatchat: + image: chatimage/chatchat:0.3.0-0622 + (省略 ...) + # 将本地路径(~/chatchat/data)挂载到容器默认数据路径(/usr/local/lib/python3.11/site-packages/chatchat/data)中 + # 将本地模型接入配置文件(~/chatchat/model_providers.yaml)挂载到容器默认模型接入配置文件路径(/usr/local/lib/python3.11/site-packages/chatchat/configs/)中 + # volumes: + # - ~/chatchat/data:/usr/local/lib/python3.11/site-packages/chatchat/data + # - ~/chatchat/model_providers.yaml:/usr/local/lib/python3.11/site-packages/chatchat/configs/model_providers.yaml + (下文 ...) +``` +将 `volumes` 字段注释打开, 并按照 `YAML` 格式对齐, 如下: +```yaml + (上文 ...) + chatchat: + image: chatimage/chatchat:0.3.0-0622 + (省略 ...) + # 将本地路径(~/chatchat/data)挂载到容器默认数据路径(/usr/local/lib/python3.11/site-packages/chatchat/data)中 + # 将本地模型接入配置文件(~/chatchat/model_providers.yaml)挂载到容器默认模型接入配置文件路径(/usr/local/lib/python3.11/site-packages/chatchat/configs/)中 + volumes: + - ~/chatchat/data:/usr/local/lib/python3.11/site-packages/chatchat/data + - ~/chatchat/model_providers.yaml:/usr/local/lib/python3.11/site-packages/chatchat/configs/model_providers.yaml + (下文 ...) +``` + +##### 4. 下载数据库初始文件 + +> 提示: 这里的 `data.tar.gz` 文件仅包含初始化后的数据库 `samples` 文件一份及相应目录结构, 用户可将原先数据和目录结构迁移此处. +> > [!WARNING] 请您先备份好您的数据再进行迁移!!! + +```shell +cd ~/chatchat +wget https://github.com/chatchat-space/Langchain-Chatchat/blob/master/docker/data.tar.gz +tar -xvf data.tar.gz +``` +```shell +cd data +pwd +``` +/root/chatchat/data +```shell +ls -l +``` +```text +total 20 +drwxr-xr-x 3 root root 4096 Jun 22 10:46 knowledge_base +drwxr-xr-x 18 root root 4096 Jun 22 10:52 logs +drwxr-xr-x 5 root root 4096 Jun 22 10:46 media +drwxr-xr-x 5 root root 4096 Jun 22 10:46 nltk_data +drwxr-xr-x 3 root root 4096 Jun 22 10:46 temp +``` + +##### 5. 下载 `model_providers.yaml` 配置文件 +> 提示: 后续可以自定义本地路径下的 `model_providers.yaml` 来实现`自定义模型接入配置` +```shell +cd ~/chatchat +wget https://github.com/chatchat-space/Langchain-Chatchat/blob/master/libs/model-providers/model_providers.yaml +``` + +##### 6. 重启 chatchat 服务 + +这一步需要到 `docker-compose.yaml` 文件所在路径下执行, 即: +```shell +cd ~ +docker-compose down chatchat +docker-compose up -d chatchat +``` +操作及检查结果如下: +```text +[root@VM-2-15-centos ~]# docker-compose down chatchat +WARN[0000] /root/docker-compose.yaml: `version` is obsolete +[+] Running 1/1 + ✔ Container root-chatchat-1 Removed 0.5s +[root@VM-2-15-centos ~]# docker-compose up -d +WARN[0000] /root/docker-compose.yaml: `version` is obsolete +[+] Running 2/2 + ✔ Container root-xinference-1 Running 0.0s + ✔ Container root-chatchat-1 Started 0.2s +[root@VM-2-15-centos ~]# docker-compose ps +WARN[0000] /root/docker-compose.yaml: `version` is obsolete +NAME IMAGE COMMAND SERVICE CREATED STATUS PORTS +root-chatchat-1 chatimage/chatchat:0.3.0-0622 "chatchat -a" chatchat 33 seconds ago Up 32 seconds +root-xinference-1 xprobe/xinference:v0.12.1 "/opt/nvidia/nvidia_…" xinference 45 minutes ago Up 45 minutes +[root@VM-2-15-centos ~]# ss -anptl | grep -E '(8501|7861|9997)' +LISTEN 0 128 0.0.0.0:9997 0.0.0.0:* users:(("pt_main_thread",pid=1489804,fd=21)) +LISTEN 0 128 0.0.0.0:8501 0.0.0.0:* users:(("python",pid=1515944,fd=10)) +LISTEN 0 128 0.0.0.0:7861 0.0.0.0:* users:(("python",pid=1515878,fd=9)) +``` \ No newline at end of file diff --git a/libs/chatchat-server/chatchat/config_work_space.py b/libs/chatchat-server/chatchat/config_work_space.py index fce12107a..b23f88b0c 100644 --- a/libs/chatchat-server/chatchat/config_work_space.py +++ b/libs/chatchat-server/chatchat/config_work_space.py @@ -1,14 +1,15 @@ +import ast import json +# We cannot lazy-load click here because its used via decorators. +import click + from chatchat.configs import ( config_basic_workspace, + config_kb_workspace, config_model_workspace, config_server_workspace, - config_kb_workspace, ) -import ast -# We cannot lazy-load click here because its used via decorators. -import click @click.group(help="指令` chatchat-config` 工作空间配置") @@ -17,13 +18,14 @@ def main(): @main.command("basic", help="基础配置") -@click.option("--verbose", type=click.Choice(["true", "false"]), help="是否开启详细日志") +@click.option( + "--verbose", type=click.Choice(["true", "false"]), help="是否开启详细日志" +) @click.option("--data", help="初始化数据存放路径,注意:目录会清空重建") @click.option("--format", help="日志格式") @click.option("--clear", is_flag=True, help="清除配置") @click.option("--show", is_flag=True, help="显示配置") def basic(**kwargs): - if kwargs["verbose"]: if kwargs["verbose"].lower() == "true": config_basic_workspace.set_log_verbose(True) @@ -50,20 +52,31 @@ def basic(**kwargs): @click.option("--model_providers_cfg_path_config", help="模型平台配置文件路径") @click.option("--model_providers_cfg_host", help="模型平台配置服务host") @click.option("--model_providers_cfg_port", type=int, help="模型平台配置服务port") -@click.option("--set_model_platforms", type=str, help="""模型平台配置 +@click.option( + "--set_model_platforms", + type=str, + help="""模型平台配置 as a JSON string. - """) -@click.option("--set_tool_config", type=str, help=""" + """, +) +@click.option( + "--set_tool_config", + type=str, + help=""" 工具配置项 as a JSON string. - """) + """, +) @click.option("--clear", is_flag=True, help="清除配置") @click.option("--show", is_flag=True, help="显示配置") def model(**kwargs): - if kwargs["default_llm_model"]: - config_model_workspace.set_default_llm_model(llm_model=kwargs["default_llm_model"]) + config_model_workspace.set_default_llm_model( + llm_model=kwargs["default_llm_model"] + ) if kwargs["default_embedding_model"]: - config_model_workspace.set_default_embedding_model(embedding_model=kwargs["default_embedding_model"]) + config_model_workspace.set_default_embedding_model( + embedding_model=kwargs["default_embedding_model"] + ) if kwargs["agent_model"]: config_model_workspace.set_agent_model(agent_model=kwargs["agent_model"]) @@ -78,7 +91,9 @@ def model(**kwargs): config_model_workspace.set_temperature(temperature=kwargs["temperature"]) if kwargs["support_agent_models"]: - config_model_workspace.set_support_agent_models(support_agent_models=kwargs["support_agent_models"]) + config_model_workspace.set_support_agent_models( + support_agent_models=kwargs["support_agent_models"] + ) if kwargs["model_providers_cfg_path_config"]: config_model_workspace.set_model_providers_cfg_path_config( @@ -86,10 +101,14 @@ def model(**kwargs): ) if kwargs["model_providers_cfg_host"]: - config_model_workspace.set_model_providers_cfg_host(model_providers_cfg_host=kwargs["model_providers_cfg_host"]) + config_model_workspace.set_model_providers_cfg_host( + model_providers_cfg_host=kwargs["model_providers_cfg_host"] + ) if kwargs["model_providers_cfg_port"]: - config_model_workspace.set_model_providers_cfg_port(model_providers_cfg_port=kwargs["model_providers_cfg_port"]) + config_model_workspace.set_model_providers_cfg_port( + model_providers_cfg_port=kwargs["model_providers_cfg_port"] + ) if kwargs["set_model_platforms"]: model_platforms_dict = json.loads(kwargs["set_model_platforms"]) @@ -106,29 +125,38 @@ def model(**kwargs): @main.command("server", help="服务配置") @click.option("--httpx_default_timeout", type=int, help="httpx默认超时时间") -@click.option("--open_cross_domain", type=click.Choice(["true", "false"]), help="是否开启跨域") +@click.option( + "--open_cross_domain", type=click.Choice(["true", "false"]), help="是否开启跨域" +) @click.option("--default_bind_host", help="默认绑定host") @click.option("--webui_server_port", type=int, help="webui服务端口") @click.option("--api_server_port", type=int, help="api服务端口") @click.option("--clear", is_flag=True, help="清除配置") @click.option("--show", is_flag=True, help="显示配置") def server(**kwargs): - if kwargs["httpx_default_timeout"]: - config_server_workspace.set_httpx_default_timeout(httpx_default_timeout=kwargs["httpx_default_timeout"]) + config_server_workspace.set_httpx_default_timeout( + httpx_default_timeout=kwargs["httpx_default_timeout"] + ) if kwargs["open_cross_domain"]: if kwargs["open_cross_domain"].lower() == "true": config_server_workspace.set_open_cross_domain(True) else: config_server_workspace.set_open_cross_domain(False) if kwargs["default_bind_host"]: - config_server_workspace.set_default_bind_host(default_bind_host=kwargs["default_bind_host"]) + config_server_workspace.set_default_bind_host( + default_bind_host=kwargs["default_bind_host"] + ) if kwargs["webui_server_port"]: - config_server_workspace.set_webui_server_port(webui_server_port=kwargs["webui_server_port"]) + config_server_workspace.set_webui_server_port( + webui_server_port=kwargs["webui_server_port"] + ) if kwargs["api_server_port"]: - config_server_workspace.set_api_server_port(api_server_port=kwargs["api_server_port"]) + config_server_workspace.set_api_server_port( + api_server_port=kwargs["api_server_port"] + ) if kwargs["clear"]: config_server_workspace.clear() @@ -147,13 +175,21 @@ def server(**kwargs): @click.option("--set_score_threshold", type=float, help="设置score阈值") @click.option("--set_default_search_engine", help="设置默认搜索引擎") @click.option("--set_search_engine_top_k", type=int, help="设置搜索引擎top k") -@click.option("--set_zh_title_enhance", type=click.Choice(["true", "false"]), help="是否开启中文标题增强") -@click.option('--pdf-ocr-threshold', type=(float, float), help='pdf ocr threshold') -@click.option('--set_kb_info', type=str, help='''每个知识库的初始化介绍,用于在初始化知识库时显示和Agent调用, +@click.option( + "--set_zh_title_enhance", + type=click.Choice(["true", "false"]), + help="是否开启中文标题增强", +) +@click.option("--pdf-ocr-threshold", type=(float, float), help="pdf ocr threshold") +@click.option( + "--set_kb_info", + type=str, + help="""每个知识库的初始化介绍,用于在初始化知识库时显示和Agent调用, 没写则没有介绍,不会被Agent调用。 as a JSON string. Example: "{\"samples\": \"关于本项目issue的解答\"}" - ''') + """, +) @click.option("--set_kb_root_path", help="设置知识库根路径") @click.option("--set_db_root_path", help="设置db根路径") @click.option("--set_sqlalchemy_database_uri", help="设置sqlalchemy数据库uri") @@ -162,34 +198,49 @@ def server(**kwargs): @click.option("--clear", is_flag=True, help="清除配置") @click.option("--show", is_flag=True, help="显示配置") def kb(**kwargs): - if kwargs["set_default_knowledge_base"]: - config_kb_workspace.set_default_knowledge_base(default_knowledge_base=kwargs["set_default_knowledge_base"]) + config_kb_workspace.set_default_knowledge_base( + default_knowledge_base=kwargs["set_default_knowledge_base"] + ) if kwargs["set_default_vs_type"]: - config_kb_workspace.set_default_vs_type(default_vs_type=kwargs["set_default_vs_type"]) + config_kb_workspace.set_default_vs_type( + default_vs_type=kwargs["set_default_vs_type"] + ) if kwargs["set_cached_vs_num"]: config_kb_workspace.set_cached_vs_num(cached_vs_num=kwargs["set_cached_vs_num"]) if kwargs["set_cached_memo_vs_num"]: - config_kb_workspace.set_cached_memo_vs_num(cached_memo_vs_num=kwargs["set_cached_memo_vs_num"]) + config_kb_workspace.set_cached_memo_vs_num( + cached_memo_vs_num=kwargs["set_cached_memo_vs_num"] + ) if kwargs["set_chunk_size"]: config_kb_workspace.set_chunk_size(chunk_size=kwargs["set_chunk_size"]) if kwargs["set_overlap_size"]: config_kb_workspace.set_overlap_size(overlap_size=kwargs["set_overlap_size"]) if kwargs["set_vector_search_top_k"]: - config_kb_workspace.set_vector_search_top_k(vector_search_top_k=kwargs["set_vector_search_top_k"]) + config_kb_workspace.set_vector_search_top_k( + vector_search_top_k=kwargs["set_vector_search_top_k"] + ) if kwargs["set_score_threshold"]: - config_kb_workspace.set_score_threshold(score_threshold=kwargs["set_score_threshold"]) + config_kb_workspace.set_score_threshold( + score_threshold=kwargs["set_score_threshold"] + ) if kwargs["set_default_search_engine"]: - config_kb_workspace.set_default_search_engine(default_search_engine=kwargs["set_default_search_engine"]) + config_kb_workspace.set_default_search_engine( + default_search_engine=kwargs["set_default_search_engine"] + ) if kwargs["set_search_engine_top_k"]: - config_model_workspace.set_search_engine_top_k(search_engine_top_k=kwargs["set_search_engine_top_k"]) + config_model_workspace.set_search_engine_top_k( + search_engine_top_k=kwargs["set_search_engine_top_k"] + ) if kwargs["set_zh_title_enhance"]: if kwargs["set_zh_title_enhance"].lower() == "true": config_kb_workspace.set_zh_title_enhance(True) else: config_kb_workspace.set_zh_title_enhance(False) if kwargs["pdf_ocr_threshold"]: - config_kb_workspace.set_pdf_ocr_threshold(pdf_ocr_threshold=kwargs["pdf_ocr_threshold"]) + config_kb_workspace.set_pdf_ocr_threshold( + pdf_ocr_threshold=kwargs["pdf_ocr_threshold"] + ) if kwargs["set_kb_info"]: kb_info_dict = json.loads(kwargs["set_kb_info"]) config_kb_workspace.set_kb_info(kb_info=kb_info_dict) @@ -198,11 +249,17 @@ def kb(**kwargs): if kwargs["set_db_root_path"]: config_kb_workspace.set_db_root_path(db_root_path=kwargs["set_db_root_path"]) if kwargs["set_sqlalchemy_database_uri"]: - config_kb_workspace.set_sqlalchemy_database_uri(sqlalchemy_database_uri=kwargs["set_sqlalchemy_database_uri"]) + config_kb_workspace.set_sqlalchemy_database_uri( + sqlalchemy_database_uri=kwargs["set_sqlalchemy_database_uri"] + ) if kwargs["set_text_splitter_name"]: - config_kb_workspace.set_text_splitter_name(text_splitter_name=kwargs["set_text_splitter_name"]) + config_kb_workspace.set_text_splitter_name( + text_splitter_name=kwargs["set_text_splitter_name"] + ) if kwargs["set_embedding_keyword_file"]: - config_kb_workspace.set_embedding_keyword_file(embedding_keyword_file=kwargs["set_embedding_keyword_file"]) + config_kb_workspace.set_embedding_keyword_file( + embedding_keyword_file=kwargs["set_embedding_keyword_file"] + ) if kwargs["clear"]: config_kb_workspace.clear() diff --git a/libs/chatchat-server/chatchat/configs/__init__.py b/libs/chatchat-server/chatchat/configs/__init__.py index d7206343d..c3c977aa5 100644 --- a/libs/chatchat-server/chatchat/configs/__init__.py +++ b/libs/chatchat-server/chatchat/configs/__init__.py @@ -1,11 +1,10 @@ import importlib import importlib.util -import os -from pathlib import Path -from typing import Dict, Any - import json import logging +import os +from pathlib import Path +from typing import Any, Dict logger = logging.getLogger() @@ -18,20 +17,19 @@ def _load_mod(mod, attr): break if attr_cfg is None: - logger.warning( - f"Missing attr_cfg:{attr} in {mod}, Skip." - ) + logger.warning(f"Missing attr_cfg:{attr} in {mod}, Skip.") return attr_cfg return attr_cfg def _import_config_mod_load(import_config_mod: str) -> Dict: # 加载用户空间的配置 - user_config_path = os.path.join(os.path.expanduser("~"), ".config", "chatchat/configs") + user_config_path = os.path.join( + os.path.expanduser("~"), ".config", "chatchat/configs" + ) user_import = True # 默认加载用户配置 if os.path.exists(user_config_path): try: - file_names = os.listdir(user_config_path) if import_config_mod + ".py" not in file_names: @@ -42,10 +40,7 @@ def _import_config_mod_load(import_config_mod: str) -> Dict: if user_import: # Dynamic loading {config}.py file py_path = os.path.join(user_config_path, import_config_mod + ".py") - spec = importlib.util.spec_from_file_location( - f"*", - py_path - ) + spec = importlib.util.spec_from_file_location(f"*", py_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) @@ -65,15 +60,14 @@ def _import_config_mod_load(import_config_mod: str) -> Dict: user_import = False if user_import: - logger.error( - f"Failed to load user config from {user_config_path}, Skip." - ) + logger.error(f"Failed to load user config from {user_config_path}, Skip.") raise RuntimeError(f"Failed to load user config from {user_config_path}") # 当前文件路径 - py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py") + py_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py" + ) - spec = importlib.util.spec_from_file_location(f"*", - py_path) + spec = importlib.util.spec_from_file_location(f"*", py_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) @@ -87,7 +81,6 @@ def _import_config_mod_load(import_config_mod: str) -> Dict: CONFIG_IMPORTS = { - "_basic_config.py": _import_config_mod_load("_basic_config"), "_kb_config.py": _import_config_mod_load("_kb_config"), "_model_config.py": _import_config_mod_load("_model_config"), @@ -115,7 +108,9 @@ def _import_ConfigBasicFactory() -> Any: def _import_ConfigBasicWorkSpace() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - ConfigBasicWorkSpace = load_mod(basic_config_load.get("module"), "ConfigBasicWorkSpace") + ConfigBasicWorkSpace = load_mod( + basic_config_load.get("module"), "ConfigBasicWorkSpace" + ) return ConfigBasicWorkSpace @@ -123,7 +118,9 @@ def _import_ConfigBasicWorkSpace() -> Any: def _import_config_basic_workspace() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") + config_basic_workspace = load_mod( + basic_config_load.get("module"), "config_basic_workspace" + ) return config_basic_workspace @@ -131,7 +128,9 @@ def _import_log_verbose() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") + config_basic_workspace = load_mod( + basic_config_load.get("module"), "config_basic_workspace" + ) return config_basic_workspace.get_config().log_verbose @@ -139,7 +138,9 @@ def _import_chatchat_root() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") + config_basic_workspace = load_mod( + basic_config_load.get("module"), "config_basic_workspace" + ) return config_basic_workspace.get_config().CHATCHAT_ROOT @@ -148,7 +149,9 @@ def _import_data_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") + config_basic_workspace = load_mod( + basic_config_load.get("module"), "config_basic_workspace" + ) return config_basic_workspace.get_config().DATA_PATH @@ -156,7 +159,9 @@ def _import_img_dir() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") + config_basic_workspace = load_mod( + basic_config_load.get("module"), "config_basic_workspace" + ) return config_basic_workspace.get_config().IMG_DIR @@ -164,7 +169,9 @@ def _import_img_dir() -> Any: def _import_nltk_data_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") + config_basic_workspace = load_mod( + basic_config_load.get("module"), "config_basic_workspace" + ) return config_basic_workspace.get_config().NLTK_DATA_PATH @@ -173,7 +180,9 @@ def _import_log_format() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") + config_basic_workspace = load_mod( + basic_config_load.get("module"), "config_basic_workspace" + ) return config_basic_workspace.get_config().LOG_FORMAT @@ -182,7 +191,9 @@ def _import_log_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") + config_basic_workspace = load_mod( + basic_config_load.get("module"), "config_basic_workspace" + ) return config_basic_workspace.get_config().LOG_PATH @@ -191,7 +202,9 @@ def _import_media_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") + config_basic_workspace = load_mod( + basic_config_load.get("module"), "config_basic_workspace" + ) return config_basic_workspace.get_config().MEDIA_PATH @@ -199,7 +212,9 @@ def _import_media_path() -> Any: def _import_base_temp_dir() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") + config_basic_workspace = load_mod( + basic_config_load.get("module"), "config_basic_workspace" + ) return config_basic_workspace.get_config().BASE_TEMP_DIR @@ -414,7 +429,9 @@ def _import_ConfigModelFactory() -> Any: def _import_ConfigModelWorkSpace() -> Any: basic_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = basic_config_load.get("load_mod") - ConfigModelWorkSpace = load_mod(basic_config_load.get("module"), "ConfigModelWorkSpace") + ConfigModelWorkSpace = load_mod( + basic_config_load.get("module"), "ConfigModelWorkSpace" + ) return ConfigModelWorkSpace @@ -422,14 +439,18 @@ def _import_ConfigModelWorkSpace() -> Any: def _import_config_model_workspace() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + config_model_workspace = load_mod( + model_config_load.get("module"), "config_model_workspace" + ) return config_model_workspace def _import_default_llm_model() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + config_model_workspace = load_mod( + model_config_load.get("module"), "config_model_workspace" + ) return config_model_workspace.get_config().DEFAULT_LLM_MODEL @@ -438,7 +459,9 @@ def _import_default_embedding_model() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + config_model_workspace = load_mod( + model_config_load.get("module"), "config_model_workspace" + ) return config_model_workspace.get_config().DEFAULT_EMBEDDING_MODEL @@ -446,7 +469,9 @@ def _import_default_embedding_model() -> Any: def _import_agent_model() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + config_model_workspace = load_mod( + model_config_load.get("module"), "config_model_workspace" + ) return config_model_workspace.get_config().Agent_MODEL @@ -454,7 +479,9 @@ def _import_agent_model() -> Any: def _import_history_len() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + config_model_workspace = load_mod( + model_config_load.get("module"), "config_model_workspace" + ) return config_model_workspace.get_config().HISTORY_LEN @@ -462,7 +489,9 @@ def _import_history_len() -> Any: def _import_max_tokens() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + config_model_workspace = load_mod( + model_config_load.get("module"), "config_model_workspace" + ) return config_model_workspace.get_config().MAX_TOKENS @@ -470,7 +499,9 @@ def _import_max_tokens() -> Any: def _import_temperature() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + config_model_workspace = load_mod( + model_config_load.get("module"), "config_model_workspace" + ) return config_model_workspace.get_config().TEMPERATURE @@ -478,7 +509,9 @@ def _import_temperature() -> Any: def _import_support_agent_models() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + config_model_workspace = load_mod( + model_config_load.get("module"), "config_model_workspace" + ) return config_model_workspace.get_config().SUPPORT_AGENT_MODELS @@ -486,7 +519,9 @@ def _import_support_agent_models() -> Any: def _import_llm_model_config() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + config_model_workspace = load_mod( + model_config_load.get("module"), "config_model_workspace" + ) return config_model_workspace.get_config().LLM_MODEL_CONFIG @@ -494,7 +529,9 @@ def _import_llm_model_config() -> Any: def _import_model_platforms() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + config_model_workspace = load_mod( + model_config_load.get("module"), "config_model_workspace" + ) return config_model_workspace.get_config().MODEL_PLATFORMS @@ -502,7 +539,9 @@ def _import_model_platforms() -> Any: def _import_model_providers_cfg_path() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + config_model_workspace = load_mod( + model_config_load.get("module"), "config_model_workspace" + ) return config_model_workspace.get_config().MODEL_PROVIDERS_CFG_PATH_CONFIG @@ -510,7 +549,9 @@ def _import_model_providers_cfg_path() -> Any: def _import_model_providers_cfg_host() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + config_model_workspace = load_mod( + model_config_load.get("module"), "config_model_workspace" + ) return config_model_workspace.get_config().MODEL_PROVIDERS_CFG_HOST @@ -518,7 +559,9 @@ def _import_model_providers_cfg_host() -> Any: def _import_model_providers_cfg_port() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + config_model_workspace = load_mod( + model_config_load.get("module"), "config_model_workspace" + ) return config_model_workspace.get_config().MODEL_PROVIDERS_CFG_PORT @@ -526,7 +569,9 @@ def _import_model_providers_cfg_port() -> Any: def _import_tool_config() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + config_model_workspace = load_mod( + model_config_load.get("module"), "config_model_workspace" + ) return config_model_workspace.get_config().TOOL_CONFIG @@ -550,7 +595,9 @@ def _import_ConfigServer() -> Any: def _import_ConfigServerFactory() -> Any: basic_config_load = CONFIG_IMPORTS.get("_server_config.py") load_mod = basic_config_load.get("load_mod") - ConfigServerFactory = load_mod(basic_config_load.get("module"), "ConfigServerFactory") + ConfigServerFactory = load_mod( + basic_config_load.get("module"), "ConfigServerFactory" + ) return ConfigServerFactory @@ -558,7 +605,9 @@ def _import_ConfigServerFactory() -> Any: def _import_ConfigServerWorkSpace() -> Any: basic_config_load = CONFIG_IMPORTS.get("_server_config.py") load_mod = basic_config_load.get("load_mod") - ConfigServerWorkSpace = load_mod(basic_config_load.get("module"), "ConfigServerWorkSpace") + ConfigServerWorkSpace = load_mod( + basic_config_load.get("module"), "ConfigServerWorkSpace" + ) return ConfigServerWorkSpace @@ -566,7 +615,9 @@ def _import_ConfigServerWorkSpace() -> Any: def _import_config_server_workspace() -> Any: server_config_load = CONFIG_IMPORTS.get("_server_config.py") load_mod = server_config_load.get("load_mod") - config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace") + config_server_workspace = load_mod( + server_config_load.get("module"), "config_server_workspace" + ) return config_server_workspace @@ -574,7 +625,9 @@ def _import_config_server_workspace() -> Any: def _import_httpx_default_timeout() -> Any: server_config_load = CONFIG_IMPORTS.get("_server_config.py") load_mod = server_config_load.get("load_mod") - config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace") + config_server_workspace = load_mod( + server_config_load.get("module"), "config_server_workspace" + ) return config_server_workspace.get_config().HTTPX_DEFAULT_TIMEOUT @@ -582,7 +635,9 @@ def _import_httpx_default_timeout() -> Any: def _import_open_cross_domain() -> Any: server_config_load = CONFIG_IMPORTS.get("_server_config.py") load_mod = server_config_load.get("load_mod") - config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace") + config_server_workspace = load_mod( + server_config_load.get("module"), "config_server_workspace" + ) return config_server_workspace.get_config().OPEN_CROSS_DOMAIN @@ -590,7 +645,9 @@ def _import_open_cross_domain() -> Any: def _import_default_bind_host() -> Any: server_config_load = CONFIG_IMPORTS.get("_server_config.py") load_mod = server_config_load.get("load_mod") - config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace") + config_server_workspace = load_mod( + server_config_load.get("module"), "config_server_workspace" + ) return config_server_workspace.get_config().DEFAULT_BIND_HOST @@ -598,7 +655,9 @@ def _import_default_bind_host() -> Any: def _import_open_cross_domain() -> Any: server_config_load = CONFIG_IMPORTS.get("_server_config.py") load_mod = server_config_load.get("load_mod") - config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace") + config_server_workspace = load_mod( + server_config_load.get("module"), "config_server_workspace" + ) return config_server_workspace.get_config().OPEN_CROSS_DOMAIN @@ -606,7 +665,9 @@ def _import_open_cross_domain() -> Any: def _import_webui_server() -> Any: server_config_load = CONFIG_IMPORTS.get("_server_config.py") load_mod = server_config_load.get("load_mod") - config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace") + config_server_workspace = load_mod( + server_config_load.get("module"), "config_server_workspace" + ) return config_server_workspace.get_config().WEBUI_SERVER @@ -614,7 +675,9 @@ def _import_webui_server() -> Any: def _import_api_server() -> Any: server_config_load = CONFIG_IMPORTS.get("_server_config.py") load_mod = server_config_load.get("load_mod") - config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace") + config_server_workspace = load_mod( + server_config_load.get("module"), "config_server_workspace" + ) return config_server_workspace.get_config().API_SERVER @@ -803,29 +866,20 @@ def __getattr__(name: str) -> Any: "OPEN_CROSS_DOMAIN", "WEBUI_SERVER", "API_SERVER", - "ConfigBasic", "ConfigBasicFactory", "ConfigBasicWorkSpace", - "config_basic_workspace", - "ConfigModel", "ConfigModelFactory", "ConfigModelWorkSpace", - "config_model_workspace", - "ConfigKb", "ConfigKbFactory", "ConfigKbWorkSpace", - "config_kb_workspace", - "ConfigServer", "ConfigServerFactory", "ConfigServerWorkSpace", - "config_server_workspace", - ] diff --git a/libs/chatchat-server/chatchat/configs/_basic_config.py b/libs/chatchat-server/chatchat/configs/_basic_config.py index f53c20dbc..5c8640044 100644 --- a/libs/chatchat-server/chatchat/configs/_basic_config.py +++ b/libs/chatchat-server/chatchat/configs/_basic_config.py @@ -1,12 +1,11 @@ -import os import json +import logging +import os +import sys from dataclasses import dataclass from pathlib import Path -import sys -import logging from typing import Any, Optional - sys.path.append(str(Path(__file__).parent)) import _core_config as core_config @@ -43,11 +42,13 @@ def __str__(self): @dataclass class ConfigBasicFactory(core_config.ConfigFactory[ConfigBasic]): - """Basic config for ChatChat """ + """Basic config for ChatChat""" def __init__(self): # 日志格式 - self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" + self.LOG_FORMAT = ( + "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" + ) logging.basicConfig(format=self.LOG_FORMAT) self.LOG_VERBOSE = False self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent) @@ -86,6 +87,7 @@ def _init_data_dir(self): # nltk 模型存储路径 self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data") import nltk + nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path # 日志存储路径 self.LOG_PATH = os.path.join(self.DATA_PATH, "logs") @@ -121,17 +123,19 @@ def get_config(self) -> ConfigBasic: return config -class ConfigBasicWorkSpace(core_config.ConfigWorkSpace[ConfigBasicFactory, ConfigBasic]): +class ConfigBasicWorkSpace( + core_config.ConfigWorkSpace[ConfigBasicFactory, ConfigBasic] +): """ 工作空间的配置预设,提供ConfigBasic建造方法产生实例。 """ + config_factory_cls = ConfigBasicFactory def __init__(self): super().__init__() def _build_config_factory(self, config_json: Any) -> ConfigBasicFactory: - _config_factory = self.config_factory_cls() if config_json.get("log_verbose"): diff --git a/libs/chatchat-server/chatchat/configs/_core_config.py b/libs/chatchat-server/chatchat/configs/_core_config.py index 339c3dfaf..4ae6de0f2 100644 --- a/libs/chatchat-server/chatchat/configs/_core_config.py +++ b/libs/chatchat-server/chatchat/configs/_core_config.py @@ -1,10 +1,10 @@ -import os import json -from abc import abstractmethod, ABC +import logging +import os +from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -import logging -from typing import Any, Dict, TypeVar, Generic, Optional, Type +from typing import Any, Dict, Generic, Optional, Type, TypeVar from dataclasses_json import DataClassJsonMixin from pydantic import BaseModel @@ -33,7 +33,7 @@ def to_json(self, **kwargs: Any) -> str: @dataclass class ConfigFactory(Generic[F], DataClassJsonMixin): - """config for ChatChat """ + """config for ChatChat""" @classmethod @abstractmethod @@ -52,6 +52,7 @@ class ConfigWorkSpace(Generic[CF, F], ABC): 工作空间的配置信息存储在用户的家目录下的.chatchat/workspace/workspace_config.json文件中。 注意:不存在则读取默认 """ + config_factory_cls: Type[CF] _config_factory: Optional[CF] = None @@ -111,7 +112,9 @@ def get_config_by_type(self, cfg_type) -> Dict[str, Any]: if store_cfg is None: raise RuntimeError("store_cfg is None.") - get_lambda = lambda store_cfg_type: store_cfg[self._get_store_cfg_index_by_type(store_cfg, store_cfg_type)] + get_lambda = lambda store_cfg_type: store_cfg[ + self._get_store_cfg_index_by_type(store_cfg, store_cfg_type) + ] return get_lambda(cfg_type) def store_config(self): @@ -123,8 +126,7 @@ def store_config(self): if _load_config is None: _load_config = [] config_json_index = self._get_store_cfg_index_by_type( - store_cfg=_load_config, - store_cfg_type=self.get_type() + store_cfg=_load_config, store_cfg_type=self.get_type() ) config_type_json = {"type": self.get_type(), "config": config_json} if config_json_index == -1: diff --git a/libs/chatchat-server/chatchat/configs/_kb_config.py b/libs/chatchat-server/chatchat/configs/_kb_config.py index 988bca030..591dabdee 100644 --- a/libs/chatchat-server/chatchat/configs/_kb_config.py +++ b/libs/chatchat-server/chatchat/configs/_kb_config.py @@ -1,14 +1,13 @@ -import os import json +import logging +import os +import sys from dataclasses import dataclass from pathlib import Path -import sys -import logging -from typing import Any, Optional, Dict, Tuple +from typing import Any, Dict, Optional, Tuple sys.path.append(str(Path(__file__).parent)) import _core_config as core_config - from _basic_config import config_basic_workspace @@ -117,7 +116,9 @@ def __init__(self): # 通常情况下不需要更改以下内容 # 知识库默认存储路径 - self.KB_ROOT_PATH = os.path.join(config_basic_workspace.get_config().DATA_PATH, "knowledge_base") + self.KB_ROOT_PATH = os.path.join( + config_basic_workspace.get_config().DATA_PATH, "knowledge_base" + ) if not os.path.exists(self.KB_ROOT_PATH): os.mkdir(self.KB_ROOT_PATH) @@ -128,8 +129,7 @@ def __init__(self): # 可选向量库类型及对应配置 self.kbs_config = { - "faiss": { - }, + "faiss": {}, "milvus": { "host": "127.0.0.1", "port": "19530", @@ -155,13 +155,16 @@ def __init__(self): "port": "9200", "index_name": "test_index", "user": "", - "password": "" + "password": "", }, "milvus_kwargs": { - "search_params": {"metric_type": "L2"}, #在此处增加search_params - "index_params": {"metric_type": "L2", "index_type": "HNSW"} # 在此处增加index_params + "search_params": {"metric_type": "L2"}, # 在此处增加search_params + "index_params": { + "metric_type": "L2", + "index_type": "HNSW", + }, # 在此处增加index_params }, - "chromadb": {} + "chromadb": {}, } # TextSplitter配置项,如果你不明白其中的含义,就不要修改。 @@ -179,13 +182,12 @@ def __init__(self): "tokenizer_name_or_path": "cl100k_base", }, "MarkdownHeaderTextSplitter": { - "headers_to_split_on": - [ - ("#", "head1"), - ("##", "head2"), - ("###", "head3"), - ("####", "head4"), - ] + "headers_to_split_on": [ + ("#", "head1"), + ("##", "head2"), + ("###", "head3"), + ("####", "head4"), + ] }, } @@ -207,6 +209,7 @@ class ConfigKbWorkSpace(core_config.ConfigWorkSpace[ConfigKbFactory, ConfigKb]): """ 工作空间的配置预设,提供ConfigKb建造方法产生实例。 """ + config_factory_cls = ConfigKbFactory def __init__(self): @@ -215,7 +218,9 @@ def __init__(self): def _build_config_factory(self, config_json: Any) -> ConfigKbFactory: _config_factory = self.config_factory_cls() if config_json.get("DEFAULT_KNOWLEDGE_BASE"): - _config_factory.DEFAULT_KNOWLEDGE_BASE = config_json.get("DEFAULT_KNOWLEDGE_BASE") + _config_factory.DEFAULT_KNOWLEDGE_BASE = config_json.get( + "DEFAULT_KNOWLEDGE_BASE" + ) if config_json.get("DEFAULT_VS_TYPE"): _config_factory.DEFAULT_VS_TYPE = config_json.get("DEFAULT_VS_TYPE") if config_json.get("CACHED_VS_NUM"): @@ -231,7 +236,9 @@ def _build_config_factory(self, config_json: Any) -> ConfigKbFactory: if config_json.get("SCORE_THRESHOLD"): _config_factory.SCORE_THRESHOLD = config_json.get("SCORE_THRESHOLD") if config_json.get("DEFAULT_SEARCH_ENGINE"): - _config_factory.DEFAULT_SEARCH_ENGINE = config_json.get("DEFAULT_SEARCH_ENGINE") + _config_factory.DEFAULT_SEARCH_ENGINE = config_json.get( + "DEFAULT_SEARCH_ENGINE" + ) if config_json.get("SEARCH_ENGINE_TOP_K"): _config_factory.SEARCH_ENGINE_TOP_K = config_json.get("SEARCH_ENGINE_TOP_K") if config_json.get("ZH_TITLE_ENHANCE"): @@ -245,13 +252,17 @@ def _build_config_factory(self, config_json: Any) -> ConfigKbFactory: if config_json.get("DB_ROOT_PATH"): _config_factory.DB_ROOT_PATH = config_json.get("DB_ROOT_PATH") if config_json.get("SQLALCHEMY_DATABASE_URI"): - _config_factory.SQLALCHEMY_DATABASE_URI = config_json.get("SQLALCHEMY_DATABASE_URI") + _config_factory.SQLALCHEMY_DATABASE_URI = config_json.get( + "SQLALCHEMY_DATABASE_URI" + ) if config_json.get("TEXT_SPLITTER_NAME"): _config_factory.TEXT_SPLITTER_NAME = config_json.get("TEXT_SPLITTER_NAME") if config_json.get("EMBEDDING_KEYWORD_FILE"): - _config_factory.EMBEDDING_KEYWORD_FILE = config_json.get("EMBEDDING_KEYWORD_FILE") + _config_factory.EMBEDDING_KEYWORD_FILE = config_json.get( + "EMBEDDING_KEYWORD_FILE" + ) return _config_factory diff --git a/libs/chatchat-server/chatchat/configs/_model_config.py b/libs/chatchat-server/chatchat/configs/_model_config.py index dfdfdd9c9..454bfe299 100644 --- a/libs/chatchat-server/chatchat/configs/_model_config.py +++ b/libs/chatchat-server/chatchat/configs/_model_config.py @@ -1,10 +1,9 @@ -import os import logging +import os import sys -from pathlib import Path -from typing import Any, Optional, List, Dict - from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional sys.path.append(str(Path(__file__).parent)) import _core_config as core_config @@ -80,13 +79,13 @@ def __init__(self): "qwen-turbo", ] - self.MODEL_PROVIDERS_CFG_PATH_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), - "model_providers.yaml") + self.MODEL_PROVIDERS_CFG_PATH_CONFIG = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "model_providers.yaml" + ) self.MODEL_PROVIDERS_CFG_HOST = "127.0.0.1" self.MODEL_PROVIDERS_CFG_PORT = 20000 - # 可以通过 model_providers 提供转换不同平台的接口为openai endpoint的能力,启动后下面变量会自动增加相应的平台 # ### 如果您已经有了一个openai endpoint的能力的地址,可以在这里直接配置 # - platform_name 可以任意填写,不要重复即可 @@ -158,46 +157,40 @@ def __init__(self): "top_k": 3, "score_threshold": 1.0, "conclude_prompt": { - "with_result": - '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",' - '不允许在答案中添加编造成分,答案请使用中文。 \n' - '<已知信息>{{ context }}\n' - '<问题>{{ question }}\n', - "without_result": - '请你根据我的提问回答我的问题:\n' - '{{ question }}\n' - '请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n', - } + "with_result": '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",' + "不允许在答案中添加编造成分,答案请使用中文。 \n" + "<已知信息>{{ context }}\n" + "<问题>{{ question }}\n", + "without_result": "请你根据我的提问回答我的问题:\n" + "{{ question }}\n" + "请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n", + }, }, "search_internet": { "use": False, "search_engine_name": "bing", - "search_engine_config": - { - "bing": { - "result_len": 3, - "bing_search_url": "https://api.bing.microsoft.com/v7.0/search", - "bing_key": "", - }, - "metaphor": { - "result_len": 3, - "metaphor_api_key": "", - "split_result": False, - "chunk_size": 500, - "chunk_overlap": 0, - }, - "duckduckgo": { - "result_len": 3 - } + "search_engine_config": { + "bing": { + "result_len": 3, + "bing_search_url": "https://api.bing.microsoft.com/v7.0/search", + "bing_key": "", }, + "metaphor": { + "result_len": 3, + "metaphor_api_key": "", + "split_result": False, + "chunk_size": 500, + "chunk_overlap": 0, + }, + "duckduckgo": {"result_len": 3}, + }, "top_k": 10, "verbose": "Origin", - "conclude_prompt": - "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 " - "\n<已知信息>{{ context }}\n" - "<问题>\n" - "{{ question }}\n" - "\n" + "conclude_prompt": "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 " + "\n<已知信息>{{ context }}\n" + "<问题>\n" + "{{ question }}\n" + "\n", }, "arxiv": { "use": False, @@ -223,13 +216,13 @@ def __init__(self): "use": False, "model_path": "your model path", "tokenizer_path": "your tokenizer path", - "device": "cuda:1" + "device": "cuda:1", }, "aqa_processor": { "use": False, "model_path": "your model path", "tokenizer_path": "yout tokenizer path", - "device": "cuda:2" + "device": "cuda:2", }, "text2images": { "use": False, @@ -262,7 +255,7 @@ def __init__(self): # 如果出现大模型选错表的情况,可尝试根据实际情况填写表名和说明 # "tableA":"这是一个用户表,存储了用户的基本信息", # "tanleB":"角色表", - } + }, }, } self._init_llm_work_config() @@ -278,7 +271,7 @@ def _init_llm_work_config(self): "max_tokens": 4096, "history_len": 100, "prompt_name": "default", - "callbacks": False + "callbacks": False, }, }, "llm_model": { @@ -287,7 +280,7 @@ def _init_llm_work_config(self): "max_tokens": 4096, "history_len": 10, "prompt_name": "default", - "callbacks": True + "callbacks": True, }, }, "action_model": { @@ -295,7 +288,7 @@ def _init_llm_work_config(self): "temperature": 0.01, "max_tokens": 4096, "prompt_name": "ChatGLM3", - "callbacks": True + "callbacks": True, }, }, "postprocess_model": { @@ -303,14 +296,14 @@ def _init_llm_work_config(self): "temperature": 0.01, "max_tokens": 4096, "prompt_name": "default", - "callbacks": True + "callbacks": True, } }, "image_model": { "sd-turbo": { "size": "256*256", } - } + }, } def default_llm_model(self, llm_model: str): @@ -368,22 +361,26 @@ def get_config(self) -> ConfigModel: return config -class ConfigModelWorkSpace(core_config.ConfigWorkSpace[ConfigModelFactory, ConfigModel]): +class ConfigModelWorkSpace( + core_config.ConfigWorkSpace[ConfigModelFactory, ConfigModel] +): """ 工作空间的配置预设, 提供ConfigModel建造方法产生实例。 """ + config_factory_cls = ConfigModelFactory def __init__(self): super().__init__() def _build_config_factory(self, config_json: Any) -> ConfigModelFactory: - _config_factory = self.config_factory_cls() if config_json.get("DEFAULT_LLM_MODEL"): _config_factory.default_llm_model(config_json.get("DEFAULT_LLM_MODEL")) if config_json.get("DEFAULT_EMBEDDING_MODEL"): - _config_factory.default_embedding_model(config_json.get("DEFAULT_EMBEDDING_MODEL")) + _config_factory.default_embedding_model( + config_json.get("DEFAULT_EMBEDDING_MODEL") + ) if config_json.get("Agent_MODEL"): _config_factory.agent_model(config_json.get("Agent_MODEL")) if config_json.get("HISTORY_LEN"): @@ -393,13 +390,21 @@ def _build_config_factory(self, config_json: Any) -> ConfigModelFactory: if config_json.get("TEMPERATURE"): _config_factory.temperature(config_json.get("TEMPERATURE")) if config_json.get("SUPPORT_AGENT_MODELS"): - _config_factory.support_agent_models(config_json.get("SUPPORT_AGENT_MODELS")) + _config_factory.support_agent_models( + config_json.get("SUPPORT_AGENT_MODELS") + ) if config_json.get("MODEL_PROVIDERS_CFG_PATH_CONFIG"): - _config_factory.model_providers_cfg_path_config(config_json.get("MODEL_PROVIDERS_CFG_PATH_CONFIG")) + _config_factory.model_providers_cfg_path_config( + config_json.get("MODEL_PROVIDERS_CFG_PATH_CONFIG") + ) if config_json.get("MODEL_PROVIDERS_CFG_HOST"): - _config_factory.model_providers_cfg_host(config_json.get("MODEL_PROVIDERS_CFG_HOST")) + _config_factory.model_providers_cfg_host( + config_json.get("MODEL_PROVIDERS_CFG_HOST") + ) if config_json.get("MODEL_PROVIDERS_CFG_PORT"): - _config_factory.model_providers_cfg_port(config_json.get("MODEL_PROVIDERS_CFG_PORT")) + _config_factory.model_providers_cfg_port( + config_json.get("MODEL_PROVIDERS_CFG_PORT") + ) if config_json.get("MODEL_PLATFORMS"): _config_factory.model_platforms(config_json.get("MODEL_PLATFORMS")) if config_json.get("TOOL_CONFIG"): @@ -443,7 +448,9 @@ def set_support_agent_models(self, support_agent_models: List[str]): self.store_config() def set_model_providers_cfg_path_config(self, model_providers_cfg_path_config: str): - self._config_factory.model_providers_cfg_path_config(model_providers_cfg_path_config) + self._config_factory.model_providers_cfg_path_config( + model_providers_cfg_path_config + ) self.store_config() def set_model_providers_cfg_host(self, model_providers_cfg_host: str): @@ -463,4 +470,4 @@ def set_tool_config(self, tool_config: Dict[str, Any]): self.store_config() -config_model_workspace: ConfigModelWorkSpace = ConfigModelWorkSpace() \ No newline at end of file +config_model_workspace: ConfigModelWorkSpace = ConfigModelWorkSpace() diff --git a/libs/chatchat-server/chatchat/configs/_prompt_config.py b/libs/chatchat-server/chatchat/configs/_prompt_config.py index eceeac2b9..eba0e8428 100644 --- a/libs/chatchat-server/chatchat/configs/_prompt_config.py +++ b/libs/chatchat-server/chatchat/configs/_prompt_config.py @@ -2,131 +2,121 @@ PROMPT_TEMPLATES = { "preprocess_model": { - "default": - '你只要回复0 和 1 ,代表不需要使用工具。以下几种问题不需要使用工具:' - '1. 需要联网查询的内容\n' - '2. 需要计算的内容\n' - '3. 需要查询实时性的内容\n' - '如果我的输入满足这几种情况,返回1。其他输入,请你回复0,你只要返回一个数字\n' - '这是我的问题:' + "default": "你只要回复0 和 1 ,代表不需要使用工具。以下几种问题不需要使用工具:" + "1. 需要联网查询的内容\n" + "2. 需要计算的内容\n" + "3. 需要查询实时性的内容\n" + "如果我的输入满足这几种情况,返回1。其他输入,请你回复0,你只要返回一个数字\n" + "这是我的问题:" }, "llm_model": { - "default": - '{{input}}', - "with_history": - 'The following is a friendly conversation between a human and an AI. ' - 'The AI is talkative and provides lots of specific details from its context. ' - 'If the AI does not know the answer to a question, it truthfully says it does not know.\n\n' - 'Current conversation:\n' - '{{history}}\n' - 'Human: {{input}}\n' - 'AI:', - "rag": - '【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。\n\n' - '【已知信息】{{context}}\n\n' - '【问题】{{question}}\n', - "rag_default": - '{{question}}', + "default": "{{input}}", + "with_history": "The following is a friendly conversation between a human and an AI. " + "The AI is talkative and provides lots of specific details from its context. " + "If the AI does not know the answer to a question, it truthfully says it does not know.\n\n" + "Current conversation:\n" + "{{history}}\n" + "Human: {{input}}\n" + "AI:", + "rag": "【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。\n\n" + "【已知信息】{{context}}\n\n" + "【问题】{{question}}\n", + "rag_default": "{{question}}", }, "action_model": { - "GPT-4": - 'Answer the following questions as best you can. You have access to the following tools:\n' - 'The way you use the tools is by specifying a json blob.\n' - 'Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).\n' - 'The only values that should be in the "action" field are: {tool_names}\n' - 'The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB:\n' - '```\n\n' - '{{{{\n' - ' "action": $TOOL_NAME,\n' - ' "action_input": $INPUT\n' - '}}}}\n' - '```\n\n' - 'ALWAYS use the following format:\n' - 'Question: the input question you must answer\n' - 'Thought: you should always think about what to do\n' - 'Action:\n' - '```\n\n' - '$JSON_BLOB' - '```\n\n' - 'Observation: the result of the action\n' - '... (this Thought/Action/Observation can repeat N times)\n' - 'Thought: I now know the final answer\n' - 'Final Answer: the final answer to the original input question\n' - 'Begin! Reminder to always use the exact characters `Final Answer` when responding.\n' - 'Question:{input}\n' - 'Thought:{agent_scratchpad}\n', - - "ChatGLM3": - 'You can answer using the tools.Respond to the human as helpfully and accurately as possible.\n' - 'You have access to the following tools:\n' - '{tools}\n' - 'Use a json blob to specify a tool by providing an action key (tool name)\n' - 'and an action_input key (tool input).\n' - 'Valid "action" values: "Final Answer" or [{tool_names}]\n' - 'Provide only ONE action per $JSON_BLOB, as shown:\n\n' - '```\n' - '{{{{\n' - ' "action": $TOOL_NAME,\n' - ' "action_input": $INPUT\n' - '}}}}\n' - '```\n\n' - 'Follow this format:\n\n' - 'Question: input question to answer\n' - 'Thought: consider previous and subsequent steps\n' - 'Action:\n' - '```\n' - '$JSON_BLOB\n' - '```\n' - 'Observation: action result\n' - '... (repeat Thought/Action/Observation N times)\n' - 'Thought: I know what to respond\n' - 'Action:\n' - '```\n' - '{{{{\n' - ' "action": "Final Answer",\n' - ' "action_input": "Final response to human"\n' - '}}}}\n' - 'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary.\n' - 'Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.\n' - 'Question: {input}\n\n' - '{agent_scratchpad}\n', - "qwen": - 'Answer the following questions as best you can. You have access to the following APIs:\n\n' - '{tools}\n\n' - 'Use the following format:\n\n' - 'Question: the input question you must answer\n' - 'Thought: you should always think about what to do\n' - 'Action: the action to take, should be one of [{tool_names}]\n' - 'Action Input: the input to the action\n' - 'Observation: the result of the action\n' - '... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n' - 'Thought: I now know the final answer\n' - 'Final Answer: the final answer to the original input question\n\n' - 'Format the Action Input as a JSON object.\n\n' - 'Begin!\n\n' - 'Question: {input}\n\n' - '{agent_scratchpad}\n\n', - "structured-chat-agent": - 'Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n' - '{tools}\n\n' - 'Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\n' - 'Valid "action" values: "Final Answer" or {tool_names}\n\n' - 'Provide only ONE action per $JSON_BLOB, as shown:\n\n' - '```\n{{\n "action": $TOOL_NAME,\n "action_input": $INPUT\n}}\n```\n\n' - 'Follow this format:\n\n' - 'Question: input question to answer\n' - 'Thought: consider previous and subsequent steps\n' - 'Action:\n```\n$JSON_BLOB\n```\n' - 'Observation: action result\n' - '... (repeat Thought/Action/Observation N times)\n' - 'Thought: I know what to respond\n' - 'Action:\n```\n{{\n "action": "Final Answer",\n "action_input": "Final response to human"\n}}\n\n' - 'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation\n' - '{input}\n\n' - '{agent_scratchpad}\n\n' - # '(reminder to respond in a JSON blob no matter what)' + "GPT-4": "Answer the following questions as best you can. You have access to the following tools:\n" + "The way you use the tools is by specifying a json blob.\n" + "Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).\n" + 'The only values that should be in the "action" field are: {tool_names}\n' + "The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB:\n" + "```\n\n" + "{{{{\n" + ' "action": $TOOL_NAME,\n' + ' "action_input": $INPUT\n' + "}}}}\n" + "```\n\n" + "ALWAYS use the following format:\n" + "Question: the input question you must answer\n" + "Thought: you should always think about what to do\n" + "Action:\n" + "```\n\n" + "$JSON_BLOB" + "```\n\n" + "Observation: the result of the action\n" + "... (this Thought/Action/Observation can repeat N times)\n" + "Thought: I now know the final answer\n" + "Final Answer: the final answer to the original input question\n" + "Begin! Reminder to always use the exact characters `Final Answer` when responding.\n" + "Question:{input}\n" + "Thought:{agent_scratchpad}\n", + "ChatGLM3": "You can answer using the tools.Respond to the human as helpfully and accurately as possible.\n" + "You have access to the following tools:\n" + "{tools}\n" + "Use a json blob to specify a tool by providing an action key (tool name)\n" + "and an action_input key (tool input).\n" + 'Valid "action" values: "Final Answer" or [{tool_names}]\n' + "Provide only ONE action per $JSON_BLOB, as shown:\n\n" + "```\n" + "{{{{\n" + ' "action": $TOOL_NAME,\n' + ' "action_input": $INPUT\n' + "}}}}\n" + "```\n\n" + "Follow this format:\n\n" + "Question: input question to answer\n" + "Thought: consider previous and subsequent steps\n" + "Action:\n" + "```\n" + "$JSON_BLOB\n" + "```\n" + "Observation: action result\n" + "... (repeat Thought/Action/Observation N times)\n" + "Thought: I know what to respond\n" + "Action:\n" + "```\n" + "{{{{\n" + ' "action": "Final Answer",\n' + ' "action_input": "Final response to human"\n' + "}}}}\n" + "Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary.\n" + "Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.\n" + "Question: {input}\n\n" + "{agent_scratchpad}\n", + "qwen": "Answer the following questions as best you can. You have access to the following APIs:\n\n" + "{tools}\n\n" + "Use the following format:\n\n" + "Question: the input question you must answer\n" + "Thought: you should always think about what to do\n" + "Action: the action to take, should be one of [{tool_names}]\n" + "Action Input: the input to the action\n" + "Observation: the result of the action\n" + "... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n" + "Thought: I now know the final answer\n" + "Final Answer: the final answer to the original input question\n\n" + "Format the Action Input as a JSON object.\n\n" + "Begin!\n\n" + "Question: {input}\n\n" + "{agent_scratchpad}\n\n", + "structured-chat-agent": "Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n" + "{tools}\n\n" + "Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\n" + 'Valid "action" values: "Final Answer" or {tool_names}\n\n' + "Provide only ONE action per $JSON_BLOB, as shown:\n\n" + '```\n{{\n "action": $TOOL_NAME,\n "action_input": $INPUT\n}}\n```\n\n' + "Follow this format:\n\n" + "Question: input question to answer\n" + "Thought: consider previous and subsequent steps\n" + "Action:\n```\n$JSON_BLOB\n```\n" + "Observation: action result\n" + "... (repeat Thought/Action/Observation N times)\n" + "Thought: I know what to respond\n" + 'Action:\n```\n{{\n "action": "Final Answer",\n "action_input": "Final response to human"\n}}\n\n' + "Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation\n" + "{input}\n\n" + "{agent_scratchpad}\n\n", + # '(reminder to respond in a JSON blob no matter what)' }, "postprocess_model": { "default": "{{input}}", - } + }, } diff --git a/libs/chatchat-server/chatchat/configs/_server_config.py b/libs/chatchat-server/chatchat/configs/_server_config.py index cfdc80c05..61460f3b0 100644 --- a/libs/chatchat-server/chatchat/configs/_server_config.py +++ b/libs/chatchat-server/chatchat/configs/_server_config.py @@ -1,10 +1,10 @@ -import os import json +import logging +import os +import sys from dataclasses import dataclass from pathlib import Path -import sys -import logging -from typing import Any, Optional, Dict +from typing import Any, Dict, Optional sys.path.append(str(Path(__file__).parent)) import _core_config as core_config @@ -97,10 +97,13 @@ def get_config(self) -> ConfigServer: return config -class ConfigServerWorkSpace(core_config.ConfigWorkSpace[ConfigServerFactory, ConfigServer]): +class ConfigServerWorkSpace( + core_config.ConfigWorkSpace[ConfigServerFactory, ConfigServer] +): """ 工作空间的配置预设,提供ConfigServer建造方法产生实例。 """ + config_factory_cls = ConfigServerFactory def __init__(self): @@ -149,4 +152,4 @@ def set_api_server_port(self, api_server_port: int): self.store_config() -config_server_workspace: ConfigServerWorkSpace = ConfigServerWorkSpace() \ No newline at end of file +config_server_workspace: ConfigServerWorkSpace = ConfigServerWorkSpace() diff --git a/libs/chatchat-server/chatchat/configs/model_providers.yaml b/libs/chatchat-server/chatchat/configs/model_providers.yaml index 4d7949b0c..5b65fef51 100644 --- a/libs/chatchat-server/chatchat/configs/model_providers.yaml +++ b/libs/chatchat-server/chatchat/configs/model_providers.yaml @@ -20,16 +20,16 @@ xinference: model_credential: - - model: 'glm-4' + - model: 'glm4-chat' model_type: 'llm' model_credentials: server_url: 'http://127.0.0.1:9997/' - model_uid: 'glm-4' - - model: 'qwen1.5-chat' + model_uid: 'glm4-chat' + - model: 'qwen2-instruct' model_type: 'llm' model_credentials: server_url: 'http://127.0.0.1:9997/' - model_uid: 'qwen1.5-chat' + model_uid: 'qwen2-instruct' - model: 'bge-large-zh-v1.5' model_type: 'text-embedding' model_credentials: diff --git a/libs/chatchat-server/chatchat/data/temp/openai_files/assistants/2024-03-29/webui.py b/libs/chatchat-server/chatchat/data/temp/openai_files/assistants/2024-03-29/webui.py index 7739235f2..070dd121c 100644 --- a/libs/chatchat-server/chatchat/data/temp/openai_files/assistants/2024-03-29/webui.py +++ b/libs/chatchat-server/chatchat/data/temp/openai_files/assistants/2024-03-29/webui.py @@ -1,21 +1,22 @@ -import streamlit as st - -# from chatchat.webui_pages.loom_view_client import update_store -# from chatchat.webui_pages.openai_plugins import openai_plugins_page -from chatchat.webui_pages.utils import * -from streamlit_option_menu import option_menu -from chatchat.webui_pages.dialogue.dialogue import dialogue_page, chat_box -from chatchat.webui_pages.knowledge_base.knowledge_base import knowledge_base_page import os import sys + +import streamlit as st +from streamlit_option_menu import option_menu + from chatchat.configs import VERSION from chatchat.server.utils import api_address +from chatchat.webui_pages.dialogue.dialogue import chat_box, dialogue_page +from chatchat.webui_pages.knowledge_base.knowledge_base import knowledge_base_page +# from chatchat.webui_pages.loom_view_client import update_store +# from chatchat.webui_pages.openai_plugins import openai_plugins_page +from chatchat.webui_pages.utils import * # def on_change(key): # if key: # update_store() -img_dir = os.path.dirname(os.path.abspath(__file__)) +img_dir = os.path.dirname(os.path.abspath(__file__)) api = ApiRequest(base_url=api_address()) @@ -27,12 +28,11 @@ os.path.join(img_dir, "img", "chatchat_icon_blue_square_v2.png"), initial_sidebar_state="expanded", menu_items={ - 'Get Help': 'https://github.com/chatchat-space/Langchain-Chatchat', - 'Report a bug': "https://github.com/chatchat-space/Langchain-Chatchat/issues", - 'About': f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}!""" + "Get Help": "https://github.com/chatchat-space/Langchain-Chatchat", + "Report a bug": "https://github.com/chatchat-space/Langchain-Chatchat/issues", + "About": f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}!""", }, - layout="wide" - + layout="wide", ) # use the following code to set the app to wide mode and the html markdown to increase the sidebar width @@ -73,8 +73,8 @@ # update_store() with st.sidebar: st.image( - os.path.join(img_dir, "img", 'logo-long-chatchat-trans-v2.png'), - use_column_width=True + os.path.join(img_dir, "img", "logo-long-chatchat-trans-v2.png"), + use_column_width=True, ) st.caption( f"""

当前版本:{VERSION}

""", diff --git a/libs/chatchat-server/chatchat/init_database.py b/libs/chatchat-server/chatchat/init_database.py index 88249a775..fef26af22 100644 --- a/libs/chatchat-server/chatchat/init_database.py +++ b/libs/chatchat-server/chatchat/init_database.py @@ -1,24 +1,33 @@ # Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作 -from datetime import datetime import multiprocessing as mp +from datetime import datetime from typing import Dict -from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db, - folder2db, prune_db_docs, prune_folder_files) from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS, logger - +from chatchat.server.knowledge_base.migrate import ( + create_tables, + folder2db, + import_from_db, + prune_db_docs, + prune_folder_files, + reset_tables, +) def run_init_model_provider( - model_platforms_shard: Dict, - started_event: mp.Event = None, - model_providers_cfg_path: str = None, - provider_host: str = None, - provider_port: int = None): + model_platforms_shard: Dict, + started_event: mp.Event = None, + model_providers_cfg_path: str = None, + provider_host: str = None, + provider_port: int = None, +): + from chatchat.configs import ( + MODEL_PROVIDERS_CFG_HOST, + MODEL_PROVIDERS_CFG_PATH_CONFIG, + MODEL_PROVIDERS_CFG_PORT, + ) from chatchat.init_server import init_server - from chatchat.configs import (MODEL_PROVIDERS_CFG_PATH_CONFIG, - MODEL_PROVIDERS_CFG_HOST, - MODEL_PROVIDERS_CFG_PORT) + if model_providers_cfg_path is None: model_providers_cfg_path = MODEL_PROVIDERS_CFG_PATH_CONFIG if provider_host is None: @@ -26,79 +35,89 @@ def run_init_model_provider( if provider_port is None: provider_port = MODEL_PROVIDERS_CFG_PORT - init_server(model_platforms_shard=model_platforms_shard, - started_event=started_event, - model_providers_cfg_path=model_providers_cfg_path, - provider_host=provider_host, - provider_port=provider_port) + init_server( + model_platforms_shard=model_platforms_shard, + started_event=started_event, + model_providers_cfg_path=model_providers_cfg_path, + provider_host=provider_host, + provider_port=provider_port, + ) def main(): import argparse - parser = argparse.ArgumentParser(description="please specify only one operate method once time.") + parser = argparse.ArgumentParser( + description="please specify only one operate method once time." + ) parser.add_argument( "-r", "--recreate-vs", action="store_true", - help=(''' + help=( + """ recreate vector store. use this option if you have copied document files to the content folder, but vector store has not been populated or DEFAUL_VS_TYPE/DEFAULT_EMBEDDING_MODEL changed. - ''' - ) + """ + ), ) parser.add_argument( "--create-tables", action="store_true", - help=("create empty tables if not existed") + help=("create empty tables if not existed"), ) parser.add_argument( "--clear-tables", action="store_true", - help=("create empty tables, or drop the database tables before recreate vector stores") + help=( + "create empty tables, or drop the database tables before recreate vector stores" + ), ) parser.add_argument( - "--import-db", - help="import tables from specified sqlite database" + "--import-db", help="import tables from specified sqlite database" ) parser.add_argument( "-u", "--update-in-db", action="store_true", - help=(''' + help=( + """ update vector store for files exist in database. use this option if you want to recreate vectors for files exist in db and skip files exist in local folder only. - ''' - ) + """ + ), ) parser.add_argument( "-i", "--increment", action="store_true", - help=(''' + help=( + """ update vector store for files exist in local folder and not exist in database. use this option if you want to create vectors incrementally. - ''' - ) + """ + ), ) parser.add_argument( "--prune-db", action="store_true", - help=(''' + help=( + """ delete docs in database that not existed in local folder. it is used to delete database docs after user deleted some doc files in file browser - ''' - ) + """ + ), ) parser.add_argument( "--prune-folder", action="store_true", - help=(''' + help=( + """ delete doc files in local folder that not existed in database. is is used to free local disk space by delete unused doc files. - ''' - ) + """ + ), ) parser.add_argument( "-n", @@ -106,14 +125,16 @@ def main(): type=str, nargs="+", default=[], - help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.") + help=( + "specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH." + ), ) parser.add_argument( "-e", "--embed-model", type=str, default=DEFAULT_EMBEDDING_MODEL, - help=("specify embeddings model.") + help=("specify embeddings model."), ) args = parser.parse_args() @@ -129,7 +150,10 @@ def main(): process = mp.Process( target=run_init_model_provider, name=f"Model providers Server", - kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=model_providers_started), + kwargs=dict( + model_platforms_shard=model_platforms_shard, + started_event=model_providers_started, + ), daemon=True, ) processes["model_providers"] = process @@ -139,12 +163,11 @@ def main(): p.start() p.name = f"{p.name} ({p.pid})" model_providers_started.wait() # 等待model_providers启动完成 - MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms']) + MODEL_PLATFORMS.extend(model_platforms_shard["provider_platforms"]) logger.info(f"Api MODEL_PLATFORMS: {MODEL_PLATFORMS}") - if args.create_tables: - create_tables() # confirm tables exist + create_tables() # confirm tables exist if args.clear_tables: reset_tables() @@ -153,13 +176,19 @@ def main(): if args.recreate_vs: create_tables() print("recreating all vector stores") - folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model) + folder2db( + kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model + ) elif args.import_db: import_from_db(args.import_db) elif args.update_in_db: - folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model) + folder2db( + kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model + ) elif args.increment: - folder2db(kb_names=args.kb_name, mode="increment", embed_model=args.embed_model) + folder2db( + kb_names=args.kb_name, mode="increment", embed_model=args.embed_model + ) elif args.prune_db: prune_db_docs(args.kb_name) elif args.prune_folder: @@ -171,7 +200,6 @@ def main(): logger.error(e, exc_info=True) logger.warning("Caught KeyboardInterrupt! Setting stop event...") finally: - for p in processes.values(): logger.warning("Sending SIGKILL to %s", p) # Queues and other inter-process communication primitives can break when diff --git a/libs/chatchat-server/chatchat/init_server.py b/libs/chatchat-server/chatchat/init_server.py index 361052ef2..eb5ca7d48 100644 --- a/libs/chatchat-server/chatchat/init_server.py +++ b/libs/chatchat-server/chatchat/init_server.py @@ -1,7 +1,12 @@ -from typing import List, Dict -from chatchat.configs import MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT, MODEL_PROVIDERS_CFG_PATH_CONFIG +import asyncio +import logging +import multiprocessing as mp +from typing import Dict, List + from model_providers import BootstrapWebBuilder -from model_providers.bootstrap_web.entities.model_provider_entities import ProviderResponse +from model_providers.bootstrap_web.entities.model_provider_entities import ( + ProviderResponse, +) from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper from model_providers.core.provider_manager import ProviderManager from model_providers.core.utils.utils import ( @@ -9,42 +14,45 @@ get_log_file, get_timestamp_ms, ) -import multiprocessing as mp -import asyncio -import logging + +from chatchat.configs import ( + MODEL_PROVIDERS_CFG_HOST, + MODEL_PROVIDERS_CFG_PATH_CONFIG, + MODEL_PROVIDERS_CFG_PORT, +) logger = logging.getLogger(__name__) -def init_server(model_platforms_shard: Dict, - started_event: mp.Event = None, - model_providers_cfg_path: str = MODEL_PROVIDERS_CFG_PATH_CONFIG, - provider_host: str = MODEL_PROVIDERS_CFG_HOST, - provider_port: int = MODEL_PROVIDERS_CFG_PORT, - log_path: str = "logs" - ) -> None: +def init_server( + model_platforms_shard: Dict, + started_event: mp.Event = None, + model_providers_cfg_path: str = MODEL_PROVIDERS_CFG_PATH_CONFIG, + provider_host: str = MODEL_PROVIDERS_CFG_HOST, + provider_port: int = MODEL_PROVIDERS_CFG_PORT, + log_path: str = "logs", +) -> None: logging_conf = get_config_dict( "INFO", get_log_file(log_path=log_path, sub_dir=f"provider_{get_timestamp_ms()}"), - - 1024*1024*1024*3, - 1024*1024*1024*3, + 1024 * 1024 * 1024 * 3, + 1024 * 1024 * 1024 * 3, ) try: boot = ( BootstrapWebBuilder() - .model_providers_cfg_path( - model_providers_cfg_path=model_providers_cfg_path - ) + .model_providers_cfg_path(model_providers_cfg_path=model_providers_cfg_path) .host(host=provider_host) .port(port=provider_port) .build() ) boot.set_app_event(started_event=started_event) - provider_platforms = init_provider_platforms(boot.provider_manager.provider_manager) - model_platforms_shard['provider_platforms'] = provider_platforms + provider_platforms = init_provider_platforms( + boot.provider_manager.provider_manager + ) + model_platforms_shard["provider_platforms"] = provider_platforms boot.logging_conf(logging_conf=logging_conf) boot.run() @@ -57,9 +65,10 @@ async def pool_join_thread(): raise -def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]: +def init_provider_platforms(provider_manager: ProviderManager) -> List[Dict]: provider_list: List[ProviderResponse] = ProvidersWrapper( - provider_manager=provider_manager).get_provider_list() + provider_manager=provider_manager + ).get_provider_list() logger.info(f"Provider list: {provider_list}") # 转换MODEL_PLATFORMS provider_platforms = [] @@ -69,7 +78,7 @@ def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]: "platform_type": provider.provider, "api_base_url": f"http://127.0.0.1:20000/{provider.provider}/v1", "api_key": "EMPTY", - "api_concurrencies": 5 + "api_concurrencies": 5, } provider_dict["llm_models"] = [] @@ -78,11 +87,12 @@ def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]: provider_dict["reranking_models"] = [] provider_dict["speech2text_models"] = [] provider_dict["tts_models"] = [] - supported_model_str_types = [model_type.to_origin_model_type() for model_type in - provider.supported_model_types] + supported_model_str_types = [ + model_type.to_origin_model_type() + for model_type in provider.supported_model_types + ] for model_type in supported_model_str_types: - providers_model_type = ProvidersWrapper( provider_manager=provider_manager ).get_models_by_model_type(model_type=model_type) @@ -113,4 +123,4 @@ def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]: logger.info(f"Provider platforms: {provider_platforms}") - return provider_platforms \ No newline at end of file + return provider_platforms diff --git a/libs/chatchat-server/chatchat/server/agent/agent_factory/agents_registry.py b/libs/chatchat-server/chatchat/server/agent/agent_factory/agents_registry.py index 7de96aaab..dc8e9fc0c 100644 --- a/libs/chatchat-server/chatchat/server/agent/agent_factory/agents_registry.py +++ b/libs/chatchat-server/chatchat/server/agent/agent_factory/agents_registry.py @@ -8,36 +8,45 @@ from langchain_core.prompts import ChatPromptTemplate from langchain_core.tools import BaseTool -from chatchat.server.agent.agent_factory import ( create_structured_qwen_chat_agent) -from chatchat.server.agent.agent_factory.glm3_agent import create_structured_glm3_chat_agent +from chatchat.server.agent.agent_factory import create_structured_qwen_chat_agent +from chatchat.server.agent.agent_factory.glm3_agent import ( + create_structured_glm3_chat_agent, +) def agents_registry( - llm: BaseLanguageModel, - tools: Sequence[BaseTool] = [], - callbacks: List[BaseCallbackHandler] = [], - prompt: str = None, - verbose: bool = False): + llm: BaseLanguageModel, + tools: Sequence[BaseTool] = [], + callbacks: List[BaseCallbackHandler] = [], + prompt: str = None, + verbose: bool = False, +): # llm.callbacks = callbacks - llm.streaming = False # qwen agent not support streaming + llm.streaming = False # qwen agent not support streaming # Write any optimized method here. if "glm3" in llm.model_name.lower(): # An optimized method of langchain Agent that uses the glm3 series model agent = create_structured_glm3_chat_agent(llm=llm, tools=tools) - agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks) + agent_executor = AgentExecutor( + agent=agent, tools=tools, verbose=verbose, callbacks=callbacks + ) return agent_executor elif "qwen" in llm.model_name.lower(): - return create_structured_qwen_chat_agent(llm=llm, tools=tools, callbacks=callbacks) + return create_structured_qwen_chat_agent( + llm=llm, tools=tools, callbacks=callbacks + ) else: if prompt is not None: prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)]) else: - prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt + prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt) - agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks) + agent_executor = AgentExecutor( + agent=agent, tools=tools, verbose=verbose, callbacks=callbacks + ) return agent_executor diff --git a/libs/chatchat-server/chatchat/server/agent/agent_factory/glm3_agent.py b/libs/chatchat-server/chatchat/server/agent/agent_factory/glm3_agent.py index 644c2f558..612ed6cd5 100644 --- a/libs/chatchat-server/chatchat/server/agent/agent_factory/glm3_agent.py +++ b/libs/chatchat-server/chatchat/server/agent/agent_factory/glm3_agent.py @@ -4,20 +4,20 @@ import json import logging -from typing import Sequence, Optional, Union +from typing import Optional, Sequence, Union -import langchain_core.prompts import langchain_core.messages -from langchain_core.runnables import Runnable, RunnablePassthrough +import langchain_core.prompts from langchain.agents.agent import AgentOutputParser from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser -from langchain.prompts.chat import ChatPromptTemplate from langchain.output_parsers import OutputFixingParser +from langchain.prompts.chat import ChatPromptTemplate from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool -from chatchat.server.pydantic_v1 import Field, typing, model_schema +from langchain_core.runnables import Runnable, RunnablePassthrough +from chatchat.server.pydantic_v1 import Field, model_schema, typing logger = logging.getLogger(__name__) @@ -29,6 +29,7 @@ class StructuredGLM3ChatOutputParser(AgentOutputParser): """ Output parser with retries for the structured chat agent. """ + base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser) output_fixing_parser: Optional[OutputFixingParser] = None @@ -36,7 +37,12 @@ def parse(self, text: str) -> Union[AgentAction, AgentFinish]: print(text) special_tokens = ["Action:", "<|observation|>"] - first_index = min([text.find(token) if token in text else len(text) for token in special_tokens]) + first_index = min( + [ + text.find(token) if token in text else len(text) + for token in special_tokens + ] + ) text = text[:first_index] if "tool_call" in text: @@ -46,18 +52,16 @@ def parse(self, text: str) -> Union[AgentAction, AgentFinish]: params_str_end = text.rfind(")") params_str = text[params_str_start:params_str_end] - params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param] - params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs} - - action_json = { - "action": action, - "action_input": params + params_pairs = [ + param.split("=") for param in params_str.split(",") if "=" in param + ] + params = { + pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs } + + action_json = {"action": action, "action_input": params} else: - action_json = { - "action": "Final Answer", - "action_input": text - } + action_json = {"action": "Final Answer", "action_input": text} action_str = f""" Action: ``` @@ -80,60 +84,69 @@ def _type(self) -> str: def create_structured_glm3_chat_agent( - llm: BaseLanguageModel, tools: Sequence[BaseTool] + llm: BaseLanguageModel, tools: Sequence[BaseTool] ) -> Runnable: tools_json = [] for tool in tools: tool_schema = model_schema(tool.args_schema) if tool.args_schema else {} - description = tool.description.split(" - ")[ - 1].strip() if tool.description and " - " in tool.description else tool.description - parameters = {k: {sub_k: sub_v for sub_k, sub_v in v.items() if sub_k != 'title'} for k, v in - tool_schema.get("properties", {}).items()} + description = ( + tool.description.split(" - ")[1].strip() + if tool.description and " - " in tool.description + else tool.description + ) + parameters = { + k: {sub_k: sub_v for sub_k, sub_v in v.items() if sub_k != "title"} + for k, v in tool_schema.get("properties", {}).items() + } simplified_config_langchain = { "name": tool.name, "description": description, - "parameters": parameters + "parameters": parameters, } tools_json.append(simplified_config_langchain) - tools = "\n".join([json.dumps(tool, indent=4, ensure_ascii=False) for tool in tools_json]) + tools = "\n".join( + [json.dumps(tool, indent=4, ensure_ascii=False) for tool in tools_json] + ) prompt = ChatPromptTemplate( input_variables=["input", "agent_scratchpad"], - input_types={'chat_history': typing.List[typing.Union[ - langchain_core.messages.ai.AIMessage, - langchain_core.messages.human.HumanMessage, - langchain_core.messages.chat.ChatMessage, - langchain_core.messages.system.SystemMessage, - langchain_core.messages.function.FunctionMessage, - langchain_core.messages.tool.ToolMessage]] - }, + input_types={ + "chat_history": typing.List[ + typing.Union[ + langchain_core.messages.ai.AIMessage, + langchain_core.messages.human.HumanMessage, + langchain_core.messages.chat.ChatMessage, + langchain_core.messages.system.SystemMessage, + langchain_core.messages.function.FunctionMessage, + langchain_core.messages.tool.ToolMessage, + ] + ] + }, messages=[ langchain_core.prompts.SystemMessagePromptTemplate( prompt=langchain_core.prompts.PromptTemplate( - input_variables=['tools'], - template=SYSTEM_PROMPT) + input_variables=["tools"], template=SYSTEM_PROMPT + ) ), langchain_core.prompts.MessagesPlaceholder( - variable_name='chat_history', - optional=True + variable_name="chat_history", optional=True ), langchain_core.prompts.HumanMessagePromptTemplate( prompt=langchain_core.prompts.PromptTemplate( - input_variables=['agent_scratchpad', 'input'], - template=HUMAN_MESSAGE + input_variables=["agent_scratchpad", "input"], + template=HUMAN_MESSAGE, ) - ) - ] - + ), + ], ).partial(tools=tools) llm_with_stop = llm.bind(stop=["<|observation|>"]) agent = ( - RunnablePassthrough.assign( - agent_scratchpad=lambda x: x["intermediate_steps"], - ) - | prompt - | llm_with_stop - | StructuredGLM3ChatOutputParser() + RunnablePassthrough.assign( + agent_scratchpad=lambda x: x["intermediate_steps"], + ) + | prompt + | llm_with_stop + | StructuredGLM3ChatOutputParser() ) return agent diff --git a/libs/chatchat-server/chatchat/server/agent/agent_factory/qwen_agent.py b/libs/chatchat-server/chatchat/server/agent/agent_factory/qwen_agent.py index 23cdb5deb..7aab34ec0 100644 --- a/libs/chatchat-server/chatchat/server/agent/agent_factory/qwen_agent.py +++ b/libs/chatchat-server/chatchat/server/agent/agent_factory/qwen_agent.py @@ -1,23 +1,29 @@ from __future__ import annotations -from functools import partial import json import logging -from operator import itemgetter import re -from typing import List, Sequence, Union, Tuple, Any +from functools import partial +from operator import itemgetter +from typing import Any, List, Sequence, Tuple, Union -from langchain_core.callbacks import Callbacks -from langchain_core.runnables import Runnable, RunnablePassthrough -from langchain.agents.agent import RunnableAgent, AgentExecutor +from langchain.agents.agent import AgentExecutor, RunnableAgent from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser from langchain.prompts.chat import BaseChatPromptTemplate -from langchain.schema import (AgentAction, AgentFinish, OutputParserException, - HumanMessage, SystemMessage, AIMessage) +from langchain.schema import ( + AgentAction, + AgentFinish, + AIMessage, + HumanMessage, + OutputParserException, + SystemMessage, +) from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool -from chatchat.server.utils import get_prompt_template +from langchain_core.callbacks import Callbacks +from langchain_core.runnables import Runnable, RunnablePassthrough +from chatchat.server.utils import get_prompt_template logger = logging.getLogger(__name__) @@ -34,6 +40,7 @@ def _plan_without_stream( inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}} return self.runnable.invoke(inputs, config={"callbacks": callbacks}) + async def _aplan_without_stream( self: RunnableAgent, intermediate_steps: List[Tuple[AgentAction, str]], @@ -60,7 +67,9 @@ def format_messages(self, **kwargs) -> str: thoughts += f"\nObservation: {observation}\nThought: " # Set the agent_scratchpad variable to that value if thoughts: - kwargs["agent_scratchpad"] = f"These were previous tasks you completed:\n{thoughts}\n\n" + kwargs[ + "agent_scratchpad" + ] = f"These were previous tasks you completed:\n{thoughts}\n\n" else: kwargs["agent_scratchpad"] = "" # Create a tools variable from the list of tools provided @@ -68,9 +77,10 @@ def format_messages(self, **kwargs) -> str: tools = [] for t in self.tools: desc = re.sub(r"\n+", " ", t.description) - text = (f"{t.name}: Call this tool to interact with the {t.name} API. What is the {t.name} API useful for?" - f" {desc}" - f" Parameters: {t.args}" + text = ( + f"{t.name}: Call this tool to interact with the {t.name} API. What is the {t.name} API useful for?" + f" {desc}" + f" Parameters: {t.args}" ) tools.append(text) kwargs["tools"] = "\n\n".join(tools) @@ -85,19 +95,21 @@ class QwenChatAgentOutputParserCustom(StructuredChatOutputParser): """Output parser with retries for the structured chat agent with custom qwen prompt.""" def parse(self, text: str) -> Union[AgentAction, AgentFinish]: - if s := re.findall(r"\nAction:\s*(.+)\nAction\sInput:\s*(.+)", text, flags=re.DOTALL): + if s := re.findall( + r"\nAction:\s*(.+)\nAction\sInput:\s*(.+)", text, flags=re.DOTALL + ): s = s[-1] - json_string: str=s[1] - json_input=None + json_string: str = s[1] + json_input = None try: - json_input=json.loads(json_string) + json_input = json.loads(json_string) except: # ollama部署的qwen,返回的json键值可能为单引号,可能缺少最后的引号和括号 - if not json_string.endswith("\"}"): - print("尝试修复格式不正确的json输出:"+json_string) - json_string=(json_string+"\"}").replace("'","\""); - print("修复后的json:"+json_string) - json_input=json.loads(json_string) + if not json_string.endswith('"}'): + print("尝试修复格式不正确的json输出:" + json_string) + json_string = (json_string + '"}').replace("'", '"') + print("修复后的json:" + json_string) + json_input = json.loads(json_string) return AgentAction(tool=s[0].strip(), tool_input=json_input, log=text) elif s := re.findall(r"\nFinal\sAnswer:\s*(.+)", text, flags=re.DOTALL): s = s[-1] @@ -121,7 +133,9 @@ def parse(self, text: str) -> Union[AgentAction, AgentFinish]: if tool == "Final Answer": return AgentFinish({"output": action.get("action_input", "")}, log=text) else: - return AgentAction(tool=tool, tool_input=action.get("action_input", {}), log=text) + return AgentAction( + tool=tool, tool_input=action.get("action_input", {}), log=text + ) else: raise OutputParserException(f"Could not parse LLM output: {text}") @@ -131,10 +145,10 @@ def _type(self) -> str: def create_structured_qwen_chat_agent( - llm: BaseLanguageModel, - tools: Sequence[BaseTool], - callbacks: Sequence[Callbacks], - use_custom_prompt: bool = True, + llm: BaseLanguageModel, + tools: Sequence[BaseTool], + callbacks: Sequence[Callbacks], + use_custom_prompt: bool = True, ) -> AgentExecutor: if use_custom_prompt: prompt = "qwen" @@ -145,16 +159,16 @@ def create_structured_qwen_chat_agent( tools = [t.copy(update={"callbacks": callbacks}) for t in tools] template = get_prompt_template("action_model", prompt) - prompt = QwenChatAgentPromptTemplate(input_variables=["input", "intermediate_steps"], - template=template, - tools=tools) + prompt = QwenChatAgentPromptTemplate( + input_variables=["input", "intermediate_steps"], template=template, tools=tools + ) agent = ( - RunnablePassthrough.assign( - agent_scratchpad=itemgetter("intermediate_steps") - ) + RunnablePassthrough.assign(agent_scratchpad=itemgetter("intermediate_steps")) | prompt - | llm.bind(stop=["<|endoftext|>", "<|im_start|>", "<|im_end|>", "\nObservation:"]) + | llm.bind( + stop=["<|endoftext|>", "<|im_start|>", "<|im_end|>", "\nObservation:"] + ) | output_parser ) executor = AgentExecutor(agent=agent, tools=tools, callbacks=callbacks) diff --git a/libs/chatchat-server/chatchat/server/agent/container.py b/libs/chatchat-server/chatchat/server/agent/container.py index e510d82ce..1d0f7c9ca 100644 --- a/libs/chatchat-server/chatchat/server/agent/container.py +++ b/libs/chatchat-server/chatchat/server/agent/container.py @@ -1,4 +1,5 @@ import logging + from chatchat.server.utils import get_tool_config logger = logging.getLogger(__name__) @@ -17,36 +18,52 @@ def __init__(self): vqa_config = get_tool_config("vqa_processor") if vqa_config["use"]: try: - from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer import torch + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + LlamaTokenizer, + ) + self.vision_tokenizer = LlamaTokenizer.from_pretrained( - vqa_config["tokenizer_path"], - trust_remote_code=True) - self.vision_model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=vqa_config["model_path"], - torch_dtype=torch.bfloat16, - low_cpu_mem_usage=True, - trust_remote_code=True - ).to(vqa_config["device"]).eval() + vqa_config["tokenizer_path"], trust_remote_code=True + ) + self.vision_model = ( + AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=vqa_config["model_path"], + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + .to(vqa_config["device"]) + .eval() + ) except Exception as e: logger.error(e, exc_info=True) aqa_config = get_tool_config("vqa_processor") if aqa_config["use"]: try: - from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer import torch + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + LlamaTokenizer, + ) + self.audio_tokenizer = AutoTokenizer.from_pretrained( - aqa_config["tokenizer_path"], - trust_remote_code=True + aqa_config["tokenizer_path"], trust_remote_code=True + ) + self.audio_model = ( + AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=aqa_config["model_path"], + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + .to(aqa_config["device"]) + .eval() ) - self.audio_model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=aqa_config["model_path"], - torch_dtype=torch.bfloat16, - low_cpu_mem_usage=True, - trust_remote_code=True).to( - aqa_config["device"] - ).eval() except Exception as e: logger.error(e, exc_info=True) diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/__init__.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/__init__.py index 2242434f1..ff8933b85 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/__init__.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/__init__.py @@ -1,13 +1,12 @@ -from .search_local_knowledgebase import search_local_knowledgebase +from .aqa_processor import aqa_processor +from .arxiv import arxiv from .calculate import calculate -from .weather_check import weather_check -from .shell import shell from .search_internet import search_internet -from .wolfram import wolfram +from .search_local_knowledgebase import search_local_knowledgebase from .search_youtube import search_youtube -from .arxiv import arxiv +from .shell import shell from .text2image import text2images - +from .text2sql import text2sql from .vqa_processor import vqa_processor -from .aqa_processor import aqa_processor -from .text2sql import text2sql \ No newline at end of file +from .weather_check import weather_check +from .wolfram import wolfram diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/aqa_processor.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/aqa_processor.py index d6a1d70a4..8338cc7ae 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/aqa_processor.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/aqa_processor.py @@ -1,14 +1,17 @@ import base64 + from chatchat.server.pydantic_v1 import Field from chatchat.server.utils import get_tool_config -from .tools_registry import regist_tool, BaseToolOutput + +from .tools_registry import BaseToolOutput, regist_tool def save_base64_audio(base64_audio, file_path): audio_data = base64.b64decode(base64_audio) - with open(file_path, 'wb') as audio_file: + with open(file_path, "wb") as audio_file: audio_file.write(audio_data) + def aqa_run(model, tokenizer, query): query = tokenizer.from_list_format([query]) response, history = model.chat(tokenizer, query=query, history=None) @@ -17,10 +20,13 @@ def aqa_run(model, tokenizer, query): @regist_tool(title="音频问答") -def aqa_processor(query: str = Field(description="The question of the audio in English")): - '''use this tool to get answer for audio question''' +def aqa_processor( + query: str = Field(description="The question of the audio in English"), +): + """use this tool to get answer for audio question""" from chatchat.server.agent.container import container + if container.metadata["audios"]: file_path = "temp_audio.mp3" save_base64_audio(container.metadata["audios"][0], file_path) @@ -28,8 +34,12 @@ def aqa_processor(query: str = Field(description="The question of the audio in E "audio": file_path, "text": query, } - ret = aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model) + ret = aqa_run( + tokenizer=container.audio_tokenizer, + query=query_input, + model=container.audio_model, + ) else: ret = "No Audio, Please Try Again" - + return BaseToolOutput(ret) diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/arxiv.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/arxiv.py index b5da1cf7a..cc83ade0e 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/arxiv.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/arxiv.py @@ -1,11 +1,12 @@ # LangChain 的 ArxivQueryRun 工具 from chatchat.server.pydantic_v1 import Field -from .tools_registry import regist_tool, BaseToolOutput + +from .tools_registry import BaseToolOutput, regist_tool @regist_tool(title="ARXIV论文") def arxiv(query: str = Field(description="The search query title")): - '''A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.''' + """A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.""" from langchain.tools.arxiv.tool import ArxivQueryRun tool = ArxivQueryRun() diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/calculate.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/calculate.py index 69a0ae4a3..fdf773330 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/calculate.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/calculate.py @@ -1,13 +1,14 @@ from chatchat.server.pydantic_v1 import Field -from .tools_registry import regist_tool, BaseToolOutput + +from .tools_registry import BaseToolOutput, regist_tool @regist_tool(title="数学计算器") def calculate(text: str = Field(description="a math expression")) -> float: - ''' + """ Useful to answer questions about simple calculations. translate user question to a math expression that can be evaluated by numexpr. - ''' + """ import numexpr try: diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py index 344e89012..132dc3166 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py @@ -1,20 +1,23 @@ -from typing import List, Dict +from typing import Dict, List +from langchain.docstore.document import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain.docstore.document import Document from markdownify import markdownify from strsimpy.normalized_levenshtein import NormalizedLevenshtein -from chatchat.server.utils import get_tool_config from chatchat.server.pydantic_v1 import Field -from .tools_registry import regist_tool, BaseToolOutput +from chatchat.server.utils import get_tool_config + +from .tools_registry import BaseToolOutput, regist_tool def bing_search(text, config): - search = BingSearchAPIWrapper(bing_subscription_key=config["bing_key"], - bing_search_url=config["bing_search_url"]) + search = BingSearchAPIWrapper( + bing_subscription_key=config["bing_key"], + bing_search_url=config["bing_search_url"], + ) return search.results(text, config["result_len"]) @@ -24,65 +27,76 @@ def duckduckgo_search(text, config): def metaphor_search( - text: str, - config: dict, + text: str, + config: dict, ) -> List[Dict]: from metaphor_python import Metaphor + client = Metaphor(config["metaphor_api_key"]) search = client.search(text, num_results=config["result_len"], use_autoprompt=True) contents = search.get_contents().contents for x in contents: x.extract = markdownify(x.extract) if config["split_result"]: - docs = [Document(page_content=x.extract, - metadata={"link": x.url, "title": x.title}) - for x in contents] - text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "], - chunk_size=config["chunk_size"], - chunk_overlap=config["chunk_overlap"]) + docs = [ + Document(page_content=x.extract, metadata={"link": x.url, "title": x.title}) + for x in contents + ] + text_splitter = RecursiveCharacterTextSplitter( + ["\n\n", "\n", ".", " "], + chunk_size=config["chunk_size"], + chunk_overlap=config["chunk_overlap"], + ) splitted_docs = text_splitter.split_documents(docs) if len(splitted_docs) > config["result_len"]: normal = NormalizedLevenshtein() for x in splitted_docs: x.metadata["score"] = normal.similarity(text, x.page_content) splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True) - splitted_docs = splitted_docs[:config["result_len"]] - - docs = [{"snippet": x.page_content, - "link": x.metadata["link"], - "title": x.metadata["title"]} - for x in splitted_docs] + splitted_docs = splitted_docs[: config["result_len"]] + + docs = [ + { + "snippet": x.page_content, + "link": x.metadata["link"], + "title": x.metadata["title"], + } + for x in splitted_docs + ] else: - docs = [{"snippet": x.extract, - "link": x.url, - "title": x.title} - for x in contents] + docs = [ + {"snippet": x.extract, "link": x.url, "title": x.title} for x in contents + ] return docs -SEARCH_ENGINES = {"bing": bing_search, - "duckduckgo": duckduckgo_search, - "metaphor": metaphor_search, - } +SEARCH_ENGINES = { + "bing": bing_search, + "duckduckgo": duckduckgo_search, + "metaphor": metaphor_search, +} def search_result2docs(search_results): docs = [] for result in search_results: - doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "", - metadata={"source": result["link"] if "link" in result.keys() else "", - "filename": result["title"] if "title" in result.keys() else ""}) + doc = Document( + page_content=result["snippet"] if "snippet" in result.keys() else "", + metadata={ + "source": result["link"] if "link" in result.keys() else "", + "filename": result["title"] if "title" in result.keys() else "", + }, + ) docs.append(doc) return docs -def search_engine(query: str, - config: dict): +def search_engine(query: str, config: dict): search_engine_use = SEARCH_ENGINES[config["search_engine_name"]] - results = search_engine_use(text=query, - config=config["search_engine_config"][ - config["search_engine_name"]]) + results = search_engine_use( + text=query, config=config["search_engine_config"][config["search_engine_name"]] + ) docs = search_result2docs(results) context = "" docs = [ @@ -97,6 +111,6 @@ def search_engine(query: str, @regist_tool(title="互联网搜索") def search_internet(query: str = Field(description="query for Internet search")): - '''Use this tool to use bing search engine to search the internet and get information.''' + """Use this tool to use bing search engine to search the internet and get information.""" tool_config = get_tool_config("search_internet") return BaseToolOutput(search_engine(query=query, config=tool_config)) diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py index 7cd393124..6a3786098 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py @@ -1,15 +1,20 @@ from urllib.parse import urlencode -from chatchat.server.utils import get_tool_config -from chatchat.server.pydantic_v1 import Field -from chatchat.server.agent.tools_factory.tools_registry import regist_tool, BaseToolOutput -from chatchat.server.knowledge_base.kb_api import list_kbs -from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId -from chatchat.configs import KB_INFO +from chatchat.configs import KB_INFO +from chatchat.server.agent.tools_factory.tools_registry import ( + BaseToolOutput, + regist_tool, +) +from chatchat.server.knowledge_base.kb_api import list_kbs +from chatchat.server.knowledge_base.kb_doc_api import DocumentWithVSId, search_docs +from chatchat.server.pydantic_v1 import Field +from chatchat.server.utils import get_tool_config -template = ("Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on " - "this knowledge use this tool. The 'database' should be one of the above [{key}].") -KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()]) +template = ( + "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on " + "this knowledge use this tool. The 'database' should be one of the above [{key}]." +) +KB_info_str = "\n".join([f"{key}: {value}" for key, value in KB_INFO.items()]) template_knowledge = template.format(KB_info=KB_info_str, key="samples") @@ -37,13 +42,17 @@ def search_knowledgebase(query: str, database: str, config: dict): query=query, knowledge_base_name=database, top_k=config["top_k"], - score_threshold=config["score_threshold"]) + score_threshold=config["score_threshold"], + ) return {"knowledge_base": database, "docs": docs} @regist_tool(description=template_knowledge, title="本地知识库") def search_local_knowledgebase( - database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]), + database: str = Field( + description="Database for Knowledge Search", + choices=[kb.kb_name for kb in list_kbs().data], + ), query: str = Field(description="Query for Knowledge Search"), ): """""" diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_youtube.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_youtube.py index ab3cb03e6..353be531c 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_youtube.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_youtube.py @@ -1,10 +1,12 @@ from chatchat.server.pydantic_v1 import Field -from .tools_registry import regist_tool, BaseToolOutput + +from .tools_registry import BaseToolOutput, regist_tool @regist_tool(title="油管视频") def search_youtube(query: str = Field(description="Query for Videos search")): - '''use this tools_factory to search youtube videos''' + """use this tools_factory to search youtube videos""" from langchain_community.tools import YouTubeSearchTool + tool = YouTubeSearchTool() return BaseToolOutput(tool.run(tool_input=query)) diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/shell.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/shell.py index 32d63ac5c..34e6f47e5 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/shell.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/shell.py @@ -2,11 +2,12 @@ from langchain_community.tools import ShellTool from chatchat.server.pydantic_v1 import Field -from .tools_registry import regist_tool, BaseToolOutput + +from .tools_registry import BaseToolOutput, regist_tool @regist_tool(title="系统命令") def shell(query: str = Field(description="The command to execute")): - '''Use Shell to execute system shell commands''' + """Use Shell to execute system shell commands""" tool = ShellTool() return BaseToolOutput(tool.run(tool_input=query)) diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/text2image.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/text2image.py index 6971b0241..0b645eb15 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/text2image.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/text2image.py @@ -1,17 +1,17 @@ import base64 import json import os -from PIL import Image -from typing import List import uuid +from typing import List -from chatchat.server.pydantic_v1 import Field -from chatchat.server.utils import get_tool_config -from .tools_registry import regist_tool, BaseToolOutput import openai +from PIL import Image from chatchat.configs import MEDIA_PATH -from chatchat.server.utils import MsgType +from chatchat.server.pydantic_v1 import Field +from chatchat.server.utils import MsgType, get_tool_config + +from .tools_registry import BaseToolOutput, regist_tool def get_image_model_config() -> dict: @@ -26,6 +26,7 @@ def get_image_model_config() -> dict: # return config pass + @regist_tool(title="文生图", return_direct=True) def text2images( prompt: str, @@ -33,7 +34,7 @@ def text2images( width: int = Field(512, description="生成图片的宽度"), height: int = Field(512, description="生成图片的高度"), ) -> List[str]: - '''根据用户的描述生成图片''' + """根据用户的描述生成图片""" model_config = get_image_model_config() assert model_config is not None, "请正确配置文生图模型" @@ -43,12 +44,13 @@ def text2images( api_key=model_config["api_key"], timeout=600, ) - resp = client.images.generate(prompt=prompt, - n=n, - size=f"{width}*{height}", - response_format="b64_json", - model=model_config["model_name"], - ) + resp = client.images.generate( + prompt=prompt, + n=n, + size=f"{width}*{height}", + response_format="b64_json", + model=model_config["model_name"], + ) images = [] for x in resp.data: uid = uuid.uuid4().hex @@ -56,14 +58,18 @@ def text2images( with open(os.path.join(MEDIA_PATH, filename), "wb") as fp: fp.write(base64.b64decode(x.b64_json)) images.append(filename) - return BaseToolOutput({"message_type": MsgType.IMAGE, "images": images}, format="json") + return BaseToolOutput( + {"message_type": MsgType.IMAGE, "images": images}, format="json" + ) if __name__ == "__main__": + import sys from io import BytesIO - from matplotlib import pyplot as plt from pathlib import Path - import sys + + from matplotlib import pyplot as plt + sys.path.append(str(Path(__file__).parent.parent.parent.parent)) prompt = "draw a house with trees and river" diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py index d9f4640e1..0902561b8 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py @@ -1,14 +1,16 @@ +from langchain.chains import LLMChain from langchain_community.utilities import SQLDatabase -from langchain_experimental.sql import SQLDatabaseChain,SQLDatabaseSequentialChain -from chatchat.server.utils import get_tool_config -from chatchat.server.pydantic_v1 import Field -from .tools_registry import regist_tool, BaseToolOutput -from sqlalchemy.exc import OperationalError -from sqlalchemy import event from langchain_core.prompts.prompt import PromptTemplate -from langchain.chains import LLMChain +from langchain_experimental.sql import SQLDatabaseChain, SQLDatabaseSequentialChain +from sqlalchemy import event +from sqlalchemy.exc import OperationalError -READ_ONLY_PROMPT_TEMPLATE="""You are a MySQL expert. The database is currently in read-only mode. +from chatchat.server.pydantic_v1 import Field +from chatchat.server.utils import get_tool_config + +from .tools_registry import BaseToolOutput, regist_tool + +READ_ONLY_PROMPT_TEMPLATE = """You are a MySQL expert. The database is currently in read-only mode. Given an input question, determine if the related SQL can be executed in read-only mode. If the SQL can be executed normally, return Answer:'SQL can be executed normally'. If the SQL cannot be executed normally, return Answer: 'SQL cannot be executed normally'. @@ -19,16 +21,30 @@ Question: {query} """ + # 定义一个拦截器函数来检查SQL语句,以支持read-only,可修改下面的write_operations,以匹配你使用的数据库写操作关键字 def intercept_sql(conn, cursor, statement, parameters, context, executemany): # List of SQL keywords that indicate a write operation - write_operations = ("insert", "update", "delete", "create", "drop", "alter", "truncate", "rename") + write_operations = ( + "insert", + "update", + "delete", + "create", + "drop", + "alter", + "truncate", + "rename", + ) # Check if the statement starts with any of the write operation keywords if any(statement.strip().lower().startswith(op) for op in write_operations): - raise OperationalError("Database is read-only. Write operations are not allowed.", params=None, orig=None) + raise OperationalError( + "Database is read-only. Write operations are not allowed.", + params=None, + orig=None, + ) + -def query_database(query: str, - config: dict): +def query_database(query: str, config: dict): top_k = config["top_k"] return_intermediate_steps = config["return_intermediate_steps"] sqlalchemy_connect_str = config["sqlalchemy_connect_str"] @@ -36,25 +52,28 @@ def query_database(query: str, db = SQLDatabase.from_uri(sqlalchemy_connect_str) from chatchat.server.api_server.chat_routes import global_model_name - from chatchat.server.utils import get_ChatOpenAI + from chatchat.server.utils import get_ChatOpenAI + llm = get_ChatOpenAI( - model_name=global_model_name, - temperature=0, - streaming=True, - local_wrap=True, - verbose=True - ) - table_names=config["table_names"] - table_comments=config["table_comments"] + model_name=global_model_name, + temperature=0, + streaming=True, + local_wrap=True, + verbose=True, + ) + table_names = config["table_names"] + table_comments = config["table_comments"] result = None - #如果发现大模型判断用什么表出现问题,尝试给langchain提供额外的表说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判 - #由于langchain固定了输入参数,所以只能通过query传递额外的表说明 + # 如果发现大模型判断用什么表出现问题,尝试给langchain提供额外的表说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判 + # 由于langchain固定了输入参数,所以只能通过query传递额外的表说明 if table_comments: - TABLE_COMMNET_PROMPT="\n\nI will provide some special notes for a few tables:\n\n" - table_comments_str="\n".join([f"{k}:{v}" for k,v in table_comments.items()]) - query=query+TABLE_COMMNET_PROMPT+table_comments_str+"\n\n" - + TABLE_COMMNET_PROMPT = ( + "\n\nI will provide some special notes for a few tables:\n\n" + ) + table_comments_str = "\n".join([f"{k}:{v}" for k, v in table_comments.items()]) + query = query + TABLE_COMMNET_PROMPT + table_comments_str + "\n\n" + if read_only: # 在read_only下,先让大模型判断只读模式是否能满足需求,避免后续执行过程报错,返回友好提示。 READ_ONLY_PROMPT = PromptTemplate( @@ -68,36 +87,54 @@ def query_database(query: str, read_only_result = read_only_chain.invoke(query) if "SQL cannot be executed normally" in read_only_result["text"]: return "当前数据库为只读状态,无法满足您的需求!" - + # 当然大模型不能保证完全判断准确,为防止大模型判断有误,再从拦截器层面拒绝写操作 event.listen(db._engine, "before_cursor_execute", intercept_sql) - - #如果不指定table_names,优先走SQLDatabaseSequentialChain,这个链会先预测需要哪些表,然后再将相关表输入SQLDatabaseChain - #这是因为如果不指定table_names,直接走SQLDatabaseChain,Langchain会将全量表结构传递给大模型,可能会因token太长从而引发错误,也浪费资源 - #如果指定了table_names,直接走SQLDatabaseChain,将特定表结构传递给大模型进行判断 + + # 如果不指定table_names,优先走SQLDatabaseSequentialChain,这个链会先预测需要哪些表,然后再将相关表输入SQLDatabaseChain + # 这是因为如果不指定table_names,直接走SQLDatabaseChain,Langchain会将全量表结构传递给大模型,可能会因token太长从而引发错误,也浪费资源 + # 如果指定了table_names,直接走SQLDatabaseChain,将特定表结构传递给大模型进行判断 if len(table_names) > 0: - db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps) - result = db_chain.invoke({"query":query,"table_names_to_use":table_names}) + db_chain = SQLDatabaseChain.from_llm( + llm, + db, + verbose=True, + top_k=top_k, + return_intermediate_steps=return_intermediate_steps, + ) + result = db_chain.invoke({"query": query, "table_names_to_use": table_names}) else: - #先预测会使用哪些表,然后再将问题和预测的表给大模型 - db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps) + # 先预测会使用哪些表,然后再将问题和预测的表给大模型 + db_chain = SQLDatabaseSequentialChain.from_llm( + llm, + db, + verbose=True, + top_k=top_k, + return_intermediate_steps=return_intermediate_steps, + ) result = db_chain.invoke(query) - + context = f"""查询结果:{result['result']}\n\n""" - intermediate_steps=result["intermediate_steps"] - #如果存在intermediate_steps,且这个数组的长度大于2,则保留最后两个元素,因为前面几个步骤存在示例数据,容易引起误解 + intermediate_steps = result["intermediate_steps"] + # 如果存在intermediate_steps,且这个数组的长度大于2,则保留最后两个元素,因为前面几个步骤存在示例数据,容易引起误解 if intermediate_steps: - if len(intermediate_steps)>2: - sql_detail=intermediate_steps[-2:-1][0]["input"] + if len(intermediate_steps) > 2: + sql_detail = intermediate_steps[-2:-1][0]["input"] # sql_detail截取从SQLQuery到Answer:之间的内容 - sql_detail=sql_detail[sql_detail.find("SQLQuery:")+9:sql_detail.find("Answer:")] - context = context+"执行的sql:'"+sql_detail+"'\n\n" + sql_detail = sql_detail[ + sql_detail.find("SQLQuery:") + 9 : sql_detail.find("Answer:") + ] + context = context + "执行的sql:'" + sql_detail + "'\n\n" return context - + @regist_tool(title="Text2Sql") -def text2sql(query: str = Field(description="No need for SQL statements,just input the natural language that you want to chat with database")): - '''Use this tool to chat with database,Input natural language, then it will convert it into SQL and execute it in the database, then return the execution result.''' +def text2sql( + query: str = Field( + description="No need for SQL statements,just input the natural language that you want to chat with database" + ), +): + """Use this tool to chat with database,Input natural language, then it will convert it into SQL and execute it in the database, then return the execution result.""" tool_config = get_tool_config("text2sql") return BaseToolOutput(query_database(query=query, config=tool_config)) diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py index bd2e661b5..5be842bb6 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py @@ -1,13 +1,12 @@ import json import re -from typing import Any, Union, Dict, Tuple, Callable, Optional, Type +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union from langchain.agents import tool from langchain_core.tools import BaseTool from chatchat.server.pydantic_v1 import BaseModel, Extra - __all__ = ["regist_tool", "BaseToolOutput"] @@ -20,6 +19,7 @@ ################################### TODO: workaround to langchain #15855 # patch BaseTool to support tool parameters defined using pydantic Field + def _new_parse_input( self, tool_input: Union[str, Dict], @@ -73,10 +73,11 @@ def regist_tool( args_schema: Optional[Type[BaseModel]] = None, infer_schema: bool = True, ) -> Union[Callable, BaseTool]: - ''' + """ wrapper of langchain tool decorator add tool to regstiry automatically - ''' + """ + def _parse_tool(t: BaseTool): nonlocal description, title @@ -87,7 +88,7 @@ def _parse_tool(t: BaseTool): if t.func is not None: description = t.func.__doc__ elif t.coroutine is not None: - description = t.coroutine.__doc__ + description = t.coroutine.__doc__ t.description = " ".join(re.split(r"\n+\s*", description)) # set a default title for human if not title: @@ -95,11 +96,12 @@ def _parse_tool(t: BaseTool): t.title = title def wrapper(def_func: Callable) -> BaseTool: - partial_ = tool(*args, - return_direct=return_direct, - args_schema=args_schema, - infer_schema=infer_schema, - ) + partial_ = tool( + *args, + return_direct=return_direct, + args_schema=args_schema, + infer_schema=infer_schema, + ) t = partial_(def_func) _parse_tool(t) return t @@ -107,27 +109,29 @@ def wrapper(def_func: Callable) -> BaseTool: if len(args) == 0: return wrapper else: - t = tool(*args, - return_direct=return_direct, - args_schema=args_schema, - infer_schema=infer_schema, - ) + t = tool( + *args, + return_direct=return_direct, + args_schema=args_schema, + infer_schema=infer_schema, + ) _parse_tool(t) return t class BaseToolOutput: - ''' + """ LLM 要求 Tool 的输出为 str,但 Tool 用在别处时希望它正常返回结构化数据。 只需要将 Tool 返回值用该类封装,能同时满足两者的需要。 基类简单的将返回值字符串化,或指定 format="json" 将其转为 json。 用户也可以继承该类定义自己的转换方法。 - ''' + """ + def __init__( self, data: Any, - format: str="", - data_alias: str="", + format: str = "", + data_alias: str = "", **extras: Any, ) -> None: self.data = data @@ -135,7 +139,7 @@ def __init__( self.extras = extras if data_alias: setattr(self, data_alias, property(lambda obj: obj.data)) - + def __str__(self) -> str: if self.format == "json": return json.dumps(self.data, ensure_ascii=False, indent=2) diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/vqa_processor.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/vqa_processor.py index c1269c7c8..a3d3ca368 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/vqa_processor.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/vqa_processor.py @@ -2,13 +2,16 @@ Method Use cogagent to generate response for a given image and query. """ import base64 +import re from io import BytesIO + from PIL import Image, ImageDraw + +from chatchat.server.agent.container import container from chatchat.server.pydantic_v1 import Field from chatchat.server.utils import get_tool_config -from .tools_registry import regist_tool, BaseToolOutput -import re -from chatchat.server.agent.container import container + +from .tools_registry import BaseToolOutput, regist_tool def extract_between_markers(text, start_marker, end_marker): @@ -20,12 +23,11 @@ def extract_between_markers(text, start_marker, end_marker): if start != -1 and end != -1: # Extract and return the text between the markers, without including the markers themselves - return text[start + len(start_marker):end].strip() + return text[start + len(start_marker) : end].strip() else: return "Text not found between the specified markers" - def draw_box_on_existing_image(base64_image, text): """ 在已有的Base64编码的图片上根据“Grounded Operation”中的坐标信息绘制矩形框。 @@ -46,7 +48,7 @@ def draw_box_on_existing_image(base64_image, text): int(coords[0] * 0.001 * img.width), int(coords[1] * 0.001 * img.height), int(coords[2] * 0.001 * img.width), - int(coords[3] * 0.001 * img.height) + int(coords[3] * 0.001 * img.height), ) draw.rectangle(scaled_coords, outline="red", width=3) @@ -58,8 +60,17 @@ def draw_box_on_existing_image(base64_image, text): return img_base64 -def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", max_length=2048, top_p=0.9, - temperature=1.0): +def vqa_run( + model, + tokenizer, + image_base_64, + query, + history=[], + device="cuda", + max_length=2048, + top_p=0.9, + temperature=1.0, +): """ Args: image_path (str): path to the image @@ -76,23 +87,28 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m image = Image.open(BytesIO(base64.b64decode(image_base_64))) - inputs = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image]) + inputs = model.build_conversation_input_ids( + tokenizer, query=query, history=history, images=[image] + ) inputs = { - 'input_ids': inputs['input_ids'].unsqueeze(0).to(device), - 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(device), - 'attention_mask': inputs['attention_mask'].unsqueeze(0).to(device), - 'images': [[inputs['images'][0].to(device).to(torch.bfloat16)]], - 'cross_images': [[inputs['cross_images'][0].to(device).to(torch.bfloat16)]] if inputs[ - 'cross_images'] else None, + "input_ids": inputs["input_ids"].unsqueeze(0).to(device), + "token_type_ids": inputs["token_type_ids"].unsqueeze(0).to(device), + "attention_mask": inputs["attention_mask"].unsqueeze(0).to(device), + "images": [[inputs["images"][0].to(device).to(torch.bfloat16)]], + "cross_images": [[inputs["cross_images"][0].to(device).to(torch.bfloat16)]] + if inputs["cross_images"] + else None, } - gen_kwargs = {"max_length": max_length, - # "temperature": temperature, - "top_p": top_p, - "do_sample": False} + gen_kwargs = { + "max_length": max_length, + # "temperature": temperature, + "top_p": top_p, + "do_sample": False, + } with torch.no_grad(): outputs = model.generate(**inputs, **gen_kwargs) - outputs = outputs[:, inputs['input_ids'].shape[1]:] + outputs = outputs[:, inputs["input_ids"].shape[1] :] response = tokenizer.decode(outputs[0]) response = response.split("")[0] @@ -100,19 +116,25 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m @regist_tool(title="图片对话") -def vqa_processor(query: str = Field(description="The question of the image in English")): - '''use this tool to get answer for image question''' +def vqa_processor( + query: str = Field(description="The question of the image in English"), +): + """use this tool to get answer for image question""" tool_config = get_tool_config("vqa_processor") if container.metadata["images"]: image_base64 = container.metadata["images"][0] - ans = vqa_run(model=container.vision_model, - tokenizer=container.vision_tokenizer, - query=query + "(with grounding)", - image_base_64=image_base64, - device=tool_config["device"]) + ans = vqa_run( + model=container.vision_model, + tokenizer=container.vision_tokenizer, + query=query + "(with grounding)", + image_base_64=image_base64, + device=tool_config["device"], + ) print(ans) - image_new_base64 = draw_box_on_existing_image(container.metadata["images"][0], ans) + image_new_base64 = draw_box_on_existing_image( + container.metadata["images"][0], ans + ) # Markers # start_marker = "Next Action:draw_box_on_existing_image diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py index 1f3cef1bf..14b06ec9a 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py @@ -1,15 +1,19 @@ """ 简单的单参数输入工具实现,用于查询现在天气的情况 """ +import requests + from chatchat.server.pydantic_v1 import Field from chatchat.server.utils import get_tool_config -from .tools_registry import regist_tool, BaseToolOutput -import requests + +from .tools_registry import BaseToolOutput, regist_tool @regist_tool(title="天气查询") -def weather_check(city: str = Field(description="City name,include city and county,like '厦门'")): - '''Use this tool to check the weather at a specific city''' +def weather_check( + city: str = Field(description="City name,include city and county,like '厦门'"), +): + """Use this tool to check the weather at a specific city""" tool_config = get_tool_config("weather_check") api_key = tool_config.get("api_key") @@ -23,5 +27,4 @@ def weather_check(city: str = Field(description="City name,include city and coun } return BaseToolOutput(weather) else: - raise Exception( - f"Failed to retrieve weather: {response.status_code}") + raise Exception(f"Failed to retrieve weather: {response.status_code}") diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/wolfram.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/wolfram.py index 0bfa2f710..7d9dcc807 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/wolfram.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/wolfram.py @@ -2,15 +2,18 @@ from chatchat.server.pydantic_v1 import Field from chatchat.server.utils import get_tool_config -from .tools_registry import regist_tool, BaseToolOutput + +from .tools_registry import BaseToolOutput, regist_tool @regist_tool def wolfram(query: str = Field(description="The formula to be calculated")): - '''Useful for when you need to calculate difficult formulas''' + """Useful for when you need to calculate difficult formulas""" from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper - wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=get_tool_config("wolfram").get("appid")) + wolfram = WolframAlphaAPIWrapper( + wolfram_alpha_appid=get_tool_config("wolfram").get("appid") + ) ans = wolfram.run(query) return BaseToolOutput(ans) diff --git a/libs/chatchat-server/chatchat/server/api_allinone_stale.py b/libs/chatchat-server/chatchat/server/api_allinone_stale.py index 78a7a6dac..eb0deb064 100644 --- a/libs/chatchat-server/chatchat/server/api_allinone_stale.py +++ b/libs/chatchat-server/chatchat/server/api_allinone_stale.py @@ -9,15 +9,15 @@ python server/api_allinone.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB """ -import sys import os +import sys sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from llm_api_stale import launch_all, parser, controller_args, worker_args, server_args -from api import create_app import uvicorn +from api import create_app +from llm_api_stale import controller_args, launch_all, parser, server_args, worker_args parser.add_argument("--api-host", type=str, default="0.0.0.0") parser.add_argument("--api-port", type=int, default=7861) @@ -30,12 +30,13 @@ def run_api(host, port, **kwargs): app = create_app() if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): - uvicorn.run(app, - host=host, - port=port, - ssl_keyfile=kwargs.get("ssl_keyfile"), - ssl_certfile=kwargs.get("ssl_certfile"), - ) + uvicorn.run( + app, + host=host, + port=port, + ssl_keyfile=kwargs.get("ssl_keyfile"), + ssl_certfile=kwargs.get("ssl_certfile"), + ) else: uvicorn.run(app, host=host, port=port) @@ -46,7 +47,12 @@ def run_api(host, port, **kwargs): # 初始化消息 args = parser.parse_args() args_dict = vars(args) - launch_all(args=args, controller_args=controller_args, worker_args=worker_args, server_args=server_args) + launch_all( + args=args, + controller_args=controller_args, + worker_args=worker_args, + server_args=server_args, + ) run_api( host=args.api_host, port=args.api_port, diff --git a/libs/chatchat-server/chatchat/server/api_server/api_schemas.py b/libs/chatchat-server/chatchat/server/api_server/api_schemas.py index 965be90d2..73eba217b 100644 --- a/libs/chatchat-server/chatchat/server/api_server/api_schemas.py +++ b/libs/chatchat-server/chatchat/server/api_server/api_schemas.py @@ -13,10 +13,11 @@ ) from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE -from chatchat.server.callback_handler.agent_callback_handler import AgentStatus # noaq -from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl +from chatchat.server.callback_handler.agent_callback_handler import AgentStatus # noaq +from chatchat.server.pydantic_v2 import AnyUrl, BaseModel, Field from chatchat.server.utils import MsgType + class OpenAIBaseInput(BaseModel): user: Optional[str] = None # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. @@ -63,7 +64,9 @@ class OpenAIImageBaseInput(OpenAIBaseInput): model: str n: int = 1 response_format: Optional[Literal["url", "b64_json"]] = None - size: Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]] = "256x256" + size: Optional[ + Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] + ] = "256x256" class OpenAIImageGenerationsInput(OpenAIImageBaseInput): @@ -98,7 +101,9 @@ class OpenAIAudioSpeechInput(OpenAIBaseInput): input: str model: str voice: str - response_format: Optional[Literal["mp3", "opus", "aac", "flac", "pcm", "wav"]] = None + response_format: Optional[ + Literal["mp3", "opus", "aac", "flac", "pcm", "wav"] + ] = None speed: Optional[float] = None @@ -111,16 +116,18 @@ class OpenAIBaseOutput(BaseModel): id: Optional[str] = None content: Optional[str] = None model: Optional[str] = None - object: Literal["chat.completion", "chat.completion.chunk"] = "chat.completion.chunk" + object: Literal[ + "chat.completion", "chat.completion.chunk" + ] = "chat.completion.chunk" role: Literal["assistant"] = "assistant" finish_reason: Optional[str] = None - created: int = Field(default_factory=lambda : int(time.time())) + created: int = Field(default_factory=lambda: int(time.time())) tool_calls: List[Dict] = [] - status: Optional[int] = None # AgentStatus + status: Optional[int] = None # AgentStatus message_type: int = MsgType.TEXT - message_id: Optional[str] = None # id in database table - is_ref: bool = False # wheather show in seperated expander + message_id: Optional[str] = None # id in database table + is_ref: bool = False # wheather show in seperated expander class Config: extra = "allow" @@ -131,7 +138,6 @@ def model_dump(self) -> dict: "object": self.object, "model": self.model, "created": self.created, - "status": self.status, "message_type": self.message_type, "message_id": self.message_id, @@ -140,22 +146,26 @@ def model_dump(self) -> dict: } if self.object == "chat.completion.chunk": - result["choices"] = [{ - "delta": { - "content": self.content, - "tool_calls": self.tool_calls, - }, - "role": self.role, - }] - elif self.object == "chat.completion": - result["choices"] = [{ - "message": { + result["choices"] = [ + { + "delta": { + "content": self.content, + "tool_calls": self.tool_calls, + }, "role": self.role, - "content": self.content, - "finish_reason": self.finish_reason, - "tool_calls": self.tool_calls, } - }] + ] + elif self.object == "chat.completion": + result["choices"] = [ + { + "message": { + "role": self.role, + "content": self.content, + "finish_reason": self.finish_reason, + "tool_calls": self.tool_calls, + } + } + ] return result def model_dump_json(self): diff --git a/libs/chatchat-server/chatchat/server/api_server/chat_routes.py b/libs/chatchat-server/chatchat/server/api_server/chat_routes.py index 698b1ab8f..c77c14ba7 100644 --- a/libs/chatchat-server/chatchat/server/api_server/chat_routes.py +++ b/libs/chatchat-server/chatchat/server/api_server/chat_routes.py @@ -1,42 +1,48 @@ from __future__ import annotations -from typing import List, Dict +from typing import Dict, List from fastapi import APIRouter, Request from langchain.prompts.prompt import PromptTemplate -from chatchat.server.api_server.api_schemas import OpenAIChatInput, MsgType, AgentStatus +from chatchat.server.api_server.api_schemas import AgentStatus, MsgType, OpenAIChatInput from chatchat.server.chat.chat import chat from chatchat.server.chat.feedback import chat_feedback from chatchat.server.chat.file_chat import file_chat from chatchat.server.db.repository import add_message_to_db -from chatchat.server.utils import get_OpenAIClient, get_tool, get_tool_config, get_prompt_template -from .openai_routes import openai_request +from chatchat.server.utils import ( + get_OpenAIClient, + get_prompt_template, + get_tool, + get_tool_config, +) +from .openai_routes import openai_request chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"]) -chat_router.post("/chat", - summary="与llm模型对话(通过LLMChain)", - )(chat) +chat_router.post( + "/chat", + summary="与llm模型对话(通过LLMChain)", +)(chat) + +chat_router.post( + "/feedback", + summary="返回llm模型对话评分", +)(chat_feedback) -chat_router.post("/feedback", - summary="返回llm模型对话评分", - )(chat_feedback) +chat_router.post("/file_chat", summary="文件对话")(file_chat) -chat_router.post("/file_chat", - summary="文件对话" - )(file_chat) +# 定义全局model信息,用于给Text2Sql中的get_ChatOpenAI提供model_name +global_model_name = None -#定义全局model信息,用于给Text2Sql中的get_ChatOpenAI提供model_name -global_model_name=None @chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口") async def chat_completions( request: Request, body: OpenAIChatInput, ) -> Dict: - ''' + """ 请求参数与 openai.chat.completions.create 一致,可以通过 extra_body 传入额外参数 tools 和 tool_choice 可以直接传工具名称,会根据项目里包含的 tools 进行转换 通过不同的参数组合调用不同的 chat 功能: @@ -47,14 +53,14 @@ async def chat_completions( - 其它:LLM 对话 以后还要考虑其它的组合(如文件对话) 返回与 openai 兼容的 Dict - ''' + """ client = get_OpenAIClient(model_name=body.model, is_async=True) extra = {**body.model_extra} or {} for key in list(extra): delattr(body, key) global global_model_name - global_model_name=body.model + global_model_name = body.model # check tools & tool_choice in request body if isinstance(body.tool_choice, str): if t := get_tool(body.tool_choice): @@ -69,7 +75,7 @@ async def chat_completions( "name": t.name, "description": t.description, "parameters": t.args, - } + }, } conversation_id = extra.get("conversation_id") @@ -78,67 +84,96 @@ async def chat_completions( if body.tool_choice: tool = get_tool(body.tool_choice["function"]["name"]) if not body.tools: - body.tools = [{ - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.args, - } - }] + body.tools = [ + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.args, + }, + } + ] if tool_input := extra.get("tool_input"): - message_id = add_message_to_db( - chat_type="tool_call", - query=body.messages[-1]["content"], - conversation_id=conversation_id - ) if conversation_id else None + message_id = ( + add_message_to_db( + chat_type="tool_call", + query=body.messages[-1]["content"], + conversation_id=conversation_id, + ) + if conversation_id + else None + ) tool_result = await tool.ainvoke(tool_input) - prompt_template = PromptTemplate.from_template(get_prompt_template("llm_model", "rag"), template_format="jinja2") - body.messages[-1]["content"] = prompt_template.format(context=tool_result, question=body.messages[-1]["content"]) + prompt_template = PromptTemplate.from_template( + get_prompt_template("llm_model", "rag"), template_format="jinja2" + ) + body.messages[-1]["content"] = prompt_template.format( + context=tool_result, question=body.messages[-1]["content"] + ) del body.tools del body.tool_choice extra_json = { "message_id": message_id, "status": None, } - header = [{**extra_json, - "content": f"{tool_result}", - "tool_output":tool_result.data, - "is_ref": True, - }] - return await openai_request(client.chat.completions.create, body, extra_json=extra_json, header=header) + header = [ + { + **extra_json, + "content": f"{tool_result}", + "tool_output": tool_result.data, + "is_ref": True, + } + ] + return await openai_request( + client.chat.completions.create, + body, + extra_json=extra_json, + header=header, + ) # agent chat with tool calls if body.tools: - message_id = add_message_to_db( - chat_type="agent_chat", - query=body.messages[-1]["content"], - conversation_id=conversation_id - ) if conversation_id else None + message_id = ( + add_message_to_db( + chat_type="agent_chat", + query=body.messages[-1]["content"], + conversation_id=conversation_id, + ) + if conversation_id + else None + ) - chat_model_config = {} # TODO: 前端支持配置模型 + chat_model_config = {} # TODO: 前端支持配置模型 tool_names = [x["function"]["name"] for x in body.tools] tool_config = {name: get_tool_config(name) for name in tool_names} - result = await chat(query=body.messages[-1]["content"], - metadata=extra.get("metadata", {}), - conversation_id=extra.get("conversation_id", ""), - message_id=message_id, - history_len=-1, - history=body.messages[:-1], - stream=body.stream, - chat_model_config=extra.get("chat_model_config", chat_model_config), - tool_config=extra.get("tool_config", tool_config), - ) - return result - else: # LLM chat directly - message_id = add_message_to_db( - chat_type="llm_chat", + result = await chat( query=body.messages[-1]["content"], - conversation_id=conversation_id - ) if conversation_id else None + metadata=extra.get("metadata", {}), + conversation_id=extra.get("conversation_id", ""), + message_id=message_id, + history_len=-1, + history=body.messages[:-1], + stream=body.stream, + chat_model_config=extra.get("chat_model_config", chat_model_config), + tool_config=extra.get("tool_config", tool_config), + ) + return result + else: # LLM chat directly + message_id = ( + add_message_to_db( + chat_type="llm_chat", + query=body.messages[-1]["content"], + conversation_id=conversation_id, + ) + if conversation_id + else None + ) extra_json = { "message_id": message_id, "status": None, } - return await openai_request(client.chat.completions.create, body, extra_json=extra_json) + return await openai_request( + client.chat.completions.create, body, extra_json=extra_json + ) diff --git a/libs/chatchat-server/chatchat/server/api_server/kb_routes.py b/libs/chatchat-server/chatchat/server/api_server/kb_routes.py index 0d743cef0..ecf18ce8e 100644 --- a/libs/chatchat-server/chatchat/server/api_server/kb_routes.py +++ b/libs/chatchat-server/chatchat/server/api_server/kb_routes.py @@ -5,84 +5,87 @@ from fastapi import APIRouter, Request from chatchat.server.chat.file_chat import upload_temp_docs -from chatchat.server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb -from chatchat.server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, - update_docs, download_doc, recreate_vector_store, - search_docs, update_info) -from chatchat.server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store, - summary_doc_ids_to_vector_store) +from chatchat.server.knowledge_base.kb_api import create_kb, delete_kb, list_kbs +from chatchat.server.knowledge_base.kb_doc_api import ( + delete_docs, + download_doc, + list_files, + recreate_vector_store, + search_docs, + update_docs, + update_info, + upload_docs, +) +from chatchat.server.knowledge_base.kb_summary_api import ( + recreate_summary_vector_store, + summary_doc_ids_to_vector_store, + summary_file_to_vector_store, +) from chatchat.server.utils import BaseResponse, ListResponse - kb_router = APIRouter(prefix="/knowledge_base", tags=["Knowledge Base Management"]) -kb_router.get("/list_knowledge_bases", - response_model=ListResponse, - summary="获取知识库列表")(list_kbs) +kb_router.get( + "/list_knowledge_bases", response_model=ListResponse, summary="获取知识库列表" +)(list_kbs) -kb_router.post("/create_knowledge_base", - response_model=BaseResponse, - summary="创建知识库" - )(create_kb) +kb_router.post( + "/create_knowledge_base", response_model=BaseResponse, summary="创建知识库" +)(create_kb) -kb_router.post("/delete_knowledge_base", - response_model=BaseResponse, - summary="删除知识库" - )(delete_kb) +kb_router.post( + "/delete_knowledge_base", response_model=BaseResponse, summary="删除知识库" +)(delete_kb) -kb_router.get("/list_files", - response_model=ListResponse, - summary="获取知识库内的文件列表" - )(list_files) +kb_router.get( + "/list_files", response_model=ListResponse, summary="获取知识库内的文件列表" +)(list_files) -kb_router.post("/search_docs", - response_model=List[dict], - summary="搜索知识库" - )(search_docs) +kb_router.post("/search_docs", response_model=List[dict], summary="搜索知识库")( + search_docs +) -kb_router.post("/upload_docs", - response_model=BaseResponse, - summary="上传文件到知识库,并/或进行向量化" - )(upload_docs) +kb_router.post( + "/upload_docs", + response_model=BaseResponse, + summary="上传文件到知识库,并/或进行向量化", +)(upload_docs) -kb_router.post("/delete_docs", - response_model=BaseResponse, - summary="删除知识库内指定文件" - )(delete_docs) +kb_router.post( + "/delete_docs", response_model=BaseResponse, summary="删除知识库内指定文件" +)(delete_docs) -kb_router.post("/update_info", - response_model=BaseResponse, - summary="更新知识库介绍" - )(update_info) +kb_router.post("/update_info", response_model=BaseResponse, summary="更新知识库介绍")( + update_info +) -kb_router.post("/update_docs", - response_model=BaseResponse, - summary="更新现有文件到知识库" - )(update_docs) +kb_router.post( + "/update_docs", response_model=BaseResponse, summary="更新现有文件到知识库" +)(update_docs) -kb_router.get("/download_doc", - summary="下载对应的知识文件")(download_doc) +kb_router.get("/download_doc", summary="下载对应的知识文件")(download_doc) -kb_router.post("/recreate_vector_store", - summary="根据content中文档重建向量库,流式输出处理进度。" - )(recreate_vector_store) +kb_router.post( + "/recreate_vector_store", summary="根据content中文档重建向量库,流式输出处理进度。" +)(recreate_vector_store) -kb_router.post("/upload_temp_docs", - summary="上传文件到临时目录,用于文件对话。" - )(upload_temp_docs) +kb_router.post("/upload_temp_docs", summary="上传文件到临时目录,用于文件对话。")( + upload_temp_docs +) summary_router = APIRouter(prefix="/kb_summary_api") -summary_router.post("/summary_file_to_vector_store", - summary="单个知识库根据文件名称摘要" - )(summary_file_to_vector_store) -summary_router.post("/summary_doc_ids_to_vector_store", - summary="单个知识库根据doc_ids摘要", - response_model=BaseResponse, - )(summary_doc_ids_to_vector_store) -summary_router.post("/recreate_summary_vector_store", - summary="重建单个知识库文件摘要" - )(recreate_summary_vector_store) +summary_router.post( + "/summary_file_to_vector_store", summary="单个知识库根据文件名称摘要" +)(summary_file_to_vector_store) +summary_router.post( + "/summary_doc_ids_to_vector_store", + summary="单个知识库根据doc_ids摘要", + response_model=BaseResponse, +)(summary_doc_ids_to_vector_store) +summary_router.post("/recreate_summary_vector_store", summary="重建单个知识库文件摘要")( + recreate_summary_vector_store +) kb_router.include_router(summary_router) diff --git a/libs/chatchat-server/chatchat/server/api_server/openai_routes.py b/libs/chatchat-server/chatchat/server/api_server/openai_routes.py index 108e144fe..e5fe96153 100644 --- a/libs/chatchat-server/chatchat/server/api_server/openai_routes.py +++ b/libs/chatchat-server/chatchat/server/api_server/openai_routes.py @@ -2,12 +2,13 @@ import asyncio import base64 +import logging +import os +import shutil from contextlib import asynccontextmanager from datetime import datetime -import os from pathlib import Path -import shutil -from typing import Dict, Tuple, AsyncGenerator, Iterable +from typing import AsyncGenerator, Dict, Iterable, Tuple from fastapi import APIRouter, Request from fastapi.responses import FileResponse @@ -15,25 +16,26 @@ from openai.types.file_object import FileObject from sse_starlette.sse import EventSourceResponse -from .api_schemas import * from chatchat.configs import BASE_TEMP_DIR, log_verbose -from chatchat.server.utils import get_model_info, get_config_platforms, get_OpenAIClient +from chatchat.server.utils import get_config_platforms, get_model_info, get_OpenAIClient -import logging +from .api_schemas import * logger = logging.getLogger() -DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数 -model_semaphores: Dict[Tuple[str, str], asyncio.Semaphore] = {} # key: (model_name, platform) +DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数 +model_semaphores: Dict[ + Tuple[str, str], asyncio.Semaphore +] = {} # key: (model_name, platform) openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"]) @asynccontextmanager async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]: - ''' + """ 对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型 - ''' + """ max_semaphore = 0 selected_platform = "" model_infos = get_model_info(model_name=model_name, multiple=True) @@ -60,10 +62,13 @@ async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]: semaphore.release() -async def openai_request(method, body, extra_json: Dict={}, header: Iterable=[], tail: Iterable=[]): - ''' +async def openai_request( + method, body, extra_json: Dict = {}, header: Iterable = [], tail: Iterable = [] +): + """ helper function to make openai request with extra fields - ''' + """ + async def generator(): for x in header: if isinstance(x, str): @@ -105,9 +110,10 @@ async def generator(): @openai_router.get("/models") async def list_models() -> Dict: - ''' + """ 整合所有平台的模型列表。 - ''' + """ + async def task(name: str, config: Dict): try: client = get_OpenAIClient(name, is_async=True) @@ -118,9 +124,12 @@ async def task(name: str, config: Dict): return [] result = [] - tasks = [asyncio.create_task(task(name, config)) for name, config in get_config_platforms().items()] + tasks = [ + asyncio.create_task(task(name, config)) + for name, config in get_config_platforms().items() + ] for t in asyncio.as_completed(tasks): - result += (await t) + result += await t return {"object": "list", "data": result} @@ -182,7 +191,7 @@ async def create_image_edit( async with get_model_client(body.model) as client: return await openai_request(client.images.edit, body) - + @openai_router.post("/audio/translations", deprecated="暂不支持") async def create_audio_translations( request: Request, @@ -248,7 +257,9 @@ async def files( purpose: str = "assistants", ) -> Dict: created_at = int(datetime.now().timestamp()) - file_id = _get_file_id(purpose=purpose, created_at=created_at, filename=file.filename) + file_id = _get_file_id( + purpose=purpose, created_at=created_at, filename=file.filename + ) file_path = _get_file_path(file_id) file_dir = os.path.dirname(file_path) os.makedirs(file_dir, exist_ok=True) @@ -273,9 +284,13 @@ def list_files(purpose: str) -> Dict[str, List[Dict]]: for dir, sub_dirs, files in os.walk(root_path): dir = Path(dir).relative_to(root_path).as_posix() for file in files: - file_id = base64.urlsafe_b64encode(f"{purpose}/{dir}/{file}".encode()).decode() + file_id = base64.urlsafe_b64encode( + f"{purpose}/{dir}/{file}".encode() + ).decode() file_ids.append(file_id) - return {"data": [{**_get_file_info(x), "id":x, "object": "file"} for x in file_ids]} + return { + "data": [{**_get_file_info(x), "id": x, "object": "file"} for x in file_ids] + } @openai_router.get("/files/{file_id}") diff --git a/libs/chatchat-server/chatchat/server/api_server/server_app.py b/libs/chatchat-server/chatchat/server/api_server/server_app.py index fcd5d2e07..ed30055fe 100644 --- a/libs/chatchat-server/chatchat/server/api_server/server_app.py +++ b/libs/chatchat-server/chatchat/server/api_server/server_app.py @@ -2,14 +2,13 @@ import os from typing import Literal -from fastapi import FastAPI, Body +import uvicorn +from fastapi import Body, FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from starlette.responses import RedirectResponse -import uvicorn -from chatchat.configs import VERSION, MEDIA_PATH, CHATCHAT_ROOT -from chatchat.configs import OPEN_CROSS_DOMAIN +from chatchat.configs import CHATCHAT_ROOT, MEDIA_PATH, OPEN_CROSS_DOMAIN, VERSION from chatchat.server.api_server.chat_routes import chat_router from chatchat.server.api_server.kb_routes import kb_router from chatchat.server.api_server.openai_routes import openai_router @@ -19,11 +18,8 @@ from chatchat.server.utils import MakeFastAPIOffline -def create_app(run_mode: str=None): - app = FastAPI( - title="Langchain-Chatchat API Server", - version=VERSION - ) +def create_app(run_mode: str = None): + app = FastAPI(title="Langchain-Chatchat API Server", version=VERSION) MakeFastAPIOffline(app) # Add CORS middleware to allow all origins # 在config.py中设置OPEN_DOMAIN=True,允许跨域 @@ -48,10 +44,11 @@ async def document(): app.include_router(server_router) # 其它接口 - app.post("/other/completion", - tags=["Other"], - summary="要求llm模型补全(通过LLMChain)", - )(completion) + app.post( + "/other/completion", + tags=["Other"], + summary="要求llm模型补全(通过LLMChain)", + )(completion) # 媒体文件 app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media") @@ -65,22 +62,26 @@ async def document(): def run_api(host, port, **kwargs): if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): - uvicorn.run(app, - host=host, - port=port, - ssl_keyfile=kwargs.get("ssl_keyfile"), - ssl_certfile=kwargs.get("ssl_certfile"), - ) + uvicorn.run( + app, + host=host, + port=port, + ssl_keyfile=kwargs.get("ssl_keyfile"), + ssl_certfile=kwargs.get("ssl_certfile"), + ) else: uvicorn.run(app, host=host, port=port) + app = create_app() if __name__ == "__main__": - parser = argparse.ArgumentParser(prog='langchain-ChatGLM', - description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain' - ' | 基于本地知识库的 ChatGLM 问答') + parser = argparse.ArgumentParser( + prog="langchain-ChatGLM", + description="About langchain-ChatGLM, local knowledge based ChatGLM with langchain" + " | 基于本地知识库的 ChatGLM 问答", + ) parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7861) parser.add_argument("--ssl_keyfile", type=str) @@ -89,8 +90,9 @@ def run_api(host, port, **kwargs): args = parser.parse_args() args_dict = vars(args) - run_api(host=args.host, - port=args.port, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ) + run_api( + host=args.host, + port=args.port, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ) diff --git a/libs/chatchat-server/chatchat/server/api_server/server_routes.py b/libs/chatchat-server/chatchat/server/api_server/server_routes.py index aefd0d5c7..f487419af 100644 --- a/libs/chatchat-server/chatchat/server/api_server/server_routes.py +++ b/libs/chatchat-server/chatchat/server/api_server/server_routes.py @@ -2,22 +2,23 @@ from fastapi import APIRouter, Body -from chatchat.server.utils import get_server_configs, get_prompt_template - +from chatchat.server.utils import get_prompt_template, get_server_configs server_router = APIRouter(prefix="/server", tags=["Server State"]) # 服务器相关接口 -server_router.post("/configs", - summary="获取服务器原始配置信息", - )(get_server_configs) +server_router.post( + "/configs", + summary="获取服务器原始配置信息", +)(get_server_configs) -@server_router.post("/get_prompt_template", - summary="获取服务区配置的 prompt 模板") +@server_router.post("/get_prompt_template", summary="获取服务区配置的 prompt 模板") def get_server_prompt_template( - type: Literal["llm_chat", "knowledge_base_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat"), + type: Literal["llm_chat", "knowledge_base_chat"] = Body( + "llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat" + ), name: str = Body("default", description="模板名称"), ) -> str: return get_prompt_template(type=type, name=name) diff --git a/libs/chatchat-server/chatchat/server/api_server/tool_routes.py b/libs/chatchat-server/chatchat/server/api_server/tool_routes.py index 0a55299ce..7907f13bd 100644 --- a/libs/chatchat-server/chatchat/server/api_server/tool_routes.py +++ b/libs/chatchat-server/chatchat/server/api_server/tool_routes.py @@ -1,13 +1,12 @@ from __future__ import annotations +import logging from typing import List -from fastapi import APIRouter, Request, Body +from fastapi import APIRouter, Body, Request from chatchat.server.utils import BaseResponse, get_tool, get_tool_config -import logging - logger = logging.getLogger() tool_router = APIRouter(prefix="/tools", tags=["Toolkits"]) @@ -16,13 +15,16 @@ @tool_router.get("", response_model=BaseResponse) async def list_tools(): tools = get_tool() - data = {t.name: { - "name": t.name, - "title": t.title, - "description": t.description, - "args": t.args, - "config": get_tool_config(t.name), - } for t in tools.values()} + data = { + t.name: { + "name": t.name, + "title": t.title, + "description": t.description, + "args": t.args, + "config": get_tool_config(t.name), + } + for t in tools.values() + } return {"data": data} diff --git a/libs/chatchat-server/chatchat/server/callback_handler/agent_callback_handler.py b/libs/chatchat-server/chatchat/server/callback_handler/agent_callback_handler.py index 6f01662c3..157f9681c 100644 --- a/libs/chatchat-server/chatchat/server/callback_handler/agent_callback_handler.py +++ b/libs/chatchat-server/chatchat/server/callback_handler/agent_callback_handler.py @@ -1,11 +1,12 @@ from __future__ import annotations -from uuid import UUID -import json + import asyncio +import json from typing import Any, Dict, List, Optional +from uuid import UUID from langchain.callbacks import AsyncIteratorCallbackHandler -from langchain.schema import AgentFinish, AgentAction +from langchain.schema import AgentAction, AgentFinish from langchain_core.outputs import LLMResult @@ -31,7 +32,9 @@ def __init__(self): self.done = asyncio.Event() self.out = True - async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None: + async def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: data = { "status": AgentStatus.llm_start, "text": "", @@ -60,15 +63,15 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: self.queue.put_nowait(dumps(data)) async def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, + self, + serialized: Dict[str, Any], + messages: List[List], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> None: data = { "status": AgentStatus.llm_start, @@ -84,7 +87,9 @@ async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: } self.queue.put_nowait(dumps(data)) - async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None: + async def on_llm_error( + self, error: Exception | KeyboardInterrupt, **kwargs: Any + ) -> None: data = { "status": AgentStatus.error, "text": str(error), @@ -92,33 +97,32 @@ async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any self.queue.put_nowait(dumps(data)) async def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> None: - data = { - "run_id": str(run_id), - "status": AgentStatus.tool_start, - "tool": serialized["name"], - "tool_input": input_str, - } - self.queue.put_nowait(dumps(data)) - + data = { + "run_id": str(run_id), + "status": AgentStatus.tool_start, + "tool": serialized["name"], + "tool_input": input_str, + } + self.queue.put_nowait(dumps(data)) async def on_tool_end( - self, - output: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, + self, + output: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, ) -> None: """Run when tool ends running.""" data = { @@ -130,13 +134,13 @@ async def on_tool_end( self.queue.put_nowait(dumps(data)) async def on_tool_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, ) -> None: """Run when tool errors.""" data = { @@ -149,13 +153,13 @@ async def on_tool_error( self.queue.put_nowait(dumps(data)) async def on_agent_action( - self, - action: AgentAction, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, + self, + action: AgentAction, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, ) -> None: data = { "status": AgentStatus.agent_action, @@ -166,12 +170,18 @@ async def on_agent_action( self.queue.put_nowait(dumps(data)) async def on_agent_finish( - self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, + self, + finish: AgentFinish, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, ) -> None: if "Thought:" in finish.return_values["output"]: - finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "") + finish.return_values["output"] = finish.return_values["output"].replace( + "Thought:", "" + ) data = { "status": AgentStatus.agent_finish, @@ -179,7 +189,14 @@ async def on_agent_finish( } self.queue.put_nowait(dumps(data)) - - async def on_chain_end(self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None: + async def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: List[str] | None = None, + **kwargs: Any, + ) -> None: self.done.set() self.out = True diff --git a/libs/chatchat-server/chatchat/server/callback_handler/conversation_callback_handler.py b/libs/chatchat-server/chatchat/server/callback_handler/conversation_callback_handler.py index b9bd5fe44..d53f6a9ce 100644 --- a/libs/chatchat-server/chatchat/server/callback_handler/conversation_callback_handler.py +++ b/libs/chatchat-server/chatchat/server/callback_handler/conversation_callback_handler.py @@ -2,13 +2,16 @@ from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import LLMResult + from chatchat.server.db.repository import update_message class ConversationCallbackHandler(BaseCallbackHandler): raise_error: bool = True - def __init__(self, conversation_id: str, message_id: str, chat_type: str, query: str): + def __init__( + self, conversation_id: str, message_id: str, chat_type: str, query: str + ): self.conversation_id = conversation_id self.message_id = message_id self.chat_type = chat_type @@ -21,7 +24,7 @@ def always_verbose(self) -> bool: return True def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: # TODO 如果想存更多信息,则 prompts 也需要持久化,不用的提示词需要特殊支持 pass diff --git a/libs/chatchat-server/chatchat/server/chat/chat.py b/libs/chatchat-server/chatchat/server/chat/chat.py index 27d18758a..2494c254a 100644 --- a/libs/chatchat-server/chatchat/server/chat/chat.py +++ b/libs/chatchat-server/chatchat/server/chat/chat.py @@ -1,26 +1,36 @@ import asyncio import json import time -from typing import AsyncIterable, List import uuid +from typing import AsyncIterable, List from fastapi import Body -from sse_starlette.sse import EventSourceResponse -from langchain_core.output_parsers import StrOutputParser -from langchain_core.messages import AIMessage, HumanMessage, convert_to_messages - from langchain.chains import LLMChain -from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts import PromptTemplate +from langchain.prompts.chat import ChatPromptTemplate +from langchain_core.messages import AIMessage, HumanMessage, convert_to_messages +from langchain_core.output_parsers import StrOutputParser +from sse_starlette.sse import EventSourceResponse from chatchat.configs import LLM_MODEL_CONFIG from chatchat.server.agent.agent_factory.agents_registry import agents_registry from chatchat.server.agent.container import container from chatchat.server.api_server.api_schemas import OpenAIChatOutput -from chatchat.server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType, get_tool +from chatchat.server.callback_handler.agent_callback_handler import ( + AgentExecutorAsyncIteratorCallbackHandler, + AgentStatus, +) from chatchat.server.chat.utils import History -from chatchat.server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory -from chatchat.server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus +from chatchat.server.memory.conversation_db_buffer_memory import ( + ConversationBufferDBMemory, +) +from chatchat.server.utils import ( + MsgType, + get_ChatOpenAI, + get_prompt_template, + get_tool, + wrap_done, +) def create_models_from_config(configs, callbacks, stream): @@ -29,86 +39,85 @@ def create_models_from_config(configs, callbacks, stream): prompts = {} for model_type, model_configs in configs.items(): for model_name, params in model_configs.items(): - callbacks = callbacks if params.get('callbacks', False) else None + callbacks = callbacks if params.get("callbacks", False) else None model_instance = get_ChatOpenAI( model_name=model_name, - temperature=params.get('temperature', 0.5), - max_tokens=params.get('max_tokens', 1000), + temperature=params.get("temperature", 0.5), + max_tokens=params.get("max_tokens", 1000), callbacks=callbacks, streaming=stream, local_wrap=True, ) models[model_type] = model_instance - prompt_name = params.get('prompt_name', 'default') + prompt_name = params.get("prompt_name", "default") prompt_template = get_prompt_template(type=model_type, name=prompt_name) prompts[model_type] = prompt_template return models, prompts -def create_models_chains(history, history_len, prompts, models, tools, callbacks, conversation_id, metadata): +def create_models_chains( + history, history_len, prompts, models, tools, callbacks, conversation_id, metadata +): memory = None chat_prompt = None container.metadata = metadata if history: history = [History.from_data(h) for h in history] - input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False) + input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template( + False + ) chat_prompt = ChatPromptTemplate.from_messages( - [i.to_msg_template() for i in history] + [input_msg]) + [i.to_msg_template() for i in history] + [input_msg] + ) elif conversation_id and history_len > 0: memory = ConversationBufferDBMemory( conversation_id=conversation_id, llm=models["llm_model"], - message_limit=history_len + message_limit=history_len, ) else: - input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False) + input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template( + False + ) chat_prompt = ChatPromptTemplate.from_messages([input_msg]) - llm=models["llm_model"] + llm = models["llm_model"] llm.callbacks = callbacks - chain = LLMChain( - prompt=chat_prompt, - llm=llm, - memory=memory - ) + chain = LLMChain(prompt=chat_prompt, llm=llm, memory=memory) if "action_model" in models and tools is not None: agent_executor = agents_registry( - llm=llm, - callbacks=callbacks, - tools=tools, - prompt=None, - verbose=True + llm=llm, callbacks=callbacks, tools=tools, prompt=None, verbose=True ) - full_chain = ({"input": lambda x: x["input"]} | agent_executor) + full_chain = {"input": lambda x: x["input"]} | agent_executor else: chain.llm.callbacks = callbacks - full_chain = ({"input": lambda x: x["input"]} | chain) + full_chain = {"input": lambda x: x["input"]} | chain return full_chain -async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), - metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]), - conversation_id: str = Body("", description="对话框ID"), - message_id: str = Body(None, description="数据库消息ID"), - history_len: int = Body(-1, description="从数据库中取历史消息的数量"), - history: List[History] = Body( - [], - description="历史对话,设为一个整数可以从数据库中读取历史消息", - examples=[ - [ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", "content": "虎头虎脑"} - ] - ] - ), - stream: bool = Body(True, description="流式输出"), - chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]), - tool_config: dict = Body({}, description="工具配置", examples=[]), - ): - '''Agent 对话''' +async def chat( + query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), + metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]), + conversation_id: str = Body("", description="对话框ID"), + message_id: str = Body(None, description="数据库消息ID"), + history_len: int = Body(-1, description="从数据库中取历史消息的数量"), + history: List[History] = Body( + [], + description="历史对话,设为一个整数可以从数据库中读取历史消息", + examples=[ + [ + {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", "content": "虎头虎脑"}, + ] + ], + ), + stream: bool = Body(True, description="流式输出"), + chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]), + tool_config: dict = Body({}, description="工具配置", examples=[]), +): + """Agent 对话""" async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]: callback = AgentExecutorAsyncIteratorCallbackHandler() @@ -116,41 +125,50 @@ async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]: # Enable langchain-chatchat to support langfuse import os - langfuse_secret_key = os.environ.get('LANGFUSE_SECRET_KEY') - langfuse_public_key = os.environ.get('LANGFUSE_PUBLIC_KEY') - langfuse_host = os.environ.get('LANGFUSE_HOST') - if langfuse_secret_key and langfuse_public_key and langfuse_host : + + langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY") + langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY") + langfuse_host = os.environ.get("LANGFUSE_HOST") + if langfuse_secret_key and langfuse_public_key and langfuse_host: from langfuse import Langfuse from langfuse.callback import CallbackHandler + langfuse_handler = CallbackHandler() callbacks.append(langfuse_handler) - models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config, - stream=stream) + models, prompts = create_models_from_config( + callbacks=callbacks, configs=chat_model_config, stream=stream + ) all_tools = get_tool().values() tools = [tool for tool in all_tools if tool.name in tool_config] tools = [t.copy(update={"callbacks": callbacks}) for t in tools] - full_chain = create_models_chains(prompts=prompts, - models=models, - conversation_id=conversation_id, - tools=tools, - callbacks=callbacks, - history=history, - history_len=history_len, - metadata=metadata) + full_chain = create_models_chains( + prompts=prompts, + models=models, + conversation_id=conversation_id, + tools=tools, + callbacks=callbacks, + history=history, + history_len=history_len, + metadata=metadata, + ) _history = [History.from_data(h) for h in history] chat_history = [h.to_msg_tuple() for h in _history] history_message = convert_to_messages(chat_history) - task = asyncio.create_task(wrap_done( - full_chain.ainvoke( - { - "input": query, - "chat_history": history_message, - } - ), callback.done)) + task = asyncio.create_task( + wrap_done( + full_chain.ainvoke( + { + "input": query, + "chat_history": history_message, + } + ), + callback.done, + ) + ) last_tool = {} async for chunk in callback.aiter(): @@ -174,7 +192,7 @@ async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]: if data["status"] in [AgentStatus.tool_end]: last_tool.update( tool_output=data["tool_output"], - is_error=data.get("is_error", False) + is_error=data.get("is_error", False), ) data["tool_calls"] = [last_tool] last_tool = {} @@ -199,8 +217,8 @@ async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]: role="assistant", tool_calls=data["tool_calls"], model=models["llm_model"].model_name, - status = data["status"], - message_type = data["message_type"], + status=data["status"], + message_type=data["message_type"], message_id=message_id, ) yield ret.model_dump_json() @@ -212,7 +230,7 @@ async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]: # model=models["llm_model"].model_name, # status = data["status"], # message_type = data["message_type"], - # message_id=message_id, + # message_id=message_id, # ) await task @@ -226,8 +244,8 @@ async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]: role="assistant", finish_reason="stop", tool_calls=[], - status = AgentStatus.agent_finish, - message_type = MsgType.TEXT, + status=AgentStatus.agent_finish, + message_type=MsgType.TEXT, message_id=message_id, ) diff --git a/libs/chatchat-server/chatchat/server/chat/completion.py b/libs/chatchat-server/chatchat/server/chat/completion.py index 8d67f3eab..69f7abdc8 100644 --- a/libs/chatchat-server/chatchat/server/chat/completion.py +++ b/libs/chatchat-server/chatchat/server/chat/completion.py @@ -1,32 +1,36 @@ +import asyncio +from typing import AsyncIterable, Optional + from fastapi import Body -from sse_starlette.sse import EventSourceResponse -from chatchat.server.utils import wrap_done, get_OpenAI -from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler -from typing import AsyncIterable, Optional -import asyncio +from langchain.chains import LLMChain from langchain.prompts import PromptTemplate +from sse_starlette.sse import EventSourceResponse -from chatchat.server.utils import get_prompt_template - +from chatchat.server.utils import get_OpenAI, get_prompt_template, wrap_done -async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), - stream: bool = Body(False, description="流式输出"), - echo: bool = Body(False, description="除了输出之外,还回显输入"), - model_name: str = Body(None, description="LLM 模型名称。"), - temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"), - # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), - prompt_name: str = Body("default", - description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), - ): - #TODO: 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理 - async def completion_iterator(query: str, - model_name: str = None, - prompt_name: str = prompt_name, - echo: bool = echo, - ) -> AsyncIterable[str]: +async def completion( + query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), + stream: bool = Body(False, description="流式输出"), + echo: bool = Body(False, description="除了输出之外,还回显输入"), + model_name: str = Body(None, description="LLM 模型名称。"), + temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: Optional[int] = Body( + 1024, description="限制LLM生成Token数量,默认None代表模型最大值" + ), + # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), + prompt_name: str = Body( + "default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)" + ), +): + # TODO: 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理 + async def completion_iterator( + query: str, + model_name: str = None, + prompt_name: str = prompt_name, + echo: bool = echo, + ) -> AsyncIterable[str]: nonlocal max_tokens callback = AsyncIteratorCallbackHandler() if isinstance(max_tokens, int) and max_tokens <= 0: @@ -46,9 +50,8 @@ async def completion_iterator(query: str, chain = LLMChain(prompt=prompt, llm=model) # Begin a task that runs in the background. - task = asyncio.create_task(wrap_done( - chain.acall({"input": query}), - callback.done), + task = asyncio.create_task( + wrap_done(chain.acall({"input": query}), callback.done), ) if stream: @@ -63,7 +66,8 @@ async def completion_iterator(query: str, await task - return EventSourceResponse(completion_iterator(query=query, - model_name=model_name, - prompt_name=prompt_name), - ) + return EventSourceResponse( + completion_iterator( + query=query, model_name=model_name, prompt_name=prompt_name + ), + ) diff --git a/libs/chatchat-server/chatchat/server/chat/feedback.py b/libs/chatchat-server/chatchat/server/chat/feedback.py index 56227ff0a..c184eabbf 100644 --- a/libs/chatchat-server/chatchat/server/chat/feedback.py +++ b/libs/chatchat-server/chatchat/server/chat/feedback.py @@ -1,23 +1,26 @@ +import logging + from fastapi import Body + from chatchat.configs import log_verbose -from chatchat.server.utils import BaseResponse from chatchat.server.db.repository import feedback_message_to_db - -import logging +from chatchat.server.utils import BaseResponse logger = logging.getLogger() -def chat_feedback(message_id: str = Body("", max_length=32, description="聊天记录id"), - score: int = Body(0, max=100, description="用户评分,满分100,越大表示评价越高"), - reason: str = Body("", description="用户评分理由,比如不符合事实等") - ): +def chat_feedback( + message_id: str = Body("", max_length=32, description="聊天记录id"), + score: int = Body(0, max=100, description="用户评分,满分100,越大表示评价越高"), + reason: str = Body("", description="用户评分理由,比如不符合事实等"), +): try: feedback_message_to_db(message_id, score, reason) except Exception as e: msg = f"反馈聊天记录出错: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", exc_info=e if log_verbose else None + ) return BaseResponse(code=500, msg=msg) return BaseResponse(code=200, msg=f"已反馈聊天记录 {message_id}") diff --git a/libs/chatchat-server/chatchat/server/chat/file_chat.py b/libs/chatchat-server/chatchat/server/chat/file_chat.py index f74e5a7aa..2f1843e86 100644 --- a/libs/chatchat-server/chatchat/server/chat/file_chat.py +++ b/libs/chatchat-server/chatchat/server/chat/file_chat.py @@ -1,31 +1,45 @@ -from fastapi import Body, File, Form, UploadFile -from sse_starlette.sse import EventSourceResponse +import asyncio +import json +import logging +import os +from typing import AsyncIterable, List, Optional -from chatchat.configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) -from chatchat.server.utils import (wrap_done, get_ChatOpenAI, get_Embeddings, - BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool) -from chatchat.server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool -from langchain.chains import LLMChain +import nest_asyncio +from fastapi import Body, File, Form, UploadFile from langchain.callbacks import AsyncIteratorCallbackHandler -from typing import AsyncIterable, List, Optional -import asyncio +from langchain.chains import LLMChain from langchain.prompts.chat import ChatPromptTemplate +from sse_starlette.sse import EventSourceResponse + +from chatchat.configs import ( + CHUNK_SIZE, + OVERLAP_SIZE, + SCORE_THRESHOLD, + VECTOR_SEARCH_TOP_K, + ZH_TITLE_ENHANCE, +) from chatchat.server.chat.utils import History +from chatchat.server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool from chatchat.server.knowledge_base.utils import KnowledgeFile -import json -import os -import nest_asyncio -import logging +from chatchat.server.utils import ( + BaseResponse, + get_ChatOpenAI, + get_Embeddings, + get_prompt_template, + get_temp_dir, + run_in_thread_pool, + wrap_done, +) logger = logging.getLogger(__name__) def _parse_files_in_thread( - files: List[UploadFile], - dir: str, - zh_title_enhance: bool, - chunk_size: int, - chunk_overlap: int, + files: List[UploadFile], + dir: str, + zh_title_enhance: bool, + chunk_size: int, + chunk_overlap: int, ): """ 通过多线程将上传的文件保存到对应目录内。 @@ -33,9 +47,9 @@ def _parse_files_in_thread( """ def parse_file(file: UploadFile) -> dict: - ''' + """ 保存单个文件。 - ''' + """ try: filename = file.filename file_path = os.path.join(dir, filename) @@ -47,9 +61,11 @@ def parse_file(file: UploadFile) -> dict: f.write(file_content) kb_file = KnowledgeFile(filename=filename, knowledge_base_name="temp") kb_file.filepath = file_path - docs = kb_file.file2text(zh_title_enhance=zh_title_enhance, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap) + docs = kb_file.file2text( + zh_title_enhance=zh_title_enhance, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) return True, filename, f"成功上传文件 {filename}", docs except Exception as e: msg = f"{filename} 文件上传失败,报错信息为: {e}" @@ -61,11 +77,11 @@ def parse_file(file: UploadFile) -> dict: def upload_temp_docs( - files: List[UploadFile] = File(..., description="上传文件,支持多文件"), - prev_id: str = Form(None, description="前知识库ID"), - chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), - chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), - zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), + files: List[UploadFile] = File(..., description="上传文件,支持多文件"), + prev_id: str = Form(None, description="前知识库ID"), + chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), + chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), + zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), ) -> BaseResponse: """ 将文件保存到临时目录,并进行向量化。 @@ -77,17 +93,18 @@ def upload_temp_docs( failed_files = [] documents = [] path, id = get_temp_dir(prev_id) - for success, file, msg, docs in _parse_files_in_thread(files=files, - dir=path, - zh_title_enhance=zh_title_enhance, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap): + for success, file, msg, docs in _parse_files_in_thread( + files=files, + dir=path, + zh_title_enhance=zh_title_enhance, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ): if success: documents += docs else: failed_files.append({file: msg}) try: - with memo_faiss_pool.load_vector_store(kb_name=id).acquire() as vs: vs.add_documents(documents) except Exception as e: @@ -96,31 +113,44 @@ def upload_temp_docs( return BaseResponse(data={"id": id, "failed_files": failed_files}) -async def file_chat(query: str = Body(..., description="用户输入", examples=["你好"]), - knowledge_id: str = Body(..., description="临时知识库ID"), - top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), - score_threshold: float = Body(SCORE_THRESHOLD, - description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", - ge=0, le=2), - history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", - "content": "虎头虎脑"}]] - ), - stream: bool = Body(False, description="流式输出"), - model_name: str = Body(None, description="LLM 模型名称。"), - temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), - prompt_name: str = Body("rag_default", - description="使用的prompt模板名称(在configs/_prompt_config.py中配置)"), - ): +async def file_chat( + query: str = Body(..., description="用户输入", examples=["你好"]), + knowledge_id: str = Body(..., description="临时知识库ID"), + top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + score_threshold: float = Body( + SCORE_THRESHOLD, + description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", + ge=0, + le=2, + ), + history: List[History] = Body( + [], + description="历史对话", + examples=[ + [ + {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", "content": "虎头虎脑"}, + ] + ], + ), + stream: bool = Body(False, description="流式输出"), + model_name: str = Body(None, description="LLM 模型名称。"), + temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: Optional[int] = Body( + None, description="限制LLM生成Token数量,默认None代表模型最大值" + ), + prompt_name: str = Body( + "rag_default", + description="使用的prompt模板名称(在configs/_prompt_config.py中配置)", + ), +): if knowledge_id not in memo_faiss_pool.keys(): # return BaseResponse(code=404, msg=f"未找到临时知识库 {knowledge_id},请先上传文件") - return BaseResponse(code=404, msg=f"""[冲!]欢迎试用【环评查特助手】\r\n -请先上传环评报告等文件以启用编制行为监督检查、项目风险评测分析等功能!""") + return BaseResponse( + code=404, + msg=f"""[冲!]欢迎试用【环评查特助手】\r\n +请先上传环评报告等文件以启用编制行为监督检查、项目风险评测分析等功能!""", + ) history = [History.from_data(h) for h in history] @@ -133,12 +163,14 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]: callbacks = [callback] # Enable langchain-chatchat to support langfuse import os - langfuse_secret_key = os.environ.get('LANGFUSE_SECRET_KEY') - langfuse_public_key = os.environ.get('LANGFUSE_PUBLIC_KEY') - langfuse_host = os.environ.get('LANGFUSE_HOST') + + langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY") + langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY") + langfuse_host = os.environ.get("LANGFUSE_HOST") if langfuse_secret_key and langfuse_public_key and langfuse_host: from langfuse import Langfuse from langfuse.callback import CallbackHandler + langfuse_handler = CallbackHandler() callbacks.append(langfuse_handler) @@ -151,24 +183,31 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]: embed_func = get_Embeddings() embeddings = await embed_func.aembed_query(query) with memo_faiss_pool.acquire(knowledge_id) as vs: - docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) + docs = vs.similarity_search_with_score_by_vector( + embeddings, k=top_k, score_threshold=score_threshold + ) docs = [x[0] for x in docs] context = "\n".join([doc.page_content for doc in docs]) if len(docs) == 0: # 如果没有找到相关文档,使用Empty模板 - prompt_template = get_prompt_template("llm_model", "rag_default" if prompt_name == "rag_default" else prompt_name) + prompt_template = get_prompt_template( + "llm_model", + "rag_default" if prompt_name == "rag_default" else prompt_name, + ) else: prompt_template = get_prompt_template("llm_model", "rag") input_msg = History(role="user", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( - [i.to_msg_template() for i in history] + [input_msg]) + [i.to_msg_template() for i in history] + [input_msg] + ) chain = LLMChain(prompt=chat_prompt, llm=model) # Begin a task that runs in the background. - task = asyncio.create_task(wrap_done( - chain.acall({"context": context, "question": query}), - callback.done), + task = asyncio.create_task( + wrap_done( + chain.acall({"context": context, "question": query}), callback.done + ), ) source_documents = [] @@ -178,7 +217,9 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]: source_documents.append(text) if len(source_documents) == 0: # 没有找到相关文档 - source_documents.append(f"""未找到相关文档,该回答为大模型自身能力解答!""") + source_documents.append( + f"""未找到相关文档,该回答为大模型自身能力解答!""" + ) if stream: async for token in callback.aiter(): @@ -189,9 +230,9 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]: answer = "" async for token in callback.aiter(): answer += token - yield json.dumps({"answer": answer, - "docs": source_documents}, - ensure_ascii=False) + yield json.dumps( + {"answer": answer, "docs": source_documents}, ensure_ascii=False + ) await task return EventSourceResponse(knowledge_base_chat_iterator()) diff --git a/libs/chatchat-server/chatchat/server/chat/utils.py b/libs/chatchat-server/chatchat/server/chat/utils.py index 38634cc85..5a5b5d41f 100644 --- a/libs/chatchat-server/chatchat/server/chat/utils.py +++ b/libs/chatchat-server/chatchat/server/chat/utils.py @@ -1,9 +1,10 @@ +import logging from functools import lru_cache -from chatchat.server.pydantic_v2 import BaseModel, Field +from typing import Dict, List, Tuple, Union + from langchain.prompts.chat import ChatMessagePromptTemplate -from typing import List, Tuple, Dict, Union -import logging +from chatchat.server.pydantic_v2 import BaseModel, Field logger = logging.getLogger() @@ -16,11 +17,12 @@ class History(BaseModel): 也可转换为tuple,如 h.to_msy_tuple = ("human", "你好") """ + role: str = Field(...) content: str = Field(...) def to_msg_tuple(self): - return "ai" if self.role=="assistant" else "human", self.content + return "ai" if self.role == "assistant" else "human", self.content def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate: role_maps = { @@ -28,7 +30,7 @@ def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate: "human": "user", } role = role_maps.get(self.role, self.role) - if is_raw: # 当前默认历史消息都是没有input_variable的文本。 + if is_raw: # 当前默认历史消息都是没有input_variable的文本。 content = "{% raw %}" + self.content + "{% endraw %}" else: content = self.content @@ -41,7 +43,7 @@ def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate: @classmethod def from_data(cls, h: Union[List, Tuple, Dict]) -> "History": - if isinstance(h, (list,tuple)) and len(h) >= 2: + if isinstance(h, (list, tuple)) and len(h) >= 2: h = cls(role=h[0], content=h[1]) elif isinstance(h, dict): h = cls(**h) diff --git a/libs/chatchat-server/chatchat/server/db/base.py b/libs/chatchat-server/chatchat/server/db/base.py index 1b9cba9c4..7790b533e 100644 --- a/libs/chatchat-server/chatchat/server/db/base.py +++ b/libs/chatchat-server/chatchat/server/db/base.py @@ -1,10 +1,10 @@ +import json + from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta +from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base from sqlalchemy.orm import sessionmaker from chatchat.configs import SQLALCHEMY_DATABASE_URI -import json - engine = create_engine( SQLALCHEMY_DATABASE_URI, diff --git a/libs/chatchat-server/chatchat/server/db/models/base.py b/libs/chatchat-server/chatchat/server/db/models/base.py index 706464f75..17e96a4c0 100644 --- a/libs/chatchat-server/chatchat/server/db/models/base.py +++ b/libs/chatchat-server/chatchat/server/db/models/base.py @@ -1,13 +1,17 @@ from datetime import datetime -from sqlalchemy import Column, DateTime, String, Integer + +from sqlalchemy import Column, DateTime, Integer, String class BaseModel: """ 基础模型 """ + id = Column(Integer, primary_key=True, index=True, comment="主键ID") create_time = Column(DateTime, default=datetime.utcnow, comment="创建时间") - update_time = Column(DateTime, default=None, onupdate=datetime.utcnow, comment="更新时间") + update_time = Column( + DateTime, default=None, onupdate=datetime.utcnow, comment="更新时间" + ) create_by = Column(String, default=None, comment="创建者") update_by = Column(String, default=None, comment="更新者") diff --git a/libs/chatchat-server/chatchat/server/db/models/conversation_model.py b/libs/chatchat-server/chatchat/server/db/models/conversation_model.py index 5edbabe7e..f1e6749b0 100644 --- a/libs/chatchat-server/chatchat/server/db/models/conversation_model.py +++ b/libs/chatchat-server/chatchat/server/db/models/conversation_model.py @@ -1,4 +1,5 @@ -from sqlalchemy import Column, Integer, String, DateTime, JSON, func +from sqlalchemy import JSON, Column, DateTime, Integer, String, func + from chatchat.server.db.base import Base @@ -6,11 +7,12 @@ class ConversationModel(Base): """ 聊天记录模型 """ - __tablename__ = 'conversation' - id = Column(String(32), primary_key=True, comment='对话框ID') - name = Column(String(50), comment='对话框名称') - chat_type = Column(String(50), comment='聊天类型') - create_time = Column(DateTime, default=func.now(), comment='创建时间') + + __tablename__ = "conversation" + id = Column(String(32), primary_key=True, comment="对话框ID") + name = Column(String(50), comment="对话框名称") + chat_type = Column(String(50), comment="聊天类型") + create_time = Column(DateTime, default=func.now(), comment="创建时间") def __repr__(self): return f"" diff --git a/libs/chatchat-server/chatchat/server/db/models/knowledge_base_model.py b/libs/chatchat-server/chatchat/server/db/models/knowledge_base_model.py index c22b71543..1ef7f0aa1 100644 --- a/libs/chatchat-server/chatchat/server/db/models/knowledge_base_model.py +++ b/libs/chatchat-server/chatchat/server/db/models/knowledge_base_model.py @@ -1,7 +1,9 @@ -from sqlalchemy import Column, Integer, String, DateTime, func -from pydantic import BaseModel -from typing import Optional from datetime import datetime +from typing import Optional + +from pydantic import BaseModel +from sqlalchemy import Column, DateTime, Integer, String, func + from chatchat.server.db.base import Base @@ -9,18 +11,20 @@ class KnowledgeBaseModel(Base): """ 知识库模型 """ - __tablename__ = 'knowledge_base' - id = Column(Integer, primary_key=True, autoincrement=True, comment='知识库ID') - kb_name = Column(String(50), comment='知识库名称') - kb_info = Column(String(200), comment='知识库简介(用于Agent)') - vs_type = Column(String(50), comment='向量库类型') - embed_model = Column(String(50), comment='嵌入模型名称') - file_count = Column(Integer, default=0, comment='文件数量') - create_time = Column(DateTime, default=func.now(), comment='创建时间') + + __tablename__ = "knowledge_base" + id = Column(Integer, primary_key=True, autoincrement=True, comment="知识库ID") + kb_name = Column(String(50), comment="知识库名称") + kb_info = Column(String(200), comment="知识库简介(用于Agent)") + vs_type = Column(String(50), comment="向量库类型") + embed_model = Column(String(50), comment="嵌入模型名称") + file_count = Column(Integer, default=0, comment="文件数量") + create_time = Column(DateTime, default=func.now(), comment="创建时间") def __repr__(self): return f"" + # 创建一个对应的 Pydantic 模型 class KnowledgeBaseSchema(BaseModel): id: int @@ -32,4 +36,4 @@ class KnowledgeBaseSchema(BaseModel): create_time: Optional[datetime] class Config: - from_attributes = True # 确保可以从 ORM 实例进行验证 \ No newline at end of file + from_attributes = True # 确保可以从 ORM 实例进行验证 diff --git a/libs/chatchat-server/chatchat/server/db/models/knowledge_file_model.py b/libs/chatchat-server/chatchat/server/db/models/knowledge_file_model.py index aa7c4612e..67fdff27e 100644 --- a/libs/chatchat-server/chatchat/server/db/models/knowledge_file_model.py +++ b/libs/chatchat-server/chatchat/server/db/models/knowledge_file_model.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func +from sqlalchemy import JSON, Boolean, Column, DateTime, Float, Integer, String, func from chatchat.server.db.base import Base @@ -7,19 +7,20 @@ class KnowledgeFileModel(Base): """ 知识文件模型 """ - __tablename__ = 'knowledge_file' - id = Column(Integer, primary_key=True, autoincrement=True, comment='知识文件ID') - file_name = Column(String(255), comment='文件名') - file_ext = Column(String(10), comment='文件扩展名') - kb_name = Column(String(50), comment='所属知识库名称') - document_loader_name = Column(String(50), comment='文档加载器名称') - text_splitter_name = Column(String(50), comment='文本分割器名称') - file_version = Column(Integer, default=1, comment='文件版本') + + __tablename__ = "knowledge_file" + id = Column(Integer, primary_key=True, autoincrement=True, comment="知识文件ID") + file_name = Column(String(255), comment="文件名") + file_ext = Column(String(10), comment="文件扩展名") + kb_name = Column(String(50), comment="所属知识库名称") + document_loader_name = Column(String(50), comment="文档加载器名称") + text_splitter_name = Column(String(50), comment="文本分割器名称") + file_version = Column(Integer, default=1, comment="文件版本") file_mtime = Column(Float, default=0.0, comment="文件修改时间") file_size = Column(Integer, default=0, comment="文件大小") custom_docs = Column(Boolean, default=False, comment="是否自定义docs") docs_count = Column(Integer, default=0, comment="切分文档数量") - create_time = Column(DateTime, default=func.now(), comment='创建时间') + create_time = Column(DateTime, default=func.now(), comment="创建时间") def __repr__(self): return f"" @@ -29,10 +30,11 @@ class FileDocModel(Base): """ 文件-向量库文档模型 """ - __tablename__ = 'file_doc' - id = Column(Integer, primary_key=True, autoincrement=True, comment='ID') - kb_name = Column(String(50), comment='知识库名称') - file_name = Column(String(255), comment='文件名称') + + __tablename__ = "file_doc" + id = Column(Integer, primary_key=True, autoincrement=True, comment="ID") + kb_name = Column(String(50), comment="知识库名称") + file_name = Column(String(255), comment="文件名称") doc_id = Column(String(50), comment="向量库文档ID") meta_data = Column(JSON, default={}) diff --git a/libs/chatchat-server/chatchat/server/db/models/knowledge_metadata_model.py b/libs/chatchat-server/chatchat/server/db/models/knowledge_metadata_model.py index 0fa7e0441..3f4b982a8 100644 --- a/libs/chatchat-server/chatchat/server/db/models/knowledge_metadata_model.py +++ b/libs/chatchat-server/chatchat/server/db/models/knowledge_metadata_model.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func +from sqlalchemy import JSON, Boolean, Column, DateTime, Float, Integer, String, func from chatchat.server.db.base import Base @@ -15,14 +15,17 @@ class SummaryChunkModel(Base): 语义相似度 """ - __tablename__ = 'summary_chunk' - id = Column(Integer, primary_key=True, autoincrement=True, comment='ID') - kb_name = Column(String(50), comment='知识库名称') - summary_context = Column(String(255), comment='总结文本') - summary_id = Column(String(255), comment='总结矢量id') + + __tablename__ = "summary_chunk" + id = Column(Integer, primary_key=True, autoincrement=True, comment="ID") + kb_name = Column(String(50), comment="知识库名称") + summary_context = Column(String(255), comment="总结文本") + summary_id = Column(String(255), comment="总结矢量id") doc_ids = Column(String(1024), comment="向量库id关联列表") meta_data = Column(JSON, default={}) def __repr__(self): - return (f"") + return ( + f"" + ) diff --git a/libs/chatchat-server/chatchat/server/db/models/message_model.py b/libs/chatchat-server/chatchat/server/db/models/message_model.py index 3382bd29c..c58747c8d 100644 --- a/libs/chatchat-server/chatchat/server/db/models/message_model.py +++ b/libs/chatchat-server/chatchat/server/db/models/message_model.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Integer, String, DateTime, JSON, func +from sqlalchemy import JSON, Column, DateTime, Integer, String, func from chatchat.server.db.base import Base @@ -7,18 +7,19 @@ class MessageModel(Base): """ 聊天记录模型 """ - __tablename__ = 'message' - id = Column(String(32), primary_key=True, comment='聊天记录ID') - conversation_id = Column(String(32), default=None, index=True, comment='对话框ID') - chat_type = Column(String(50), comment='聊天类型') - query = Column(String(4096), comment='用户问题') - response = Column(String(4096), comment='模型回答') + + __tablename__ = "message" + id = Column(String(32), primary_key=True, comment="聊天记录ID") + conversation_id = Column(String(32), default=None, index=True, comment="对话框ID") + chat_type = Column(String(50), comment="聊天类型") + query = Column(String(4096), comment="用户问题") + response = Column(String(4096), comment="模型回答") # 记录知识库id等,以便后续扩展 meta_data = Column(JSON, default={}) # 满分100 越高表示评价越好 - feedback_score = Column(Integer, default=-1, comment='用户评分') - feedback_reason = Column(String(255), default="", comment='用户评分理由') - create_time = Column(DateTime, default=func.now(), comment='创建时间') + feedback_score = Column(Integer, default=-1, comment="用户评分") + feedback_reason = Column(String(255), default="", comment="用户评分理由") + create_time = Column(DateTime, default=func.now(), comment="创建时间") def __repr__(self): return f"" diff --git a/libs/chatchat-server/chatchat/server/db/repository/__init__.py b/libs/chatchat-server/chatchat/server/db/repository/__init__.py index 0bb2cc0c0..5ec2ca123 100644 --- a/libs/chatchat-server/chatchat/server/db/repository/__init__.py +++ b/libs/chatchat-server/chatchat/server/db/repository/__init__.py @@ -1,4 +1,4 @@ from .conversation_repository import * -from .message_repository import * from .knowledge_base_repository import * -from .knowledge_file_repository import * \ No newline at end of file +from .knowledge_file_repository import * +from .message_repository import * diff --git a/libs/chatchat-server/chatchat/server/db/repository/conversation_repository.py b/libs/chatchat-server/chatchat/server/db/repository/conversation_repository.py index aede91889..ce40b9502 100644 --- a/libs/chatchat-server/chatchat/server/db/repository/conversation_repository.py +++ b/libs/chatchat-server/chatchat/server/db/repository/conversation_repository.py @@ -1,6 +1,7 @@ -from chatchat.server.db.session import with_session import uuid + from chatchat.server.db.models.conversation_model import ConversationModel +from chatchat.server.db.session import with_session @with_session diff --git a/libs/chatchat-server/chatchat/server/db/repository/knowledge_base_repository.py b/libs/chatchat-server/chatchat/server/db/repository/knowledge_base_repository.py index 785d95b94..bed6cd29b 100644 --- a/libs/chatchat-server/chatchat/server/db/repository/knowledge_base_repository.py +++ b/libs/chatchat-server/chatchat/server/db/repository/knowledge_base_repository.py @@ -1,14 +1,22 @@ -from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseModel -from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema +from chatchat.server.db.models.knowledge_base_model import ( + KnowledgeBaseModel, + KnowledgeBaseSchema, +) from chatchat.server.db.session import with_session @with_session def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model): # 创建知识库实例 - kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() + kb = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.kb_name.ilike(kb_name)) + .first() + ) if not kb: - kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model) + kb = KnowledgeBaseModel( + kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model + ) session.add(kb) else: # update kb with new vs_type and embed_model kb.kb_info = kb_info @@ -19,21 +27,33 @@ def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model): @with_session def list_kbs_from_db(session, min_file_count: int = -1): - kbs = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.file_count > min_file_count).all() + kbs = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.file_count > min_file_count) + .all() + ) kbs = [KnowledgeBaseSchema.model_validate(kb) for kb in kbs] return kbs @with_session def kb_exists(session, kb_name): - kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() + kb = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.kb_name.ilike(kb_name)) + .first() + ) status = True if kb else False return status @with_session def load_kb_from_db(session, kb_name): - kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() + kb = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.kb_name.ilike(kb_name)) + .first() + ) if kb: kb_name, vs_type, embed_model = kb.kb_name, kb.vs_type, kb.embed_model else: @@ -43,7 +63,11 @@ def load_kb_from_db(session, kb_name): @with_session def delete_kb_from_db(session, kb_name): - kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() + kb = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.kb_name.ilike(kb_name)) + .first() + ) if kb: session.delete(kb) return True @@ -51,7 +75,11 @@ def delete_kb_from_db(session, kb_name): @with_session def get_kb_detail(session, kb_name: str) -> dict: - kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() + kb: KnowledgeBaseModel = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.kb_name.ilike(kb_name)) + .first() + ) if kb: return { "kb_name": kb.kb_name, diff --git a/libs/chatchat-server/chatchat/server/db/repository/knowledge_file_repository.py b/libs/chatchat-server/chatchat/server/db/repository/knowledge_file_repository.py index 6ca51c403..d22d87b43 100644 --- a/libs/chatchat-server/chatchat/server/db/repository/knowledge_file_repository.py +++ b/libs/chatchat-server/chatchat/server/db/repository/knowledge_file_repository.py @@ -1,33 +1,43 @@ +from typing import Dict, List + from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseModel -from chatchat.server.db.models.knowledge_file_model import KnowledgeFileModel, FileDocModel +from chatchat.server.db.models.knowledge_file_model import ( + FileDocModel, + KnowledgeFileModel, +) from chatchat.server.db.session import with_session from chatchat.server.knowledge_base.utils import KnowledgeFile -from typing import List, Dict @with_session -def list_file_num_docs_id_by_kb_name_and_file_name(session, - kb_name: str, - file_name: str, - ) -> List[int]: - ''' +def list_file_num_docs_id_by_kb_name_and_file_name( + session, + kb_name: str, + file_name: str, +) -> List[int]: + """ 列出某知识库某文件对应的所有Document的id。 返回形式:[str, ...] - ''' - doc_ids = session.query(FileDocModel.doc_id).filter_by(kb_name=kb_name, file_name=file_name).all() + """ + doc_ids = ( + session.query(FileDocModel.doc_id) + .filter_by(kb_name=kb_name, file_name=file_name) + .all() + ) return [int(_id[0]) for _id in doc_ids] @with_session -def list_docs_from_db(session, - kb_name: str, - file_name: str = None, - metadata: Dict = {}, - ) -> List[Dict]: - ''' +def list_docs_from_db( + session, + kb_name: str, + file_name: str = None, + metadata: Dict = {}, +) -> List[Dict]: + """ 列出某知识库某文件对应的所有Document。 返回形式:[{"id": str, "metadata": dict}, ...] - ''' + """ docs = session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(kb_name)) if file_name: docs = docs.filter(FileDocModel.file_name.ilike(file_name)) @@ -38,14 +48,15 @@ def list_docs_from_db(session, @with_session -def delete_docs_from_db(session, - kb_name: str, - file_name: str = None, - ) -> List[Dict]: - ''' +def delete_docs_from_db( + session, + kb_name: str, + file_name: str = None, +) -> List[Dict]: + """ 删除某知识库某文件对应的所有Document,并返回被删除的Document。 返回形式:[{"id": str, "metadata": dict}, ...] - ''' + """ docs = list_docs_from_db(kb_name=kb_name, file_name=file_name) query = session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(kb_name)) if file_name: @@ -56,17 +67,16 @@ def delete_docs_from_db(session, @with_session -def add_docs_to_db(session, - kb_name: str, - file_name: str, - doc_infos: List[Dict]): - ''' +def add_docs_to_db(session, kb_name: str, file_name: str, doc_infos: List[Dict]): + """ 将某知识库某文件对应的所有Document信息添加到数据库。 doc_infos形式:[{"id": str, "metadata": dict}, ...] - ''' + """ # ! 这里会出现doc_infos为None的情况,需要进一步排查 if doc_infos is None: - print("输入的server.db.repository.knowledge_file_repository.add_docs_to_db的doc_infos参数为None") + print( + "输入的server.db.repository.knowledge_file_repository.add_docs_to_db的doc_infos参数为None" + ) return False for d in doc_infos: obj = FileDocModel( @@ -81,30 +91,43 @@ def add_docs_to_db(session, @with_session def count_files_from_db(session, kb_name: str) -> int: - return session.query(KnowledgeFileModel).filter(KnowledgeFileModel.kb_name.ilike(kb_name)).count() + return ( + session.query(KnowledgeFileModel) + .filter(KnowledgeFileModel.kb_name.ilike(kb_name)) + .count() + ) @with_session def list_files_from_db(session, kb_name): - files = session.query(KnowledgeFileModel).filter(KnowledgeFileModel.kb_name.ilike(kb_name)).all() + files = ( + session.query(KnowledgeFileModel) + .filter(KnowledgeFileModel.kb_name.ilike(kb_name)) + .all() + ) docs = [f.file_name for f in files] return docs @with_session -def add_file_to_db(session, - kb_file: KnowledgeFile, - docs_count: int = 0, - custom_docs: bool = False, - doc_infos: List[Dict] = [], # 形式:[{"id": str, "metadata": dict}, ...] - ): +def add_file_to_db( + session, + kb_file: KnowledgeFile, + docs_count: int = 0, + custom_docs: bool = False, + doc_infos: List[Dict] = [], # 形式:[{"id": str, "metadata": dict}, ...] +): kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first() if kb: # 如果已经存在该文件,则更新文件信息与版本号 - existing_file: KnowledgeFileModel = (session.query(KnowledgeFileModel) - .filter(KnowledgeFileModel.kb_name.ilike(kb_file.kb_name), - KnowledgeFileModel.file_name.ilike(kb_file.filename)) - .first()) + existing_file: KnowledgeFileModel = ( + session.query(KnowledgeFileModel) + .filter( + KnowledgeFileModel.kb_name.ilike(kb_file.kb_name), + KnowledgeFileModel.file_name.ilike(kb_file.filename), + ) + .first() + ) mtime = kb_file.get_mtime() size = kb_file.get_size() @@ -129,22 +152,32 @@ def add_file_to_db(session, ) kb.file_count += 1 session.add(new_file) - add_docs_to_db(kb_name=kb_file.kb_name, file_name=kb_file.filename, doc_infos=doc_infos) + add_docs_to_db( + kb_name=kb_file.kb_name, file_name=kb_file.filename, doc_infos=doc_infos + ) return True @with_session def delete_file_from_db(session, kb_file: KnowledgeFile): - existing_file = (session.query(KnowledgeFileModel) - .filter(KnowledgeFileModel.file_name.ilike(kb_file.filename), - KnowledgeFileModel.kb_name.ilike(kb_file.kb_name)) - .first()) + existing_file = ( + session.query(KnowledgeFileModel) + .filter( + KnowledgeFileModel.file_name.ilike(kb_file.filename), + KnowledgeFileModel.kb_name.ilike(kb_file.kb_name), + ) + .first() + ) if existing_file: session.delete(existing_file) delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename) session.commit() - kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_file.kb_name)).first() + kb = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.kb_name.ilike(kb_file.kb_name)) + .first() + ) if kb: kb.file_count -= 1 session.commit() @@ -153,11 +186,17 @@ def delete_file_from_db(session, kb_file: KnowledgeFile): @with_session def delete_files_from_db(session, knowledge_base_name: str): - session.query(KnowledgeFileModel).filter(KnowledgeFileModel.kb_name.ilike(knowledge_base_name)).delete( - synchronize_session=False) - session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(knowledge_base_name)).delete( - synchronize_session=False) - kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(knowledge_base_name)).first() + session.query(KnowledgeFileModel).filter( + KnowledgeFileModel.kb_name.ilike(knowledge_base_name) + ).delete(synchronize_session=False) + session.query(FileDocModel).filter( + FileDocModel.kb_name.ilike(knowledge_base_name) + ).delete(synchronize_session=False) + kb = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.kb_name.ilike(knowledge_base_name)) + .first() + ) if kb: kb.file_count = 0 @@ -167,19 +206,27 @@ def delete_files_from_db(session, knowledge_base_name: str): @with_session def file_exists_in_db(session, kb_file: KnowledgeFile): - existing_file = (session.query(KnowledgeFileModel) - .filter(KnowledgeFileModel.file_name.ilike(kb_file.filename), - KnowledgeFileModel.kb_name.ilike(kb_file.kb_name)) - .first()) + existing_file = ( + session.query(KnowledgeFileModel) + .filter( + KnowledgeFileModel.file_name.ilike(kb_file.filename), + KnowledgeFileModel.kb_name.ilike(kb_file.kb_name), + ) + .first() + ) return True if existing_file else False @with_session def get_file_detail(session, kb_name: str, filename: str) -> dict: - file: KnowledgeFileModel = (session.query(KnowledgeFileModel) - .filter(KnowledgeFileModel.file_name.ilike(filename), - KnowledgeFileModel.kb_name.ilike(kb_name)) - .first()) + file: KnowledgeFileModel = ( + session.query(KnowledgeFileModel) + .filter( + KnowledgeFileModel.file_name.ilike(filename), + KnowledgeFileModel.kb_name.ilike(kb_name), + ) + .first() + ) if file: return { "kb_name": file.kb_name, diff --git a/libs/chatchat-server/chatchat/server/db/repository/knowledge_metadata_repository.py b/libs/chatchat-server/chatchat/server/db/repository/knowledge_metadata_repository.py index b0b74a20f..0463d8f35 100644 --- a/libs/chatchat-server/chatchat/server/db/repository/knowledge_metadata_repository.py +++ b/libs/chatchat-server/chatchat/server/db/repository/knowledge_metadata_repository.py @@ -1,52 +1,59 @@ +from typing import Dict, List + from chatchat.server.db.models.knowledge_metadata_model import SummaryChunkModel from chatchat.server.db.session import with_session -from typing import List, Dict @with_session -def list_summary_from_db(session, - kb_name: str, - metadata: Dict = {}, - ) -> List[Dict]: - ''' +def list_summary_from_db( + session, + kb_name: str, + metadata: Dict = {}, +) -> List[Dict]: + """ 列出某知识库chunk summary。 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...] - ''' - docs = session.query(SummaryChunkModel).filter(SummaryChunkModel.kb_name.ilike(kb_name)) + """ + docs = session.query(SummaryChunkModel).filter( + SummaryChunkModel.kb_name.ilike(kb_name) + ) for k, v in metadata.items(): docs = docs.filter(SummaryChunkModel.meta_data[k].as_string() == str(v)) - return [{"id": x.id, - "summary_context": x.summary_context, - "summary_id": x.summary_id, - "doc_ids": x.doc_ids, - "metadata": x.metadata} for x in docs.all()] + return [ + { + "id": x.id, + "summary_context": x.summary_context, + "summary_id": x.summary_id, + "doc_ids": x.doc_ids, + "metadata": x.metadata, + } + for x in docs.all() + ] @with_session -def delete_summary_from_db(session, - kb_name: str - ) -> List[Dict]: - ''' +def delete_summary_from_db(session, kb_name: str) -> List[Dict]: + """ 删除知识库chunk summary,并返回被删除的Dchunk summary。 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...] - ''' + """ docs = list_summary_from_db(kb_name=kb_name) - query = session.query(SummaryChunkModel).filter(SummaryChunkModel.kb_name.ilike(kb_name)) + query = session.query(SummaryChunkModel).filter( + SummaryChunkModel.kb_name.ilike(kb_name) + ) query.delete(synchronize_session=False) session.commit() return docs @with_session -def add_summary_to_db(session, - kb_name: str, - summary_infos: List[Dict]): - ''' +def add_summary_to_db(session, kb_name: str, summary_infos: List[Dict]): + """ 将总结信息添加到数据库。 summary_infos形式:[{"summary_context": str, "doc_ids": str}, ...] - ''' + """ for summary in summary_infos: obj = SummaryChunkModel( kb_name=kb_name, @@ -63,4 +70,8 @@ def add_summary_to_db(session, @with_session def count_summary_from_db(session, kb_name: str) -> int: - return session.query(SummaryChunkModel).filter(SummaryChunkModel.kb_name.ilike(kb_name)).count() + return ( + session.query(SummaryChunkModel) + .filter(SummaryChunkModel.kb_name.ilike(kb_name)) + .count() + ) diff --git a/libs/chatchat-server/chatchat/server/db/repository/message_repository.py b/libs/chatchat-server/chatchat/server/db/repository/message_repository.py index 1453a23c3..75bdd9d80 100644 --- a/libs/chatchat-server/chatchat/server/db/repository/message_repository.py +++ b/libs/chatchat-server/chatchat/server/db/repository/message_repository.py @@ -1,20 +1,33 @@ -from chatchat.server.db.session import with_session -from typing import Dict, List import uuid +from typing import Dict, List + from chatchat.server.db.models.message_model import MessageModel +from chatchat.server.db.session import with_session @with_session -def add_message_to_db(session, conversation_id: str, chat_type, query, response="", message_id=None, - metadata: Dict = {}): +def add_message_to_db( + session, + conversation_id: str, + chat_type, + query, + response="", + message_id=None, + metadata: Dict = {}, +): """ 新增聊天记录 """ if not message_id: message_id = uuid.uuid4().hex - m = MessageModel(id=message_id, chat_type=chat_type, query=query, response=response, - conversation_id=conversation_id, - meta_data=metadata) + m = MessageModel( + id=message_id, + chat_type=chat_type, + query=query, + response=response, + conversation_id=conversation_id, + meta_data=metadata, + ) session.add(m) session.commit() return m.id @@ -60,11 +73,18 @@ def feedback_message_to_db(session, message_id, feedback_score, feedback_reason) @with_session def filter_message(session, conversation_id: str, limit: int = 10): - messages = (session.query(MessageModel).filter_by(conversation_id=conversation_id). - # 用户最新的query 也会插入到db,忽略这个message record - filter(MessageModel.response != ''). - # 返回最近的limit 条记录 - order_by(MessageModel.create_time.desc()).limit(limit).all()) + messages = ( + session.query(MessageModel) + .filter_by(conversation_id=conversation_id) + . + # 用户最新的query 也会插入到db,忽略这个message record + filter(MessageModel.response != "") + . + # 返回最近的limit 条记录 + order_by(MessageModel.create_time.desc()) + .limit(limit) + .all() + ) # 直接返回 List[MessageModel] 报错 data = [] for m in messages: diff --git a/libs/chatchat-server/chatchat/server/db/session.py b/libs/chatchat-server/chatchat/server/db/session.py index ce8680596..7b8f835da 100644 --- a/libs/chatchat-server/chatchat/server/db/session.py +++ b/libs/chatchat-server/chatchat/server/db/session.py @@ -1,8 +1,10 @@ -from functools import wraps from contextlib import contextmanager -from chatchat.server.db.base import SessionLocal +from functools import wraps + from sqlalchemy.orm import Session +from chatchat.server.db.base import SessionLocal + @contextmanager def session_scope() -> Session: diff --git a/libs/chatchat-server/chatchat/server/file_rag/document_loaders/FilteredCSVloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/FilteredCSVloader.py index 1279495f5..c1cc6b6bd 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/document_loaders/FilteredCSVloader.py +++ b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/FilteredCSVloader.py @@ -1,23 +1,24 @@ ## 指定制定列的csv文件加载器 -from langchain_community.document_loaders import CSVLoader import csv from io import TextIOWrapper from typing import Dict, List, Optional + from langchain.docstore.document import Document +from langchain_community.document_loaders import CSVLoader from langchain_community.document_loaders.helpers import detect_file_encodings class FilteredCSVLoader(CSVLoader): def __init__( - self, - file_path: str, - columns_to_read: List[str], - source_column: Optional[str] = None, - metadata_columns: List[str] = [], - csv_args: Optional[Dict] = None, - encoding: Optional[str] = None, - autodetect_encoding: bool = False, + self, + file_path: str, + columns_to_read: List[str], + source_column: Optional[str] = None, + metadata_columns: List[str] = [], + csv_args: Optional[Dict] = None, + encoding: Optional[str] = None, + autodetect_encoding: bool = False, ): super().__init__( file_path=file_path, @@ -62,10 +63,12 @@ def __read_file(self, csvfile: TextIOWrapper) -> List[Document]: content = [] for col in self.columns_to_read: if col in row: - content.append(f'{col}:{str(row[col])}') + content.append(f"{col}:{str(row[col])}") else: - raise ValueError(f"Column '{self.columns_to_read[0]}' not found in CSV file.") - content = '\n'.join(content) + raise ValueError( + f"Column '{self.columns_to_read[0]}' not found in CSV file." + ) + content = "\n".join(content) # Extract the source if available source = ( row.get(self.source_column, None) diff --git a/libs/chatchat-server/chatchat/server/file_rag/document_loaders/__init__.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/__init__.py index 88cfeae81..eb99c8276 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/document_loaders/__init__.py +++ b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/__init__.py @@ -1,4 +1,4 @@ -from .mypdfloader import RapidOCRPDFLoader -from .myimgloader import RapidOCRLoader from .mydocloader import RapidOCRDocLoader +from .myimgloader import RapidOCRLoader +from .mypdfloader import RapidOCRPDFLoader from .mypptloader import RapidOCRPPTLoader diff --git a/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mydocloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mydocloader.py index d10dd49b8..82d71edf3 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mydocloader.py +++ b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mydocloader.py @@ -1,26 +1,30 @@ -from langchain_community.document_loaders.unstructured import UnstructuredFileLoader from typing import List + import tqdm +from langchain_community.document_loaders.unstructured import UnstructuredFileLoader class RapidOCRDocLoader(UnstructuredFileLoader): def _get_elements(self) -> List: def doc2text(filepath): - from docx.table import _Cell, Table + from io import BytesIO + + import numpy as np + from docx import Document, ImagePart from docx.oxml.table import CT_Tbl from docx.oxml.text.paragraph import CT_P + from docx.table import Table, _Cell from docx.text.paragraph import Paragraph - from docx import Document, ImagePart from PIL import Image - from io import BytesIO - import numpy as np from rapidocr_onnxruntime import RapidOCR + ocr = RapidOCR() doc = Document(filepath) resp = "" def iter_block_items(parent): from docx.document import Document + if isinstance(parent, Document): parent_elm = parent.element.body elif isinstance(parent, _Cell): @@ -34,18 +38,21 @@ def iter_block_items(parent): elif isinstance(child, CT_Tbl): yield Table(child, parent) - b_unit = tqdm.tqdm(total=len(doc.paragraphs)+len(doc.tables), - desc="RapidOCRDocLoader block index: 0") + b_unit = tqdm.tqdm( + total=len(doc.paragraphs) + len(doc.tables), + desc="RapidOCRDocLoader block index: 0", + ) for i, block in enumerate(iter_block_items(doc)): - b_unit.set_description( - "RapidOCRDocLoader block index: {}".format(i)) + b_unit.set_description("RapidOCRDocLoader block index: {}".format(i)) b_unit.refresh() if isinstance(block, Paragraph): resp += block.text.strip() + "\n" - images = block._element.xpath('.//pic:pic') # 获取所有图片 + images = block._element.xpath(".//pic:pic") # 获取所有图片 for image in images: - for img_id in image.xpath('.//a:blip/@r:embed'): # 获取图片id - part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片 + for img_id in image.xpath(".//a:blip/@r:embed"): # 获取图片id + part = doc.part.related_parts[ + img_id + ] # 根据图片id获取对应的图片 if isinstance(part, ImagePart): image = Image.open(BytesIO(part._blob)) result, _ = ocr(np.array(image)) @@ -62,10 +69,11 @@ def iter_block_items(parent): text = doc2text(self.file_path) from unstructured.partition.text import partition_text + return partition_text(text=text, **self.unstructured_kwargs) -if __name__ == '__main__': +if __name__ == "__main__": loader = RapidOCRDocLoader(file_path="../tests/samples/ocr_test.docx") docs = loader.load() print(docs) diff --git a/libs/chatchat-server/chatchat/server/file_rag/document_loaders/myimgloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/myimgloader.py index c6fda01e3..f11b6c57c 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/document_loaders/myimgloader.py +++ b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/myimgloader.py @@ -1,5 +1,7 @@ from typing import List + from langchain_community.document_loaders.unstructured import UnstructuredFileLoader + from chatchat.server.file_rag.document_loaders.ocr import get_ocr @@ -16,6 +18,7 @@ def img2text(filepath): text = img2text(self.file_path) from unstructured.partition.text import partition_text + return partition_text(text=text, **self.unstructured_kwargs) diff --git a/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypdfloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypdfloader.py index c6a178f8a..aa981a23d 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypdfloader.py +++ b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypdfloader.py @@ -1,49 +1,56 @@ from typing import List -from langchain_community.document_loaders.unstructured import UnstructuredFileLoader + import cv2 -from PIL import Image import numpy as np +import tqdm +from langchain_community.document_loaders.unstructured import UnstructuredFileLoader +from PIL import Image + from chatchat.configs import PDF_OCR_THRESHOLD from chatchat.server.file_rag.document_loaders.ocr import get_ocr -import tqdm class RapidOCRPDFLoader(UnstructuredFileLoader): def _get_elements(self) -> List: def rotate_img(img, angle): - ''' + """ img --image angle --rotation angle return--rotated img - ''' - + """ + h, w = img.shape[:2] - rotate_center = (w/2, h/2) - #获取旋转矩阵 + rotate_center = (w / 2, h / 2) + # 获取旋转矩阵 # 参数1为旋转中心点; # 参数2为旋转角度,正值-逆时针旋转;负值-顺时针旋转 # 参数3为各向同性的比例因子,1.0原图,2.0变成原来的2倍,0.5变成原来的0.5倍 M = cv2.getRotationMatrix2D(rotate_center, angle, 1.0) - #计算图像新边界 + # 计算图像新边界 new_w = int(h * np.abs(M[0, 1]) + w * np.abs(M[0, 0])) new_h = int(h * np.abs(M[0, 0]) + w * np.abs(M[0, 1])) - #调整旋转矩阵以考虑平移 + # 调整旋转矩阵以考虑平移 M[0, 2] += (new_w - w) / 2 M[1, 2] += (new_h - h) / 2 rotated_img = cv2.warpAffine(img, M, (new_w, new_h)) return rotated_img - + def pdf2text(filepath): - import fitz # pyMuPDF里面的fitz包,不要与pip install fitz混淆 + import fitz # pyMuPDF里面的fitz包,不要与pip install fitz混淆 import numpy as np + ocr = get_ocr() doc = fitz.open(filepath) resp = "" - b_unit = tqdm.tqdm(total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0") + b_unit = tqdm.tqdm( + total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0" + ) for i, page in enumerate(doc): - b_unit.set_description("RapidOCRPDFLoader context page index: {}".format(i)) + b_unit.set_description( + "RapidOCRPDFLoader context page index: {}".format(i) + ) b_unit.refresh() text = page.get_text("") resp += text + "\n" @@ -53,19 +60,26 @@ def pdf2text(filepath): if xref := img.get("xref"): bbox = img["bbox"] # 检查图片尺寸是否超过设定的阈值 - if ((bbox[2] - bbox[0]) / (page.rect.width) < PDF_OCR_THRESHOLD[0] - or (bbox[3] - bbox[1]) / (page.rect.height) < PDF_OCR_THRESHOLD[1]): + if (bbox[2] - bbox[0]) / (page.rect.width) < PDF_OCR_THRESHOLD[ + 0 + ] or (bbox[3] - bbox[1]) / ( + page.rect.height + ) < PDF_OCR_THRESHOLD[1]: continue pix = fitz.Pixmap(doc, xref) samples = pix.samples - if int(page.rotation)!=0: #如果Page有旋转角度,则旋转图片 - img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1) - tmp_img = Image.fromarray(img_array); - ori_img = cv2.cvtColor(np.array(tmp_img),cv2.COLOR_RGB2BGR) - rot_img = rotate_img(img=ori_img, angle=360-page.rotation) + if int(page.rotation) != 0: # 如果Page有旋转角度,则旋转图片 + img_array = np.frombuffer( + pix.samples, dtype=np.uint8 + ).reshape(pix.height, pix.width, -1) + tmp_img = Image.fromarray(img_array) + ori_img = cv2.cvtColor(np.array(tmp_img), cv2.COLOR_RGB2BGR) + rot_img = rotate_img(img=ori_img, angle=360 - page.rotation) img_array = cv2.cvtColor(rot_img, cv2.COLOR_RGB2BGR) else: - img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1) + img_array = np.frombuffer( + pix.samples, dtype=np.uint8 + ).reshape(pix.height, pix.width, -1) result, _ = ocr(img_array) if result: @@ -78,6 +92,7 @@ def pdf2text(filepath): text = pdf2text(self.file_path) from unstructured.partition.text import partition_text + return partition_text(text=text, **self.unstructured_kwargs) diff --git a/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypptloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypptloader.py index 309ffdcca..7b00df075 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypptloader.py +++ b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypptloader.py @@ -1,16 +1,19 @@ -from langchain_community.document_loaders.unstructured import UnstructuredFileLoader from typing import List + import tqdm +from langchain_community.document_loaders.unstructured import UnstructuredFileLoader class RapidOCRPPTLoader(UnstructuredFileLoader): def _get_elements(self) -> List: def ppt2text(filepath): - from pptx import Presentation - from PIL import Image - import numpy as np from io import BytesIO + + import numpy as np + from PIL import Image + from pptx import Presentation from rapidocr_onnxruntime import RapidOCR + ocr = RapidOCR() prs = Presentation(filepath) resp = "" @@ -34,15 +37,18 @@ def extract_text(shape): for child_shape in shape.shapes: extract_text(child_shape) - b_unit = tqdm.tqdm(total=len(prs.slides), - desc="RapidOCRPPTLoader slide index: 1") + b_unit = tqdm.tqdm( + total=len(prs.slides), desc="RapidOCRPPTLoader slide index: 1" + ) # 遍历所有幻灯片 for slide_number, slide in enumerate(prs.slides, start=1): b_unit.set_description( - "RapidOCRPPTLoader slide index: {}".format(slide_number)) + "RapidOCRPPTLoader slide index: {}".format(slide_number) + ) b_unit.refresh() - sorted_shapes = sorted(slide.shapes, - key=lambda x: (x.top, x.left)) # 从上到下、从左到右遍历 + sorted_shapes = sorted( + slide.shapes, key=lambda x: (x.top, x.left) + ) # 从上到下、从左到右遍历 for shape in sorted_shapes: extract_text(shape) b_unit.update(1) @@ -50,10 +56,11 @@ def extract_text(shape): text = ppt2text(self.file_path) from unstructured.partition.text import partition_text + return partition_text(text=text, **self.unstructured_kwargs) -if __name__ == '__main__': +if __name__ == "__main__": loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx") docs = loader.load() print(docs) diff --git a/libs/chatchat-server/chatchat/server/file_rag/document_loaders/ocr.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/ocr.py index 2b66dd357..49160287f 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/document_loaders/ocr.py +++ b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/ocr.py @@ -1,6 +1,5 @@ from typing import TYPE_CHECKING - if TYPE_CHECKING: try: from rapidocr_paddle import RapidOCR @@ -11,8 +10,12 @@ def get_ocr(use_cuda: bool = True) -> "RapidOCR": try: from rapidocr_paddle import RapidOCR - ocr = RapidOCR(det_use_cuda=use_cuda, cls_use_cuda=use_cuda, rec_use_cuda=use_cuda) + + ocr = RapidOCR( + det_use_cuda=use_cuda, cls_use_cuda=use_cuda, rec_use_cuda=use_cuda + ) except ImportError: from rapidocr_onnxruntime import RapidOCR + ocr = RapidOCR() return ocr diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py index 2cf3617f0..ef78fac77 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py @@ -1,3 +1,3 @@ from chatchat.server.file_rag.retrievers.base import BaseRetrieverService +from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService -from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService \ No newline at end of file diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/base.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/base.py index 7e4d06465..6cda59582 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/retrievers/base.py +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/base.py @@ -1,6 +1,7 @@ -from langchain.vectorstores import VectorStore from abc import ABCMeta, abstractmethod +from langchain.vectorstores import VectorStore + class BaseRetrieverService(metaclass=ABCMeta): def __init__(self, **kwargs): @@ -10,12 +11,11 @@ def __init__(self, **kwargs): def do_init(self, **kwargs): pass - @abstractmethod def from_vectorstore( - vectorstore: VectorStore, - top_k: int, - score_threshold: int or float, + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, ): pass diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py index 5d6b17a60..31a6aaea7 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py @@ -1,37 +1,35 @@ -from chatchat.server.file_rag.retrievers.base import BaseRetrieverService +from langchain.retrievers import EnsembleRetriever from langchain.vectorstores import VectorStore -from langchain_core.retrievers import BaseRetriever from langchain_community.retrievers import BM25Retriever -from langchain.retrievers import EnsembleRetriever +from langchain_core.retrievers import BaseRetriever + +from chatchat.server.file_rag.retrievers.base import BaseRetrieverService class EnsembleRetrieverService(BaseRetrieverService): def do_init( - self, - retriever: BaseRetriever = None, - top_k: int = 5, + self, + retriever: BaseRetriever = None, + top_k: int = 5, ): self.vs = None self.top_k = top_k self.retriever = retriever - @staticmethod def from_vectorstore( - vectorstore: VectorStore, - top_k: int, - score_threshold: int or float, + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, ): faiss_retriever = vectorstore.as_retriever( search_type="similarity_score_threshold", - search_kwargs={ - "score_threshold": score_threshold, - "k": top_k - } + search_kwargs={"score_threshold": score_threshold, "k": top_k}, ) # TODO: 换个不用torch的实现方式 # from cutword.cutword import Cutter import jieba + # cutter = Cutter() docs = list(vectorstore.docstore._dict.values()) bm25_retriever = BM25Retriever.from_documents( @@ -45,4 +43,4 @@ def from_vectorstore( return EnsembleRetrieverService(retriever=ensemble_retriever) def get_relevant_documents(self, query: str): - return self.retriever.get_relevant_documents(query)[:self.top_k] + return self.retriever.get_relevant_documents(query)[: self.top_k] diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py index b6d382fa5..ba65353c3 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py @@ -1,33 +1,30 @@ -from chatchat.server.file_rag.retrievers.base import BaseRetrieverService from langchain.vectorstores import VectorStore from langchain_core.retrievers import BaseRetriever +from chatchat.server.file_rag.retrievers.base import BaseRetrieverService + class VectorstoreRetrieverService(BaseRetrieverService): def do_init( - self, - retriever: BaseRetriever = None, - top_k: int = 5, + self, + retriever: BaseRetriever = None, + top_k: int = 5, ): self.vs = None self.top_k = top_k self.retriever = retriever - @staticmethod def from_vectorstore( - vectorstore: VectorStore, - top_k: int, - score_threshold: int or float, + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, ): retriever = vectorstore.as_retriever( search_type="similarity_score_threshold", - search_kwargs={ - "score_threshold": score_threshold, - "k": top_k - } + search_kwargs={"score_threshold": score_threshold, "k": top_k}, ) return VectorstoreRetrieverService(retriever=retriever) def get_relevant_documents(self, query: str): - return self.retriever.get_relevant_documents(query)[:self.top_k] + return self.retriever.get_relevant_documents(query)[: self.top_k] diff --git a/libs/chatchat-server/chatchat/server/file_rag/text_splitter/__init__.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/__init__.py index dc0641206..c0e418ab4 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/text_splitter/__init__.py +++ b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/__init__.py @@ -1,4 +1,4 @@ -from .chinese_text_splitter import ChineseTextSplitter from .ali_text_splitter import AliTextSplitter +from .chinese_recursive_text_splitter import ChineseRecursiveTextSplitter +from .chinese_text_splitter import ChineseTextSplitter from .zh_title_enhance import zh_title_enhance -from .chinese_recursive_text_splitter import ChineseRecursiveTextSplitter \ No newline at end of file diff --git a/libs/chatchat-server/chatchat/server/file_rag/text_splitter/ali_text_splitter.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/ali_text_splitter.py index 93846d190..9def31a9b 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/text_splitter/ali_text_splitter.py +++ b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/ali_text_splitter.py @@ -1,7 +1,8 @@ -from langchain.text_splitter import CharacterTextSplitter import re from typing import List +from langchain.text_splitter import CharacterTextSplitter + class AliTextSplitter(CharacterTextSplitter): def __init__(self, pdf: bool = False, **kwargs): @@ -14,7 +15,7 @@ def split_text(self, text: str) -> List[str]: # 考虑到使用了三个模型,可能对于低配置gpu不太友好,因此这里将模型load进cpu计算,有需要的话可以替换device为自己的显卡id if self.pdf: text = re.sub(r"\n{3,}", r"\n", text) - text = re.sub('\s', " ", text) + text = re.sub("\s", " ", text) text = re.sub("\n\n", "", text) try: from modelscope.pipelines import pipeline @@ -24,11 +25,11 @@ def split_text(self, text: str) -> List[str]: "Please install modelscope with `pip install modelscope`. " ) - p = pipeline( task="document-segmentation", - model='damo/nlp_bert_document-segmentation_chinese-base', - device="cpu") + model="damo/nlp_bert_document-segmentation_chinese-base", + device="cpu", + ) result = p(documents=text) sent_list = [i for i in result["text"].split("\n\t") if i] return sent_list diff --git a/libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_recursive_text_splitter.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_recursive_text_splitter.py index 70b4b29c2..cedbb3b3e 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_recursive_text_splitter.py +++ b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_recursive_text_splitter.py @@ -1,13 +1,14 @@ +import logging import re -from typing import List, Optional, Any +from typing import Any, List, Optional + from langchain.text_splitter import RecursiveCharacterTextSplitter -import logging logger = logging.getLogger(__name__) def _split_text_with_regex_from_end( - text: str, separator: str, keep_separator: bool + text: str, separator: str, keep_separator: bool ) -> List[str]: # Now that we have the separator, split the text if separator: @@ -27,11 +28,11 @@ def _split_text_with_regex_from_end( class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): def __init__( - self, - separators: Optional[List[str]] = None, - keep_separator: bool = True, - is_separator_regex: bool = True, - **kwargs: Any, + self, + separators: Optional[List[str]] = None, + keep_separator: bool = True, + is_separator_regex: bool = True, + **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(keep_separator=keep_separator, **kwargs) @@ -41,7 +42,7 @@ def __init__( "。|!|?", "\.\s|\!\s|\?\s", ";|;\s", - ",|,\s" + ",|,\s", ] self._is_separator_regex = is_separator_regex @@ -58,7 +59,7 @@ def _split_text(self, text: str, separators: List[str]) -> List[str]: break if re.search(_separator, text): separator = _s - new_separators = separators[i + 1:] + new_separators = separators[i + 1 :] break _separator = separator if self._is_separator_regex else re.escape(separator) @@ -83,19 +84,20 @@ def _split_text(self, text: str, separators: List[str]) -> List[str]: if _good_splits: merged_text = self._merge_splits(_good_splits, _separator) final_chunks.extend(merged_text) - return [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip()!=""] + return [ + re.sub(r"\n{2,}", "\n", chunk.strip()) + for chunk in final_chunks + if chunk.strip() != "" + ] if __name__ == "__main__": text_splitter = ChineseRecursiveTextSplitter( - keep_separator=True, - is_separator_regex=True, - chunk_size=50, - chunk_overlap=0 + keep_separator=True, is_separator_regex=True, chunk_size=50, chunk_overlap=0 ) ls = [ """中国对外贸易形势报告(75页)。前 10 个月,一般贸易进出口 19.5 万亿元,增长 25.1%, 比整体进出口增速高出 2.9 个百分点,占进出口总额的 61.7%,较去年同期提升 1.6 个百分点。其中,一般贸易出口 10.6 万亿元,增长 25.3%,占出口总额的 60.9%,提升 1.5 个百分点;进口8.9万亿元,增长24.9%,占进口总额的62.7%, 提升 1.8 个百分点。加工贸易进出口 6.8 万亿元,增长 11.8%, 占进出口总额的 21.5%,减少 2.0 个百分点。其中,出口增 长 10.4%,占出口总额的 24.3%,减少 2.6 个百分点;进口增 长 14.2%,占进口总额的 18.0%,减少 1.2 个百分点。此外, 以保税物流方式进出口 3.96 万亿元,增长 27.9%。其中,出 口 1.47 万亿元,增长 38.9%;进口 2.49 万亿元,增长 22.2%。前三季度,中国服务贸易继续保持快速增长态势。服务 进出口总额 37834.3 亿元,增长 11.6%;其中服务出口 17820.9 亿元,增长 27.3%;进口 20013.4 亿元,增长 0.5%,进口增 速实现了疫情以来的首次转正。服务出口增幅大于进口 26.8 个百分点,带动服务贸易逆差下降 62.9%至 2192.5 亿元。服 务贸易结构持续优化,知识密集型服务进出口 16917.7 亿元, 增长 13.3%,占服务进出口总额的比重达到 44.7%,提升 0.7 个百分点。 二、中国对外贸易发展环境分析和展望 全球疫情起伏反复,经济复苏分化加剧,大宗商品价格 上涨、能源紧缺、运力紧张及发达经济体政策调整外溢等风 险交织叠加。同时也要看到,我国经济长期向好的趋势没有 改变,外贸企业韧性和活力不断增强,新业态新模式加快发 展,创新转型步伐提速。产业链供应链面临挑战。美欧等加快出台制造业回迁计 划,加速产业链供应链本土布局,跨国公司调整产业链供应 链,全球双链面临新一轮重构,区域化、近岸化、本土化、 短链化趋势凸显。疫苗供应不足,制造业“缺芯”、物流受限、 运价高企,全球产业链供应链面临压力。 全球通胀持续高位运行。能源价格上涨加大主要经济体 的通胀压力,增加全球经济复苏的不确定性。世界银行今年 10 月发布《大宗商品市场展望》指出,能源价格在 2021 年 大涨逾 80%,并且仍将在 2022 年小幅上涨。IMF 指出,全 球通胀上行风险加剧,通胀前景存在巨大不确定性。""", - ] + ] # text = """""" for inum, text in enumerate(ls): print(inum) diff --git a/libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_text_splitter.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_text_splitter.py index 4107b25f0..9d4e2e286 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_text_splitter.py +++ b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_text_splitter.py @@ -1,7 +1,8 @@ -from langchain.text_splitter import CharacterTextSplitter import re from typing import List +from langchain.text_splitter import CharacterTextSplitter + class ChineseTextSplitter(CharacterTextSplitter): def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs): @@ -12,9 +13,11 @@ def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs): def split_text1(self, text: str) -> List[str]: if self.pdf: text = re.sub(r"\n{3,}", "\n", text) - text = re.sub('\s', ' ', text) + text = re.sub("\s", " ", text) text = text.replace("\n\n", "") - sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :; + sent_sep_pattern = re.compile( + '([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))' + ) # del :; sent_list = [] for ele in sent_sep_pattern.split(text): if sent_sep_pattern.match(ele) and sent_list: @@ -23,37 +26,52 @@ def split_text1(self, text: str) -> List[str]: sent_list.append(ele) return sent_list - def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑 + def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑 if self.pdf: text = re.sub(r"\n{3,}", r"\n", text) - text = re.sub('\s', " ", text) + text = re.sub("\s", " ", text) text = re.sub("\n\n", "", text) - text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符 + text = re.sub(r"([;;.!?。!?\?])([^”’])", r"\1\n\2", text) # 单字符断句符 text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号 text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号 - text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text) + text = re.sub( + r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r"\1\n\2", text + ) # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号 text = text.rstrip() # 段尾如果有多余的\n就去掉它 # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。 ls = [i for i in text.split("\n") if i] for ele in ls: if len(ele) > self.sentence_size: - ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele) + ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r"\1\n\2", ele) ele1_ls = ele1.split("\n") for ele_ele1 in ele1_ls: if len(ele_ele1) > self.sentence_size: - ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1) + ele_ele2 = re.sub( + r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', + r"\1\n\2", + ele_ele1, + ) ele2_ls = ele_ele2.split("\n") for ele_ele2 in ele2_ls: if len(ele_ele2) > self.sentence_size: - ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2) + ele_ele3 = re.sub( + '( ["’”」』]{0,2})([^ ])', r"\1\n\2", ele_ele2 + ) ele2_id = ele2_ls.index(ele_ele2) - ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[ - ele2_id + 1:] + ele2_ls = ( + ele2_ls[:ele2_id] + + [i for i in ele_ele3.split("\n") if i] + + ele2_ls[ele2_id + 1 :] + ) ele_id = ele1_ls.index(ele_ele1) - ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:] + ele1_ls = ( + ele1_ls[:ele_id] + + [i for i in ele2_ls if i] + + ele1_ls[ele_id + 1 :] + ) id = ls.index(ele) - ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:] + ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1 :] return ls diff --git a/libs/chatchat-server/chatchat/server/file_rag/text_splitter/zh_title_enhance.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/zh_title_enhance.py index 7f8c54843..793e0ba94 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/text_splitter/zh_title_enhance.py +++ b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/zh_title_enhance.py @@ -1,6 +1,7 @@ -from langchain.docstore.document import Document import re +from langchain.docstore.document import Document + def under_non_alpha_ratio(text: str, threshold: float = 0.5): """Checks if the proportion of non-alpha characters in the text snippet exceeds a given @@ -28,9 +29,9 @@ def under_non_alpha_ratio(text: str, threshold: float = 0.5): def is_possible_title( - text: str, - title_max_word_length: int = 20, - non_alpha_threshold: float = 0.5, + text: str, + title_max_word_length: int = 20, + non_alpha_threshold: float = 0.5, ) -> bool: """Checks to see if the text passes all of the checks for a valid title. @@ -90,7 +91,7 @@ def zh_title_enhance(docs: Document) -> Document: if len(docs) > 0: for doc in docs: if is_possible_title(doc.page_content): - doc.metadata['category'] = 'cn_Title' + doc.metadata["category"] = "cn_Title" title = doc.page_content elif title: doc.page_content = f"下文与({title})有关。{doc.page_content}" diff --git a/libs/chatchat-server/chatchat/server/file_rag/utils.py b/libs/chatchat-server/chatchat/server/file_rag/utils.py index ddf64e3d1..cd767451a 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/utils.py +++ b/libs/chatchat-server/chatchat/server/file_rag/utils.py @@ -1,7 +1,7 @@ from chatchat.server.file_rag.retrievers import ( BaseRetrieverService, - VectorstoreRetrieverService, EnsembleRetrieverService, + VectorstoreRetrieverService, ) Retrivals = { @@ -9,5 +9,6 @@ "ensemble": EnsembleRetrieverService, } + def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService: - return Retrivals[type] \ No newline at end of file + return Retrivals[type] diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_api.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_api.py index 0ccc0704f..36a917fea 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_api.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_api.py @@ -1,22 +1,25 @@ import urllib -from chatchat.server.utils import BaseResponse, ListResponse -from chatchat.server.knowledge_base.utils import validate_kb_name -from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory -from chatchat.server.db.repository.knowledge_base_repository import list_kbs_from_db -from chatchat.configs import DEFAULT_EMBEDDING_MODEL, logger, log_verbose + from fastapi import Body +from chatchat.configs import DEFAULT_EMBEDDING_MODEL, log_verbose, logger +from chatchat.server.db.repository.knowledge_base_repository import list_kbs_from_db +from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory +from chatchat.server.knowledge_base.utils import validate_kb_name +from chatchat.server.utils import BaseResponse, ListResponse + def list_kbs(): # Get List of Knowledge Base return ListResponse(data=list_kbs_from_db()) -def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), - vector_store_type: str = Body("faiss"), - kb_info: str = Body("", description="知识库内容简介,用于Agent选择知识库。"), - embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), - ) -> BaseResponse: +def create_kb( + knowledge_base_name: str = Body(..., examples=["samples"]), + vector_store_type: str = Body("faiss"), + kb_info: str = Body("", description="知识库内容简介,用于Agent选择知识库。"), + embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), +) -> BaseResponse: # Create selected knowledge base if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -27,20 +30,23 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), if kb is not None: return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") - kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model, kb_info=kb_info) + kb = KBServiceFactory.get_service( + knowledge_base_name, vector_store_type, embed_model, kb_info=kb_info + ) try: kb.create_kb() except Exception as e: msg = f"创建知识库出错: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", exc_info=e if log_verbose else None + ) return BaseResponse(code=500, msg=msg) return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") def delete_kb( - knowledge_base_name: str = Body(..., examples=["samples"]) + knowledge_base_name: str = Body(..., examples=["samples"]), ) -> BaseResponse: # Delete selected knowledge base if not validate_kb_name(knowledge_base_name): @@ -59,8 +65,9 @@ def delete_kb( return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") except Exception as e: msg = f"删除知识库时出现意外: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", exc_info=e if log_verbose else None + ) return BaseResponse(code=500, msg=msg) return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}") diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/base.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/base.py index a11e0054d..877c05ed4 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/base.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/base.py @@ -1,15 +1,18 @@ -from langchain.embeddings.base import Embeddings -from langchain.vectorstores.faiss import FAISS import threading -from chatchat.configs import (DEFAULT_EMBEDDING_MODEL, CHUNK_SIZE, - logger, log_verbose) -from contextlib import contextmanager from collections import OrderedDict -from typing import List, Any, Union, Tuple +from contextlib import contextmanager +from typing import Any, List, Tuple, Union + +from langchain.embeddings.base import Embeddings +from langchain.vectorstores.faiss import FAISS + +from chatchat.configs import CHUNK_SIZE, DEFAULT_EMBEDDING_MODEL, log_verbose, logger class ThreadSafeObject: - def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None): + def __init__( + self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None + ): self._obj = obj self._key = key self._pool = pool @@ -62,7 +65,7 @@ def __init__(self, cache_num: int = -1): self._cache_num = cache_num self._cache = OrderedDict() self.atomic = threading.RLock() - + def keys(self) -> List[str]: return list(self._cache.keys()) @@ -96,4 +99,3 @@ def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""): return cache.acquire(owner=owner, msg=msg) else: return cache - diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py index eb1a77729..6232a4824 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py @@ -1,12 +1,13 @@ -from chatchat.configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM -from chatchat.server.knowledge_base.kb_cache.base import * -from chatchat.server.utils import get_Embeddings -from chatchat.server.knowledge_base.utils import get_vs_path -from langchain.vectorstores.faiss import FAISS -from langchain.docstore.in_memory import InMemoryDocstore -from langchain.schema import Document import os + +from langchain.docstore.in_memory import InMemoryDocstore from langchain.schema import Document +from langchain.vectorstores.faiss import FAISS + +from chatchat.configs import CACHED_MEMO_VS_NUM, CACHED_VS_NUM +from chatchat.server.knowledge_base.kb_cache.base import * +from chatchat.server.knowledge_base.utils import get_vs_path +from chatchat.server.utils import get_Embeddings # patch FAISS to include doc id in Document.metadata @@ -18,6 +19,8 @@ def _new_ds_search(self, search: str) -> Union[str, Document]: if isinstance(doc, Document): doc.metadata["id"] = search return doc + + InMemoryDocstore.search = _new_ds_search @@ -50,11 +53,10 @@ def clear(self): class _FaissPool(CachePool): def new_vector_store( - self, - kb_name: str, - embed_model: str = DEFAULT_EMBEDDING_MODEL, + self, + kb_name: str, + embed_model: str = DEFAULT_EMBEDDING_MODEL, ) -> FAISS: - # create an empty vector store embeddings = get_Embeddings(embed_model=embed_model) doc = Document(page_content="init", metadata={}) @@ -64,10 +66,9 @@ def new_vector_store( return vector_store def new_temp_vector_store( - self, - embed_model: str = DEFAULT_EMBEDDING_MODEL, + self, + embed_model: str = DEFAULT_EMBEDDING_MODEL, ) -> FAISS: - # create an empty vector store embeddings = get_Embeddings(embed_model=embed_model) doc = Document(page_content="init", metadata={}) @@ -88,11 +89,11 @@ def unload_vector_store(self, kb_name: str): class KBFaissPool(_FaissPool): def load_vector_store( - self, - kb_name: str, - vector_name: str = None, - create: bool = True, - embed_model: str = DEFAULT_EMBEDDING_MODEL, + self, + kb_name: str, + vector_name: str = None, + create: bool = True, + embed_model: str = DEFAULT_EMBEDDING_MODEL, ) -> ThreadSafeFaiss: self.atomic.acquire() locked = True @@ -105,18 +106,26 @@ def load_vector_store( with item.acquire(msg="初始化"): self.atomic.release() locked = False - logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.") + logger.info( + f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk." + ) vs_path = get_vs_path(kb_name, vector_name) if os.path.isfile(os.path.join(vs_path, "index.faiss")): embeddings = get_Embeddings(embed_model=embed_model) - vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True, - allow_dangerous_deserialization=True) + vector_store = FAISS.load_local( + vs_path, + embeddings, + normalize_L2=True, + allow_dangerous_deserialization=True, + ) elif create: # create an empty vector store if not os.path.exists(vs_path): os.makedirs(vs_path) - vector_store = self.new_vector_store(kb_name=kb_name, embed_model=embed_model) + vector_store = self.new_vector_store( + kb_name=kb_name, embed_model=embed_model + ) vector_store.save_local(vs_path) else: raise RuntimeError(f"knowledge base {kb_name} not exist.") @@ -126,7 +135,7 @@ def load_vector_store( self.atomic.release() locked = False except Exception as e: - if locked: # we don't know exception raised before or after atomic.release + if locked: # we don't know exception raised before or after atomic.release self.atomic.release() logger.error(e, exc_info=True) raise RuntimeError(f"向量库 {kb_name} 加载失败。") @@ -137,10 +146,11 @@ class MemoFaissPool(_FaissPool): r""" 临时向量库的缓存池 """ + def load_vector_store( - self, - kb_name: str, - embed_model: str = DEFAULT_EMBEDDING_MODEL, + self, + kb_name: str, + embed_model: str = DEFAULT_EMBEDDING_MODEL, ) -> ThreadSafeFaiss: self.atomic.acquire() cache = self.get(kb_name) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py index e73627a16..45641a242 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py @@ -1,36 +1,61 @@ import json import os import urllib -from typing import List, Dict +from typing import Dict, List -from fastapi import File, Form, Body, Query, UploadFile +from fastapi import Body, File, Form, Query, UploadFile from fastapi.responses import FileResponse from langchain.docstore.document import Document from sse_starlette import EventSourceResponse -from chatchat.configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL, - VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, - logger, log_verbose, ) +from chatchat.configs import ( + CHUNK_SIZE, + DEFAULT_EMBEDDING_MODEL, + DEFAULT_VS_TYPE, + OVERLAP_SIZE, + SCORE_THRESHOLD, + VECTOR_SEARCH_TOP_K, + ZH_TITLE_ENHANCE, + log_verbose, + logger, +) from chatchat.server.db.repository.knowledge_file_repository import get_file_detail -from chatchat.server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path, - files2docs_in_thread, KnowledgeFile) -from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory, get_kb_file_details +from chatchat.server.knowledge_base.kb_service.base import ( + KBServiceFactory, + get_kb_file_details, +) from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId -from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool, check_embed_model +from chatchat.server.knowledge_base.utils import ( + KnowledgeFile, + files2docs_in_thread, + get_file_path, + list_files_from_folder, + validate_kb_name, +) +from chatchat.server.utils import ( + BaseResponse, + ListResponse, + check_embed_model, + run_in_thread_pool, +) def search_docs( - query: str = Body("", description="用户输入", examples=["你好"]), - knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), - top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), - score_threshold: float = Body(SCORE_THRESHOLD, - description="知识库匹配相关度阈值,取值范围在0-1之间," - "SCORE越小,相关度越高," - "取到1相当于不筛选,建议设置在0.5左右", - ge=0.0, le=1.0), - file_name: str = Body("", description="文件名称,支持 sql 通配符"), - metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"), + query: str = Body("", description="用户输入", examples=["你好"]), + knowledge_base_name: str = Body( + ..., description="知识库名称", examples=["samples"] + ), + top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + score_threshold: float = Body( + SCORE_THRESHOLD, + description="知识库匹配相关度阈值,取值范围在0-1之间," + "SCORE越小,相关度越高," + "取到1相当于不筛选,建议设置在0.5左右", + ge=0.0, + le=1.0, + ), + file_name: str = Body("", description="文件名称,支持 sql 通配符"), + metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"), ) -> List[Dict]: kb = KBServiceFactory.get_service_by_name(knowledge_base_name) data = [] @@ -47,42 +72,45 @@ def search_docs( return [x.dict() for x in data] -def list_files( - knowledge_base_name: str -) -> ListResponse: +def list_files(knowledge_base_name: str) -> ListResponse: if not validate_kb_name(knowledge_base_name): return ListResponse(code=403, msg="Don't attack me", data=[]) knowledge_base_name = urllib.parse.unquote(knowledge_base_name) kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: - return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) + return ListResponse( + code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[] + ) else: all_docs = get_kb_file_details(knowledge_base_name) return ListResponse(data=all_docs) -def _save_files_in_thread(files: List[UploadFile], - knowledge_base_name: str, - override: bool): +def _save_files_in_thread( + files: List[UploadFile], knowledge_base_name: str, override: bool +): """ 通过多线程将上传的文件保存到对应知识库目录内。 生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}} """ def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict: - ''' + """ 保存单个文件。 - ''' + """ try: filename = file.filename - file_path = get_file_path(knowledge_base_name=knowledge_base_name, doc_name=filename) + file_path = get_file_path( + knowledge_base_name=knowledge_base_name, doc_name=filename + ) data = {"knowledge_base_name": knowledge_base_name, "file_name": filename} file_content = file.file.read() # 读取上传文件的内容 - if (os.path.isfile(file_path) - and not override - and os.path.getsize(file_path) == len(file_content) + if ( + os.path.isfile(file_path) + and not override + and os.path.getsize(file_path) == len(file_content) ): file_status = f"文件 {filename} 已存在。" logger.warn(file_status) @@ -95,11 +123,15 @@ def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dic return dict(code=200, msg=f"成功上传文件 {filename}", data=data) except Exception as e: msg = f"{filename} 文件上传失败,报错信息为: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", exc_info=e if log_verbose else None + ) return dict(code=500, msg=msg, data=data) - params = [{"file": file, "knowledge_base_name": knowledge_base_name, "override": override} for file in files] + params = [ + {"file": file, "knowledge_base_name": knowledge_base_name, "override": override} + for file in files + ] for result in run_in_thread_pool(save_file, params=params): yield result @@ -118,15 +150,17 @@ def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dic def upload_docs( - files: List[UploadFile] = File(..., description="上传文件,支持多文件"), - knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), - override: bool = Form(False, description="覆盖已有文件"), - to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"), - chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), - chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), - zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), - docs: str = Form("", description="自定义的docs,需要转为json字符串"), - not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), + files: List[UploadFile] = File(..., description="上传文件,支持多文件"), + knowledge_base_name: str = Form( + ..., description="知识库名称", examples=["samples"] + ), + override: bool = Form(False, description="覆盖已有文件"), + to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"), + chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), + chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), + zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), + docs: str = Form("", description="自定义的docs,需要转为json字符串"), + not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: """ API接口:上传文件,并/或向量化 @@ -143,7 +177,9 @@ def upload_docs( file_names = list(docs.keys()) # 先将上传的文件保存到磁盘 - for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override): + for result in _save_files_in_thread( + files, knowledge_base_name=knowledge_base_name, override=override + ): filename = result["data"]["file_name"] if result["code"] != 200: failed_files[filename] = result["msg"] @@ -167,14 +203,16 @@ def upload_docs( if not not_refresh_vs_cache: kb.save_vector_store() - return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files}) + return BaseResponse( + code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files} + ) def delete_docs( - knowledge_base_name: str = Body(..., examples=["samples"]), - file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]), - delete_content: bool = Body(False), - not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), + knowledge_base_name: str = Body(..., examples=["samples"]), + file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]), + delete_content: bool = Body(False), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -190,24 +228,30 @@ def delete_docs( failed_files[file_name] = f"未找到文件 {file_name}" try: - kb_file = KnowledgeFile(filename=file_name, - knowledge_base_name=knowledge_base_name) + kb_file = KnowledgeFile( + filename=file_name, knowledge_base_name=knowledge_base_name + ) kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True) except Exception as e: msg = f"{file_name} 文件删除失败,错误信息:{e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", exc_info=e if log_verbose else None + ) failed_files[file_name] = msg if not not_refresh_vs_cache: kb.save_vector_store() - return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files}) + return BaseResponse( + code=200, msg=f"文件删除完成", data={"failed_files": failed_files} + ) def update_info( - knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), - kb_info: str = Body(..., description="知识库介绍", examples=["这是一个知识库"]), + knowledge_base_name: str = Body( + ..., description="知识库名称", examples=["samples"] + ), + kb_info: str = Body(..., description="知识库介绍", examples=["这是一个知识库"]), ): if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -221,14 +265,18 @@ def update_info( def update_docs( - knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), - file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]), - chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), - chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), - zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), - override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"), - docs: str = Body("", description="自定义的docs,需要转为json字符串"), - not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), + knowledge_base_name: str = Body( + ..., description="知识库名称", examples=["samples"] + ), + file_names: List[str] = Body( + ..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]] + ), + chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), + chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), + zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), + override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"), + docs: str = Body("", description="自定义的docs,需要转为json字符串"), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: """ 更新知识库文档 @@ -252,23 +300,32 @@ def update_docs( continue if file_name not in docs: try: - kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name)) + kb_files.append( + KnowledgeFile( + filename=file_name, knowledge_base_name=knowledge_base_name + ) + ) except Exception as e: msg = f"加载文档 {file_name} 时出错:{e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", + exc_info=e if log_verbose else None, + ) failed_files[file_name] = msg # 从文件生成docs,并进行向量化。 # 这里利用了KnowledgeFile的缓存功能,在多线程中加载Document,然后传给KnowledgeFile - for status, result in files2docs_in_thread(kb_files, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - zh_title_enhance=zh_title_enhance): + for status, result in files2docs_in_thread( + kb_files, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance, + ): if status: kb_name, file_name, new_docs = result - kb_file = KnowledgeFile(filename=file_name, - knowledge_base_name=knowledge_base_name) + kb_file = KnowledgeFile( + filename=file_name, knowledge_base_name=knowledge_base_name + ) kb_file.splited_docs = new_docs kb.update_doc(kb_file, not_refresh_vs_cache=True) else: @@ -279,24 +336,31 @@ def update_docs( for file_name, v in docs.items(): try: v = [x if isinstance(x, Document) else Document(**x) for x in v] - kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name) + kb_file = KnowledgeFile( + filename=file_name, knowledge_base_name=knowledge_base_name + ) kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True) except Exception as e: msg = f"为 {file_name} 添加自定义docs时出错:{e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", exc_info=e if log_verbose else None + ) failed_files[file_name] = msg if not not_refresh_vs_cache: kb.save_vector_store() - return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files}) + return BaseResponse( + code=200, msg=f"更新文档完成", data={"failed_files": failed_files} + ) def download_doc( - knowledge_base_name: str = Query(..., description="知识库名称", examples=["samples"]), - file_name: str = Query(..., description="文件名称", examples=["test.txt"]), - preview: bool = Query(False, description="是:浏览器内预览;否:下载"), + knowledge_base_name: str = Query( + ..., description="知识库名称", examples=["samples"] + ), + file_name: str = Query(..., description="文件名称", examples=["test.txt"]), + preview: bool = Query(False, description="是:浏览器内预览;否:下载"), ): """ 下载知识库文档 @@ -314,8 +378,9 @@ def download_doc( content_disposition_type = None try: - kb_file = KnowledgeFile(filename=file_name, - knowledge_base_name=knowledge_base_name) + kb_file = KnowledgeFile( + filename=file_name, knowledge_base_name=knowledge_base_name + ) if os.path.exists(kb_file.filepath): return FileResponse( @@ -326,22 +391,23 @@ def download_doc( ) except Exception as e: msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", exc_info=e if log_verbose else None + ) return BaseResponse(code=500, msg=msg) return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") def recreate_vector_store( - knowledge_base_name: str = Body(..., examples=["samples"]), - allow_empty_kb: bool = Body(True), - vs_type: str = Body(DEFAULT_VS_TYPE), - embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), - chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), - chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), - zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), - not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), + knowledge_base_name: str = Body(..., examples=["samples"]), + allow_empty_kb: bool = Body(True), + vs_type: str = Body(DEFAULT_VS_TYPE), + embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), + chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), + chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), + zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ): """ recreate vector store from the content. @@ -355,7 +421,9 @@ def output(): if not kb.exists() and not allow_empty_kb: yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} else: - error_msg = f"could not recreate vector store because failed to access embed model." + error_msg = ( + f"could not recreate vector store because failed to access embed model." + ) if not kb.check_embed_model(error_msg): yield {"code": 404, "msg": error_msg} else: @@ -365,30 +433,39 @@ def output(): files = list_files_from_folder(knowledge_base_name) kb_files = [(file, knowledge_base_name) for file in files] i = 0 - for status, result in files2docs_in_thread(kb_files, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - zh_title_enhance=zh_title_enhance): + for status, result in files2docs_in_thread( + kb_files, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance, + ): if status: kb_name, file_name, docs = result - kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name) + kb_file = KnowledgeFile( + filename=file_name, knowledge_base_name=kb_name + ) kb_file.splited_docs = docs - yield json.dumps({ - "code": 200, - "msg": f"({i + 1} / {len(files)}): {file_name}", - "total": len(files), - "finished": i + 1, - "doc": file_name, - }, ensure_ascii=False) + yield json.dumps( + { + "code": 200, + "msg": f"({i + 1} / {len(files)}): {file_name}", + "total": len(files), + "finished": i + 1, + "doc": file_name, + }, + ensure_ascii=False, + ) kb.add_doc(kb_file, not_refresh_vs_cache=True) else: kb_name, file_name, error = result msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。" logger.error(msg) - yield json.dumps({ - "code": 500, - "msg": msg, - }) + yield json.dumps( + { + "code": 500, + "msg": msg, + } + ) i += 1 if not not_refresh_vs_cache: kb.save_vector_store() diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py index d5d905942..093dc1bb4 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py @@ -1,51 +1,71 @@ -from abc import ABC, abstractmethod - import operator import os +from abc import ABC, abstractmethod from pathlib import Path -from langchain.docstore.document import Document +from typing import Dict, List, Optional, Tuple, Union -from typing import List, Union, Dict, Optional, Tuple +from langchain.docstore.document import Document -from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - DEFAULT_EMBEDDING_MODEL, KB_INFO, logger) +from chatchat.configs import ( + DEFAULT_EMBEDDING_MODEL, + KB_INFO, + SCORE_THRESHOLD, + VECTOR_SEARCH_TOP_K, + kbs_config, + logger, +) from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema from chatchat.server.db.repository.knowledge_base_repository import ( - add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, - load_kb_from_db, get_kb_detail, + add_kb_to_db, + delete_kb_from_db, + get_kb_detail, + kb_exists, + list_kbs_from_db, + load_kb_from_db, ) from chatchat.server.db.repository.knowledge_file_repository import ( - add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db, - count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db, + add_file_to_db, + count_files_from_db, + delete_file_from_db, + delete_files_from_db, + file_exists_in_db, + get_file_detail, list_docs_from_db, + list_files_from_db, ) +from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId from chatchat.server.knowledge_base.utils import ( - get_kb_path, get_doc_path, KnowledgeFile, - list_kbs_from_folder, list_files_from_folder, + KnowledgeFile, + get_doc_path, + get_kb_path, + list_files_from_folder, + list_kbs_from_folder, ) -from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId from chatchat.server.utils import check_embed_model as _check_embed_model + class SupportedVSType: - FAISS = 'faiss' - MILVUS = 'milvus' - DEFAULT = 'default' - ZILLIZ = 'zilliz' - PG = 'pg' - RELYT = 'relyt' - ES = 'es' - CHROMADB = 'chromadb' + FAISS = "faiss" + MILVUS = "milvus" + DEFAULT = "default" + ZILLIZ = "zilliz" + PG = "pg" + RELYT = "relyt" + ES = "es" + CHROMADB = "chromadb" class KBService(ABC): - - def __init__(self, - knowledge_base_name: str, - kb_info: str = None, - embed_model: str = DEFAULT_EMBEDDING_MODEL, - ): + def __init__( + self, + knowledge_base_name: str, + kb_info: str = None, + embed_model: str = DEFAULT_EMBEDDING_MODEL, + ): self.kb_name = knowledge_base_name - self.kb_info = kb_info or KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库") + self.kb_info = kb_info or KB_INFO.get( + knowledge_base_name, f"关于{knowledge_base_name}的知识库" + ) self.embed_model = embed_model self.kb_path = get_kb_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name) @@ -55,9 +75,9 @@ def __repr__(self) -> str: return f"{self.kb_name} @ {self.embed_model}" def save_vector_store(self): - ''' + """ 保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持 - ''' + """ pass def check_embed_model(self, error_msg: str) -> bool: @@ -74,7 +94,9 @@ def create_kb(self): if not os.path.exists(self.doc_path): os.makedirs(self.doc_path) - status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model) + status = add_kb_to_db( + self.kb_name, self.kb_info, self.vs_type(), self.embed_model + ) if status: self.do_create_kb() @@ -101,7 +123,9 @@ def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): 向知识库添加文件 如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True """ - if not self.check_embed_model(f"could not add docs because failed to access embed model."): + if not self.check_embed_model( + f"could not add docs because failed to access embed model." + ): return False if docs: @@ -120,18 +144,24 @@ def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): rel_path = Path(source).relative_to(self.doc_path) doc.metadata["source"] = str(rel_path.as_posix().strip("/")) except Exception as e: - print(f"cannot convert absolute path ({source}) to relative path. error is : {e}") + print( + f"cannot convert absolute path ({source}) to relative path. error is : {e}" + ) self.delete_doc(kb_file) doc_infos = self.do_add_doc(docs, **kwargs) - status = add_file_to_db(kb_file, - custom_docs=custom_docs, - docs_count=len(docs), - doc_infos=doc_infos) + status = add_file_to_db( + kb_file, + custom_docs=custom_docs, + docs_count=len(docs), + doc_infos=doc_infos, + ) else: status = False return status - def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs): + def delete_doc( + self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs + ): """ 从知识库删除文件 """ @@ -146,7 +176,9 @@ def update_info(self, kb_info: str): 更新知识库介绍 """ self.kb_info = kb_info - status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model) + status = add_kb_to_db( + self.kb_name, self.kb_info, self.vs_type(), self.embed_model + ) return status def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): @@ -154,7 +186,9 @@ def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs 使用content中的文件更新向量库 如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True """ - if not self.check_embed_model(f"could not update docs because failed to access embed model."): + if not self.check_embed_model( + f"could not update docs because failed to access embed model." + ): return False if os.path.exists(kb_file.filepath): @@ -162,8 +196,9 @@ def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs return self.add_doc(kb_file, docs=docs, **kwargs) def exist_doc(self, file_name: str): - return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name, - filename=file_name)) + return file_exists_in_db( + KnowledgeFile(knowledge_base_name=self.kb_name, filename=file_name) + ) def list_files(self): return list_files_from_db(self.kb_name) @@ -171,12 +206,15 @@ def list_files(self): def count_files(self): return count_files_from_db(self.kb_name) - def search_docs(self, - query: str, - top_k: int = VECTOR_SEARCH_TOP_K, - score_threshold: float = SCORE_THRESHOLD, - ) ->List[Document]: - if not self.check_embed_model(f"could not search docs because failed to access embed model."): + def search_docs( + self, + query: str, + top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: float = SCORE_THRESHOLD, + ) -> List[Document]: + if not self.check_embed_model( + f"could not search docs because failed to access embed model." + ): return [] docs = self.do_search(query, top_k, score_threshold) return docs @@ -188,11 +226,13 @@ def del_doc_by_ids(self, ids: List[str]) -> bool: raise NotImplementedError def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool: - ''' + """ 传入参数为: {doc_id: Document, ...} 如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档 - ''' - if not self.check_embed_model(f"could not update docs because failed to access embed model."): + """ + if not self.check_embed_model( + f"could not update docs because failed to access embed model." + ): return False self.del_doc_by_ids(list(docs.keys())) @@ -206,11 +246,15 @@ def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool: self.do_add_doc(docs=pending_docs, ids=ids) return True - def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]: - ''' + def list_docs( + self, file_name: str = None, metadata: Dict = {} + ) -> List[DocumentWithVSId]: + """ 通过file_name或metadata检索Document - ''' - doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata) + """ + doc_infos = list_docs_from_db( + kb_name=self.kb_name, file_name=file_name, metadata=metadata + ) docs = [] for x in doc_infos: doc_info = self.get_doc_by_ids([x["id"]])[0] @@ -224,19 +268,21 @@ def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document pass return docs - def get_relative_source_path(self,filepath: str): - ''' - 将文件路径转化为相对路径,保证查询时一致 - ''' - relative_path = filepath - if os.path.isabs(relative_path): - try: - relative_path = Path(filepath).relative_to(self.doc_path) - except Exception as e: - print(f"cannot convert absolute path ({relative_path}) to relative path. error is : {e}") - - relative_path = str(relative_path.as_posix().strip("/")) - return relative_path + def get_relative_source_path(self, filepath: str): + """ + 将文件路径转化为相对路径,保证查询时一致 + """ + relative_path = filepath + if os.path.isabs(relative_path): + try: + relative_path = Path(filepath).relative_to(self.doc_path) + except Exception as e: + print( + f"cannot convert absolute path ({relative_path}) to relative path. error is : {e}" + ) + + relative_path = str(relative_path.as_posix().strip("/")) + return relative_path @abstractmethod def do_create_kb(self): @@ -273,29 +319,30 @@ def do_drop_kb(self): pass @abstractmethod - def do_search(self, - query: str, - top_k: int, - score_threshold: float, - ) -> List[Tuple[Document, float]]: + def do_search( + self, + query: str, + top_k: int, + score_threshold: float, + ) -> List[Tuple[Document, float]]: """ 搜索知识库子类实自己逻辑 """ pass @abstractmethod - def do_add_doc(self, - docs: List[Document], - **kwargs, - ) -> List[Dict]: + def do_add_doc( + self, + docs: List[Document], + **kwargs, + ) -> List[Dict]: """ 向知识库添加文档子类实自己逻辑 """ pass @abstractmethod - def do_delete_doc(self, - kb_file: KnowledgeFile): + def do_delete_doc(self, kb_file: KnowledgeFile): """ 从知识库删除文档子类实自己逻辑 """ @@ -310,42 +357,77 @@ def do_clear_vs(self): class KBServiceFactory: - @staticmethod - def get_service(kb_name: str, - vector_store_type: Union[str, SupportedVSType], - embed_model: str = DEFAULT_EMBEDDING_MODEL, - kb_info: str = None, - ) -> KBService: + def get_service( + kb_name: str, + vector_store_type: Union[str, SupportedVSType], + embed_model: str = DEFAULT_EMBEDDING_MODEL, + kb_info: str = None, + ) -> KBService: if isinstance(vector_store_type, str): vector_store_type = getattr(SupportedVSType, vector_store_type.upper()) - params = {"knowledge_base_name": kb_name, "embed_model": embed_model, "kb_info": kb_info} + params = { + "knowledge_base_name": kb_name, + "embed_model": embed_model, + "kb_info": kb_info, + } if SupportedVSType.FAISS == vector_store_type: - from chatchat.server.knowledge_base.kb_service.faiss_kb_service import FaissKBService + from chatchat.server.knowledge_base.kb_service.faiss_kb_service import ( + FaissKBService, + ) + return FaissKBService(**params) elif SupportedVSType.PG == vector_store_type: - from chatchat.server.knowledge_base.kb_service.pg_kb_service import PGKBService + from chatchat.server.knowledge_base.kb_service.pg_kb_service import ( + PGKBService, + ) + return PGKBService(**params) elif SupportedVSType.RELYT == vector_store_type: - from chatchat.server.knowledge_base.kb_service.relyt_kb_service import RelytKBService + from chatchat.server.knowledge_base.kb_service.relyt_kb_service import ( + RelytKBService, + ) + return RelytKBService(**params) elif SupportedVSType.MILVUS == vector_store_type: - from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService + from chatchat.server.knowledge_base.kb_service.milvus_kb_service import ( + MilvusKBService, + ) + return MilvusKBService(**params) elif SupportedVSType.ZILLIZ == vector_store_type: - from chatchat.server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService + from chatchat.server.knowledge_base.kb_service.zilliz_kb_service import ( + ZillizKBService, + ) + return ZillizKBService(**params) elif SupportedVSType.DEFAULT == vector_store_type: - from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService - return MilvusKBService(**params) # other milvus parameters are set in model_config.kbs_config + from chatchat.server.knowledge_base.kb_service.milvus_kb_service import ( + MilvusKBService, + ) + + return MilvusKBService( + **params + ) # other milvus parameters are set in model_config.kbs_config elif SupportedVSType.ES == vector_store_type: - from chatchat.server.knowledge_base.kb_service.es_kb_service import ESKBService + from chatchat.server.knowledge_base.kb_service.es_kb_service import ( + ESKBService, + ) + return ESKBService(**params) elif SupportedVSType.CHROMADB == vector_store_type: - from chatchat.server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService + from chatchat.server.knowledge_base.kb_service.chromadb_kb_service import ( + ChromaKBService, + ) + return ChromaKBService(**params) - elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. - from chatchat.server.knowledge_base.kb_service.default_kb_service import DefaultKBService + elif ( + SupportedVSType.DEFAULT == vector_store_type + ): # kb_exists of default kbservice is False, to make validation easier. + from chatchat.server.knowledge_base.kb_service.default_kb_service import ( + DefaultKBService, + ) + return DefaultKBService(kb_name) @staticmethod @@ -362,7 +444,7 @@ def get_default(): def get_kb_details() -> List[Dict]: kbs_in_folder = list_kbs_from_folder() - kbs_in_db:List[KnowledgeBaseSchema] = KBService.list_kbs() + kbs_in_db: List[KnowledgeBaseSchema] = KBService.list_kbs() result = {} for kb in kbs_in_folder: @@ -378,19 +460,18 @@ def get_kb_details() -> List[Dict]: } for kb_detail in kbs_in_db: - kb_detail=kb_detail.model_dump() - kb_name=kb_detail["kb_name"] + kb_detail = kb_detail.model_dump() + kb_name = kb_detail["kb_name"] kb_detail["in_db"] = True if kb_name in result: result[kb_name].update(kb_detail) else: kb_detail["in_folder"] = False result[kb_name] = kb_detail - data = [] for i, v in enumerate(result.values()): - v['No'] = i + 1 + v["No"] = i + 1 data.append(v) return data @@ -431,7 +512,7 @@ def get_kb_file_details(kb_name: str) -> List[Dict]: data = [] for i, v in enumerate(result.values()): - v['No'] = i + 1 + v["No"] = i + 1 data.append(v) return data @@ -439,9 +520,7 @@ def get_kb_file_details(kb_name: str) -> List[Dict]: def score_threshold_process(score_threshold, k, docs): if score_threshold is not None: - cmp = ( - operator.le - ) + cmp = operator.le docs = [ (doc, similarity) for doc, similarity in docs diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py index 48801c8b3..c9e28c3d2 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py @@ -2,25 +2,31 @@ from typing import Any, Dict, List, Tuple import chromadb -from chromadb.api.types import (GetResult, QueryResult) +from chromadb.api.types import GetResult, QueryResult from langchain.docstore.document import Document from chatchat.configs import SCORE_THRESHOLD +from chatchat.server.file_rag.utils import get_Retriever from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path from chatchat.server.utils import get_Embeddings -from chatchat.server.file_rag.utils import get_Retriever def _get_result_to_documents(get_result: GetResult) -> List[Document]: - if not get_result['documents']: + if not get_result["documents"]: return [] - _metadatas = get_result['metadatas'] if get_result['metadatas'] else [{}] * len(get_result['documents']) + _metadatas = ( + get_result["metadatas"] + if get_result["metadatas"] + else [{}] * len(get_result["documents"]) + ) document_list = [] - for page_content, metadata in zip(get_result['documents'], _metadatas): - document_list.append(Document(**{'page_content': page_content, 'metadata': metadata})) + for page_content, metadata in zip(get_result["documents"], _metadatas): + document_list.append( + Document(**{"page_content": page_content, "metadata": metadata}) + ) return document_list @@ -74,13 +80,14 @@ def do_drop_kb(self): if not str(e) == f"Collection {self.kb_name} does not exist.": raise e - def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[ - Tuple[Document, float]]: + def do_search( + self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD + ) -> List[Tuple[Document, float]]: retriever = get_Retriever("vectorstore").from_vectorstore( - self.collection, - top_k=top_k, - score_threshold=score_threshold, - ) + self.collection, + top_k=top_k, + score_threshold=score_threshold, + ) docs = retriever.get_relevant_documents(query) return docs @@ -92,7 +99,9 @@ def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: embeddings = embed_func.embed_documents(texts=texts) ids = [str(uuid.uuid1()) for _ in range(len(texts))] for _id, text, embedding, metadata in zip(ids, texts, embeddings, metadatas): - self.collection.add(ids=_id, embeddings=embedding, metadatas=metadata, documents=text) + self.collection.add( + ids=_id, embeddings=embedding, metadatas=metadata, documents=text + ) doc_infos.append({"id": _id, "metadata": metadata}) return doc_infos diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py index 9a11baa44..02a5dfd32 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py @@ -1,36 +1,41 @@ -from typing import List +import logging import os import shutil +from typing import List + +from elasticsearch import BadRequestError, Elasticsearch from langchain.schema import Document -from langchain_community.vectorstores.elasticsearch import ElasticsearchStore, ApproxRetrievalStrategy +from langchain_community.vectorstores.elasticsearch import ( + ApproxRetrievalStrategy, + ElasticsearchStore, +) + +from chatchat.configs import KB_ROOT_PATH, kbs_config +from chatchat.server.file_rag.utils import get_Retriever from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.utils import get_Embeddings -from elasticsearch import Elasticsearch, BadRequestError -from chatchat.configs import kbs_config, KB_ROOT_PATH -from chatchat.server.file_rag.utils import get_Retriever - -import logging logger = logging.getLogger() class ESKBService(KBService): - def do_init(self): self.kb_path = self.get_kb_path(self.kb_name) self.index_name = os.path.split(self.kb_path)[-1] - self.IP = kbs_config[self.vs_type()]['host'] - self.PORT = kbs_config[self.vs_type()]['port'] - self.user = kbs_config[self.vs_type()].get("user",'') - self.password = kbs_config[self.vs_type()].get("password",'') - self.dims_length = kbs_config[self.vs_type()].get("dims_length",None) + self.IP = kbs_config[self.vs_type()]["host"] + self.PORT = kbs_config[self.vs_type()]["port"] + self.user = kbs_config[self.vs_type()].get("user", "") + self.password = kbs_config[self.vs_type()].get("password", "") + self.dims_length = kbs_config[self.vs_type()].get("dims_length", None) self.embeddings_model = get_Embeddings(self.embed_model) try: # ES python客户端连接(仅连接) if self.user != "" and self.password != "": - self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}", - basic_auth=(self.user,self.password)) + self.es_client_python = Elasticsearch( + f"http://{self.IP}:{self.PORT}", + basic_auth=(self.user, self.password), + ) else: logger.warning("ES未配置用户名和密码") self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}") @@ -47,11 +52,13 @@ def do_init(self): "dense_vector": { "type": "dense_vector", "dims": self.dims_length, - "index": True + "index": True, } } } - self.es_client_python.indices.create(index=self.index_name, mappings=mappings) + self.es_client_python.indices.create( + index=self.index_name, mappings=mappings + ) except BadRequestError as e: logger.error("创建索引失败,重新") logger.error(e) @@ -67,13 +74,10 @@ def do_init(self): strategy=ApproxRetrievalStrategy(), es_params={ "timeout": 60, - } + }, ) if self.user != "" and self.password != "": - params.update( - es_user=self.user, - es_password=self.password - ) + params.update(es_user=self.user, es_password=self.password) self.db = ElasticsearchStore(**params) except ConnectionError: logger.error("### 初始化 Elasticsearch 失败!") @@ -84,9 +88,8 @@ def do_init(self): try: # 尝试通过db_init创建索引 self.db._create_index_if_not_exists( - index_name=self.index_name, - dims_length=self.dims_length - ) + index_name=self.index_name, dims_length=self.dims_length + ) except Exception as e: logger.error("创建索引失败...") logger.error(e) @@ -98,7 +101,9 @@ def get_kb_path(knowledge_base_name: str): @staticmethod def get_vs_path(knowledge_base_name: str): - return os.path.join(ESKBService.get_kb_path(knowledge_base_name), "vector_store") + return os.path.join( + ESKBService.get_kb_path(knowledge_base_name), "vector_store" + ) def do_create_kb(self): ... @@ -106,7 +111,7 @@ def do_create_kb(self): def vs_type(self) -> str: return SupportedVSType.ES - def do_search(self, query:str, top_k: int, score_threshold: float): + def do_search(self, query: str, top_k: int, score_threshold: float): # 文本相似性检索 retriever = get_Retriever("vectorstore").from_vectorstore( self.db, @@ -133,9 +138,9 @@ def get_doc_by_ids(self, ids: List[str]) -> List[Document]: def del_doc_by_ids(self, ids: List[str]) -> bool: for doc_id in ids: try: - self.es_client_python.delete(index=self.index_name, - id=doc_id, - refresh=True) + self.es_client_python.delete( + index=self.index_name, id=doc_id, refresh=True + ) except Exception as e: logger.error(f"ES Docs Delete Error! {e}") @@ -145,65 +150,65 @@ def do_delete_doc(self, kb_file, **kwargs): query = { "query": { "term": { - "metadata.source.keyword": self.get_relative_source_path(kb_file.filepath) + "metadata.source.keyword": self.get_relative_source_path( + kb_file.filepath + ) } } } # 注意设置size,默认返回10个。 search_results = self.es_client_python.search(body=query, size=50) - delete_list = [hit["_id"] for hit in search_results['hits']['hits']] + delete_list = [hit["_id"] for hit in search_results["hits"]["hits"]] if len(delete_list) == 0: return None else: for doc_id in delete_list: try: - self.es_client_python.delete(index=self.index_name, - id=doc_id, - refresh=True) + self.es_client_python.delete( + index=self.index_name, id=doc_id, refresh=True + ) except Exception as e: logger.error(f"ES Docs Delete Error! {e}") # self.db.delete(ids=delete_list) - #self.es_client_python.indices.refresh(index=self.index_name) - + # self.es_client_python.indices.refresh(index=self.index_name) def do_add_doc(self, docs: List[Document], **kwargs): - '''向知识库添加文件''' + """向知识库添加文件""" + + print( + f"server.knowledge_base.kb_service.es_kb_service.do_add_doc 输入的docs参数长度为:{len(docs)}" + ) + print("*" * 100) - print(f"server.knowledge_base.kb_service.es_kb_service.do_add_doc 输入的docs参数长度为:{len(docs)}") - print("*"*100) - self.db.add_documents(documents=docs) # 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...] print("写入数据成功.") - print("*"*100) + print("*" * 100) if self.es_client_python.indices.exists(index=self.index_name): file_path = docs[0].metadata.get("source") query = { "query": { - "term": { - "metadata.source.keyword": file_path - }, - "term": { - "_index": self.index_name - } + "term": {"metadata.source.keyword": file_path}, + "term": {"_index": self.index_name}, } } # 注意设置size,默认返回10个。 search_results = self.es_client_python.search(body=query, size=50) if len(search_results["hits"]["hits"]) == 0: raise ValueError("召回元素个数为0") - info_docs = [{"id":hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"]] + info_docs = [ + {"id": hit["_id"], "metadata": hit["_source"]["metadata"]} + for hit in search_results["hits"]["hits"] + ] return info_docs - def do_clear_vs(self): """从知识库删除全部向量""" if self.es_client_python.indices.exists(index=self.kb_name): self.es_client_python.indices.delete(index=self.kb_name) - def do_drop_kb(self): """删除知识库""" # self.kb_file: 知识库路径 @@ -211,14 +216,9 @@ def do_drop_kb(self): shutil.rmtree(self.kb_path) -if __name__ == '__main__': +if __name__ == "__main__": esKBService = ESKBService("test") - #esKBService.clear_vs() - #esKBService.create_kb() + # esKBService.clear_vs() + # esKBService.create_kb() esKBService.add_doc(KnowledgeFile(filename="README.md", knowledge_base_name="test")) print(esKBService.search_docs("如何启动api服务")) - - - - - diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py index 52738ae81..dd4ae3a93 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py @@ -1,13 +1,17 @@ import os import shutil +from typing import Dict, List, Tuple + +from langchain.docstore.document import Document from chatchat.configs import SCORE_THRESHOLD +from chatchat.server.file_rag.utils import get_Retriever +from chatchat.server.knowledge_base.kb_cache.faiss_cache import ( + ThreadSafeFaiss, + kb_faiss_pool, +) from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType -from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path -from langchain.docstore.document import Document -from typing import List, Dict, Tuple -from chatchat.server.file_rag.utils import get_Retriever class FaissKBService(KBService): @@ -25,9 +29,11 @@ def get_kb_path(self): return get_kb_path(self.kb_name) def load_vector_store(self) -> ThreadSafeFaiss: - return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, - vector_name=self.vector_name, - embed_model=self.embed_model) + return kb_faiss_pool.load_vector_store( + kb_name=self.kb_name, + vector_name=self.vector_name, + embed_model=self.embed_model, + ) def save_vector_store(self): self.load_vector_store().save(self.vs_path) @@ -57,41 +63,45 @@ def do_drop_kb(self): except Exception: pass - def do_search(self, - query: str, - top_k: int, - score_threshold: float = SCORE_THRESHOLD, - ) -> List[Tuple[Document, float]]: + def do_search( + self, + query: str, + top_k: int, + score_threshold: float = SCORE_THRESHOLD, + ) -> List[Tuple[Document, float]]: with self.load_vector_store().acquire() as vs: retriever = get_Retriever("ensemble").from_vectorstore( - vs, - top_k=top_k, - score_threshold=score_threshold, - ) + vs, + top_k=top_k, + score_threshold=score_threshold, + ) docs = retriever.get_relevant_documents(query) return docs - def do_add_doc(self, - docs: List[Document], - **kwargs, - ) -> List[Dict]: - + def do_add_doc( + self, + docs: List[Document], + **kwargs, + ) -> List[Dict]: texts = [x.page_content for x in docs] metadatas = [x.metadata for x in docs] with self.load_vector_store().acquire() as vs: embeddings = vs.embeddings.embed_documents(texts) - ids = vs.add_embeddings(text_embeddings=zip(texts, embeddings), - metadatas=metadatas) + ids = vs.add_embeddings( + text_embeddings=zip(texts, embeddings), metadatas=metadatas + ) if not kwargs.get("not_refresh_vs_cache"): vs.save_local(self.vs_path) doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] return doc_infos - def do_delete_doc(self, - kb_file: KnowledgeFile, - **kwargs): + def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): with self.load_vector_store().acquire() as vs: - ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source").lower() == kb_file.filename.lower()] + ids = [ + k + for k, v in vs.docstore._dict.items() + if v.metadata.get("source").lower() == kb_file.filename.lower() + ] if len(ids) > 0: vs.delete(ids) if not kwargs.get("not_refresh_vs_cache"): @@ -118,7 +128,7 @@ def exist_doc(self, file_name: str): return False -if __name__ == '__main__': +if __name__ == "__main__": faissService = FaissKBService("test") faissService.add_doc(KnowledgeFile("README.md", "test")) faissService.delete_doc(KnowledgeFile("README.md", "test")) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py index 8eddb5f43..4507be685 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py @@ -1,16 +1,18 @@ -from typing import List, Dict, Optional +import os +from typing import Dict, List, Optional from langchain.schema import Document from langchain.vectorstores.milvus import Milvus -import os from chatchat.configs import kbs_config from chatchat.server.db.repository import list_file_num_docs_id_by_kb_name_and_file_name - -from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \ - score_threshold_process -from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.file_rag.utils import get_Retriever +from chatchat.server.knowledge_base.kb_service.base import ( + KBService, + SupportedVSType, + score_threshold_process, +) +from chatchat.server.knowledge_base.utils import KnowledgeFile class MilvusKBService(KBService): @@ -19,20 +21,23 @@ class MilvusKBService(KBService): @staticmethod def get_collection(milvus_name): from pymilvus import Collection + return Collection(milvus_name) def get_doc_by_ids(self, ids: List[str]) -> List[Document]: result = [] if self.milvus.col: # ids = [int(id) for id in ids] # for milvus if needed #pr 2725 - data_list = self.milvus.col.query(expr=f'pk in {[int(_id) for _id in ids]}', output_fields=["*"]) + data_list = self.milvus.col.query( + expr=f"pk in {[int(_id) for _id in ids]}", output_fields=["*"] + ) for data in data_list: text = data.pop("text") result.append(Document(page_content=text, metadata=data)) return result def del_doc_by_ids(self, ids: List[str]) -> bool: - self.milvus.col.delete(expr=f'pk in {ids}') + self.milvus.col.delete(expr=f"pk in {ids}") @staticmethod def search(milvus_name, content, limit=3): @@ -41,7 +46,9 @@ def search(milvus_name, content, limit=3): "params": {"nprobe": 10}, } c = MilvusKBService.get_collection(milvus_name) - return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"]) + return c.search( + content, "embeddings", search_params, limit=limit, output_fields=["content"] + ) def do_create_kb(self): pass @@ -50,12 +57,13 @@ def vs_type(self) -> str: return SupportedVSType.MILVUS def _load_milvus(self): - self.milvus = Milvus(embedding_function=(self.embed_model), - collection_name=self.kb_name, - connection_args=kbs_config.get("milvus"), - index_params=kbs_config.get("milvus_kwargs")["index_params"], - search_params=kbs_config.get("milvus_kwargs")["search_params"] - ) + self.milvus = Milvus( + embedding_function=(self.embed_model), + collection_name=self.kb_name, + connection_args=kbs_config.get("milvus"), + index_params=kbs_config.get("milvus_kwargs")["index_params"], + search_params=kbs_config.get("milvus_kwargs")["search_params"], + ) def do_init(self): self._load_milvus() @@ -92,9 +100,11 @@ def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: return doc_infos def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): - id_list = list_file_num_docs_id_by_kb_name_and_file_name(kb_file.kb_name, kb_file.filename) + id_list = list_file_num_docs_id_by_kb_name_and_file_name( + kb_file.kb_name, kb_file.filename + ) if self.milvus.col: - self.milvus.col.delete(expr=f'pk in {id_list}') + self.milvus.col.delete(expr=f"pk in {id_list}") # Issue 2846, for windows # if self.milvus.col: @@ -110,7 +120,7 @@ def do_clear_vs(self): self.do_init() -if __name__ == '__main__': +if __name__ == "__main__": # 测试建表使用 from chatchat.server.db.base import Base, engine diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py index d63c1ff51..a101c1f41 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py @@ -1,39 +1,50 @@ import json -from typing import List, Dict, Optional +import shutil +from typing import Dict, List, Optional +import sqlalchemy from langchain.schema import Document -from langchain.vectorstores.pgvector import PGVector, DistanceStrategy +from langchain.vectorstores.pgvector import DistanceStrategy, PGVector from sqlalchemy import text +from sqlalchemy.engine.base import Engine +from sqlalchemy.orm import Session from chatchat.configs import kbs_config - -from chatchat.server.knowledge_base.kb_service.base import SupportedVSType, KBService, \ - score_threshold_process +from chatchat.server.file_rag.utils import get_Retriever +from chatchat.server.knowledge_base.kb_service.base import ( + KBService, + SupportedVSType, + score_threshold_process, +) from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.utils import get_Embeddings -import shutil -import sqlalchemy -from sqlalchemy.engine.base import Engine -from sqlalchemy.orm import Session -from chatchat.server.file_rag.utils import get_Retriever class PGKBService(KBService): - engine: Engine = sqlalchemy.create_engine(kbs_config.get("pg").get("connection_uri"), pool_size=10) + engine: Engine = sqlalchemy.create_engine( + kbs_config.get("pg").get("connection_uri"), pool_size=10 + ) def _load_pg_vector(self): - self.pg_vector = PGVector(embedding_function=get_Embeddings(self.embed_model), - collection_name=self.kb_name, - distance_strategy=DistanceStrategy.EUCLIDEAN, - connection=PGKBService.engine, - connection_string=kbs_config.get("pg").get("connection_uri")) + self.pg_vector = PGVector( + embedding_function=get_Embeddings(self.embed_model), + collection_name=self.kb_name, + distance_strategy=DistanceStrategy.EUCLIDEAN, + connection=PGKBService.engine, + connection_string=kbs_config.get("pg").get("connection_uri"), + ) def get_doc_by_ids(self, ids: List[str]) -> List[Document]: with Session(PGKBService.engine) as session: - stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE custom_id = ANY(:ids)") - results = [Document(page_content=row[0], metadata=row[1]) for row in - session.execute(stmt, {'ids': ids}).fetchall()] + stmt = text( + "SELECT document, cmetadata FROM langchain_pg_embedding WHERE custom_id = ANY(:ids)" + ) + results = [ + Document(page_content=row[0], metadata=row[1]) + for row in session.execute(stmt, {"ids": ids}).fetchall() + ] return results + def del_doc_by_ids(self, ids: List[str]) -> bool: return super().del_doc_by_ids(ids) @@ -48,7 +59,9 @@ def vs_type(self) -> str: def do_drop_kb(self): with Session(PGKBService.engine) as session: - session.execute(text(f''' + session.execute( + text( + f""" -- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录 DELETE FROM langchain_pg_embedding WHERE collection_id IN ( @@ -56,7 +69,9 @@ def do_drop_kb(self): ); -- 删除 langchain_pg_collection 表中 记录 DELETE FROM langchain_pg_collection WHERE name = '{self.kb_name}'; - ''')) + """ + ) + ) session.commit() shutil.rmtree(self.kb_path) @@ -78,8 +93,11 @@ def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): with Session(PGKBService.engine) as session: session.execute( text( - ''' DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;'''.replace( - "filepath", self.get_relative_source_path(kb_file.filepath)))) + """ DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;""".replace( + "filepath", self.get_relative_source_path(kb_file.filepath) + ) + ) + ) session.commit() def do_clear_vs(self): @@ -87,7 +105,7 @@ def do_clear_vs(self): self.pg_vector.create_collection() -if __name__ == '__main__': +if __name__ == "__main__": from chatchat.server.db.base import Base, engine # Base.metadata.create_all(bind=engine) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/relyt_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/relyt_kb_service.py index 7886751b3..0b851024a 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/relyt_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/relyt_kb_service.py @@ -1,18 +1,21 @@ -from typing import List, Dict +from typing import Dict, List +from configs import kbs_config from langchain.schema import Document from langchain_community.vectorstores.pgvecto_rs import PGVecto_rs -from sqlalchemy import text, create_engine +from sqlalchemy import create_engine, text from sqlalchemy.orm import Session -from configs import kbs_config -from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ - score_threshold_process +from server.knowledge_base.kb_service.base import ( + EmbeddingsFunAdapter, + KBService, + SupportedVSType, + score_threshold_process, +) from server.knowledge_base.utils import KnowledgeFile class RelytKBService(KBService): - def _load_relyt_vector(self): embedding_func = EmbeddingsFunAdapter(self.embed_model) sample_embedding = embedding_func.embed_query("Hello relyt!") @@ -25,18 +28,22 @@ def _load_relyt_vector(self): self.engine = create_engine(kbs_config.get("relyt").get("connection_uri")) def get_doc_by_ids(self, ids: List[str]) -> List[Document]: - ids_str = ', '.join([f"{id}" for id in ids]) + ids_str = ", ".join([f"{id}" for id in ids]) with Session(self.engine) as session: - stmt = text(f"SELECT text, meta FROM collection_{self.kb_name} WHERE id in (:ids)") - results = [Document(page_content=row[0], metadata=row[1]) for row in - session.execute(stmt, {'ids': ids_str}).fetchall()] + stmt = text( + f"SELECT text, meta FROM collection_{self.kb_name} WHERE id in (:ids)" + ) + results = [ + Document(page_content=row[0], metadata=row[1]) + for row in session.execute(stmt, {"ids": ids_str}).fetchall() + ] return results def del_doc_by_ids(self, ids: List[str]) -> bool: - ids_str = ', '.join([f"{id}" for id in ids]) + ids_str = ", ".join([f"{id}" for id in ids]) with Session(self.engine) as session: stmt = text(f"DELETE FROM collection_{self.kb_name} WHERE id in (:ids)") - session.execute(stmt, {'ids': ids_str}) + session.execute(stmt, {"ids": ids_str}) session.commit() return True @@ -53,7 +60,8 @@ def do_create_kb(self): SELECT 1 FROM pg_indexes WHERE indexname = '{index_name}'; - """) + """ + ) result = conn.execute(index_query).scalar() if not result: index_statement = text( @@ -69,7 +77,8 @@ def do_create_kb(self): m=30 ef_construction=500 $$); - """) + """ + ) conn.execute(index_statement) def vs_type(self) -> str: @@ -103,8 +112,9 @@ def do_clear_vs(self): self.do_drop_kb() -if __name__ == '__main__': +if __name__ == "__main__": from server.db.base import Base, engine + Base.metadata.create_all(bind=engine) relyt_kb_service = RelytKBService("collection_test") kf = KnowledgeFile("README.md", "test") diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py index 336eaa482..08fb932dd 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py @@ -1,12 +1,17 @@ -from typing import List, Dict +from typing import Dict, List + from langchain.schema import Document from langchain.vectorstores import Zilliz + from chatchat.configs import kbs_config -from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \ - score_threshold_process +from chatchat.server.file_rag.utils import get_Retriever +from chatchat.server.knowledge_base.kb_service.base import ( + KBService, + SupportedVSType, + score_threshold_process, +) from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.utils import get_Embeddings -from chatchat.server.file_rag.utils import get_Retriever class ZillizKBService(KBService): @@ -15,20 +20,21 @@ class ZillizKBService(KBService): @staticmethod def get_collection(zilliz_name): from pymilvus import Collection + return Collection(zilliz_name) def get_doc_by_ids(self, ids: List[str]) -> List[Document]: result = [] if self.zilliz.col: # ids = [int(id) for id in ids] # for zilliz if needed #pr 2725 - data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"]) + data_list = self.zilliz.col.query(expr=f"pk in {ids}", output_fields=["*"]) for data in data_list: text = data.pop("text") result.append(Document(page_content=text, metadata=data)) return result def del_doc_by_ids(self, ids: List[str]) -> bool: - self.zilliz.col.delete(expr=f'pk in {ids}') + self.zilliz.col.delete(expr=f"pk in {ids}") @staticmethod def search(zilliz_name, content, limit=3): @@ -37,7 +43,9 @@ def search(zilliz_name, content, limit=3): "params": {}, } c = ZillizKBService.get_collection(zilliz_name) - return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"]) + return c.search( + content, "embeddings", search_params, limit=limit, output_fields=["content"] + ) def do_create_kb(self): pass @@ -47,8 +55,11 @@ def vs_type(self) -> str: def _load_zilliz(self): zilliz_args = kbs_config.get("zilliz") - self.zilliz = Zilliz(embedding_function=get_Embeddings(self.embed_model), - collection_name=self.kb_name, connection_args=zilliz_args) + self.zilliz = Zilliz( + embedding_function=get_Embeddings(self.embed_model), + collection_name=self.kb_name, + connection_args=zilliz_args, + ) def do_init(self): self._load_zilliz() @@ -83,10 +94,14 @@ def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): if self.zilliz.col: - filepath = kb_file.filepath.replace('\\', '\\\\') - delete_list = [item.get("pk") for item in - self.zilliz.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])] - self.zilliz.col.delete(expr=f'pk in {delete_list}') + filepath = kb_file.filepath.replace("\\", "\\\\") + delete_list = [ + item.get("pk") + for item in self.zilliz.col.query( + expr=f'source == "{filepath}"', output_fields=["pk"] + ) + ] + self.zilliz.col.delete(expr=f"pk in {delete_list}") def do_clear_vs(self): if self.zilliz.col: @@ -94,7 +109,7 @@ def do_clear_vs(self): self.do_init() -if __name__ == '__main__': +if __name__ == "__main__": from chatchat.server.db.base import Base, engine Base.metadata.create_all(bind=engine) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary/base.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary/base.py index 2d49259a5..5b68209fa 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary/base.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary/base.py @@ -1,17 +1,20 @@ -from typing import List - -from chatchat.configs import ( - DEFAULT_EMBEDDING_MODEL, - KB_ROOT_PATH) - -from abc import ABC, abstractmethod -from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss import os import shutil -from chatchat.server.db.repository.knowledge_metadata_repository import add_summary_to_db, delete_summary_from_db +from abc import ABC, abstractmethod +from typing import List from langchain.docstore.document import Document +from chatchat.configs import DEFAULT_EMBEDDING_MODEL, KB_ROOT_PATH +from chatchat.server.db.repository.knowledge_metadata_repository import ( + add_summary_to_db, + delete_summary_from_db, +) +from chatchat.server.knowledge_base.kb_cache.faiss_cache import ( + ThreadSafeFaiss, + kb_faiss_pool, +) + class KBSummaryService(ABC): kb_name: str @@ -19,10 +22,9 @@ class KBSummaryService(ABC): vs_path: str kb_path: str - def __init__(self, - knowledge_base_name: str, - embed_model: str = DEFAULT_EMBEDDING_MODEL - ): + def __init__( + self, knowledge_base_name: str, embed_model: str = DEFAULT_EMBEDDING_MODEL + ): self.kb_name = knowledge_base_name self.embed_model = embed_model @@ -32,7 +34,6 @@ def __init__(self, if not os.path.exists(self.vs_path): os.makedirs(self.vs_path) - def get_vs_path(self): return os.path.join(self.get_kb_path(), "summary_vector_store") @@ -40,20 +41,27 @@ def get_kb_path(self): return os.path.join(KB_ROOT_PATH, self.kb_name) def load_vector_store(self) -> ThreadSafeFaiss: - return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, - vector_name="summary_vector_store", - embed_model=self.embed_model, - create=True) + return kb_faiss_pool.load_vector_store( + kb_name=self.kb_name, + vector_name="summary_vector_store", + embed_model=self.embed_model, + create=True, + ) def add_kb_summary(self, summary_combine_docs: List[Document]): with self.load_vector_store().acquire() as vs: ids = vs.add_documents(documents=summary_combine_docs) vs.save_local(self.vs_path) - summary_infos = [{"summary_context": doc.page_content, - "summary_id": id, - "doc_ids": doc.metadata.get('doc_ids'), - "metadata": doc.metadata} for id, doc in zip(ids, summary_combine_docs)] + summary_infos = [ + { + "summary_context": doc.page_content, + "summary_id": id, + "doc_ids": doc.metadata.get("doc_ids"), + "metadata": doc.metadata, + } + for id, doc in zip(ids, summary_combine_docs) + ] status = add_summary_to_db(kb_name=self.kb_name, summary_infos=summary_infos) return status diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary/summary_chunk.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary/summary_chunk.py index 1e93317a0..b61a67b8b 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary/summary_chunk.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary/summary_chunk.py @@ -1,19 +1,19 @@ +import asyncio +import logging +import sys from typing import List, Optional -from langchain.schema.language_model import BaseLanguageModel - -from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId -from langchain.chains import StuffDocumentsChain, LLMChain -from langchain.prompts import PromptTemplate - +from langchain.chains import LLMChain, StuffDocumentsChain +from langchain.chains.combine_documents.map_reduce import ( + MapReduceDocumentsChain, + ReduceDocumentsChain, +) from langchain.docstore.document import Document from langchain.output_parsers.regex import RegexParser -from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain, MapReduceDocumentsChain - -import sys -import asyncio +from langchain.prompts import PromptTemplate +from langchain.schema.language_model import BaseLanguageModel -import logging +from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId logger = logging.getLogger() @@ -24,18 +24,21 @@ class SummaryAdapter: _separator: str = "\n\n" chain: MapReduceDocumentsChain - def __init__(self, overlap_size: int, token_max: int, - chain: MapReduceDocumentsChain): + def __init__( + self, overlap_size: int, token_max: int, chain: MapReduceDocumentsChain + ): self._OVERLAP_SIZE = overlap_size self.chain = chain self.token_max = token_max @classmethod - def form_summary(cls, - llm: BaseLanguageModel, - reduce_llm: BaseLanguageModel, - overlap_size: int, - token_max: int = 1300): + def form_summary( + cls, + llm: BaseLanguageModel, + reduce_llm: BaseLanguageModel, + overlap_size: int, + token_max: int = 1300, + ): """ 获取实例 :param reduce_llm: 用于合并摘要的llm @@ -47,22 +50,20 @@ def form_summary(cls, # This controls how each document will be formatted. Specifically, document_prompt = PromptTemplate( - input_variables=["page_content"], - template="{page_content}" + input_variables=["page_content"], template="{page_content}" ) # The prompt here should take as an input variable the # `document_variable_name` prompt_template = ( - "根据文本执行任务。以下任务信息" - "{task_briefing}" + "根据文本执行任务。以下任务信息" + "{task_briefing}" "文本内容如下: " "\r\n" "{context}" ) prompt = PromptTemplate( - template=prompt_template, - input_variables=["task_briefing", "context"] + template=prompt_template, input_variables=["task_briefing", "context"] ) llm_chain = LLMChain(llm=llm, prompt=prompt) # We now define how to combine these summaries @@ -75,7 +76,7 @@ def form_summary(cls, combine_documents_chain = StuffDocumentsChain( llm_chain=reduce_llm_chain, document_prompt=document_prompt, - document_variable_name=document_variable_name + document_variable_name=document_variable_name, ) reduce_documents_chain = ReduceDocumentsChain( token_max=token_max, @@ -86,17 +87,13 @@ def form_summary(cls, document_variable_name=document_variable_name, reduce_documents_chain=reduce_documents_chain, # 返回中间步骤 - return_intermediate_steps=True + return_intermediate_steps=True, ) - return cls(overlap_size=overlap_size, - chain=chain, - token_max=token_max) - - def summarize(self, - file_description: str, - docs: List[DocumentWithVSId] = [] - ) -> List[Document]: + return cls(overlap_size=overlap_size, chain=chain, token_max=token_max) + def summarize( + self, file_description: str, docs: List[DocumentWithVSId] = [] + ) -> List[Document]: if sys.version_info < (3, 10): loop = asyncio.get_event_loop() else: @@ -107,13 +104,13 @@ def summarize(self, asyncio.set_event_loop(loop) # 同步调用协程代码 - return loop.run_until_complete(self.asummarize(file_description=file_description, - docs=docs)) - - async def asummarize(self, - file_description: str, - docs: List[DocumentWithVSId] = []) -> List[Document]: + return loop.run_until_complete( + self.asummarize(file_description=file_description, docs=docs) + ) + async def asummarize( + self, file_description: str, docs: List[DocumentWithVSId] = [] + ) -> List[Document]: logger.info("start summary") """ 这个过程分成两个部分: @@ -128,9 +125,11 @@ async def asummarize(self, result_docs, token_max=token_max, callbacks=callbacks, **kwargs ) """ - summary_combine, summary_intermediate_steps = self.chain.combine_docs(docs=docs, - task_briefing="描述不同方法之间的接近度和相似性," - "以帮助读者理解它们之间的关系。") + summary_combine, summary_intermediate_steps = self.chain.combine_docs( + docs=docs, + task_briefing="描述不同方法之间的接近度和相似性," + "以帮助读者理解它们之间的关系。", + ) print(summary_combine) print(summary_intermediate_steps) @@ -152,7 +151,7 @@ async def asummarize(self, _metadata = { "file_description": file_description, "summary_intermediate_steps": summary_intermediate_steps, - "doc_ids": doc_ids + "doc_ids": doc_ids, } summary_combine_doc = Document(page_content=summary_combine, metadata=_metadata) @@ -178,12 +177,14 @@ def _drop_overlap(self, docs: List[DocumentWithVSId]) -> List[str]: # 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分 # 迭代递减pre_doc的长度,每次迭代删除前面的字符, # 查询重叠部分,直到pre_doc的长度小于 self._OVERLAP_SIZE // 2 - 2len(separator) - for i in range(len(pre_doc), self._OVERLAP_SIZE // 2 - 2 * len(self._separator), -1): + for i in range( + len(pre_doc), self._OVERLAP_SIZE // 2 - 2 * len(self._separator), -1 + ): # 每次迭代删除前面的字符 pre_doc = pre_doc[1:] - if doc.page_content[:len(pre_doc)] == pre_doc: + if doc.page_content[: len(pre_doc)] == pre_doc: # 删除下一个开头重叠的部分 - merge_docs.append(doc.page_content[len(pre_doc):]) + merge_docs.append(doc.page_content[len(pre_doc) :]) break pre_doc = doc.page_content @@ -199,16 +200,12 @@ def _join_docs(self, docs: List[str]) -> Optional[str]: return text -if __name__ == '__main__': - +if __name__ == "__main__": docs = [ - - '梦者有特别的作用,也就是说梦是在预卜未来。因此,梦内容的', - - '梦内容的多彩多姿以及对梦者本身所遗留的特殊印象,使他们很难想象', - - '使他们很难想象出一套系统划一的观念,而需要以其个别的价值与可靠性作各', - '值与可靠性作各种不同的分化与聚合。因此,古代哲学家们对梦的评价也就完全' + "梦者有特别的作用,也就是说梦是在预卜未来。因此,梦内容的", + "梦内容的多彩多姿以及对梦者本身所遗留的特殊印象,使他们很难想象", + "使他们很难想象出一套系统划一的观念,而需要以其个别的价值与可靠性作各", + "值与可靠性作各种不同的分化与聚合。因此,古代哲学家们对梦的评价也就完全", ] _OVERLAP_SIZE = 1 separator: str = "\n\n" @@ -229,9 +226,9 @@ def _join_docs(self, docs: List[str]) -> Optional[str]: for i in range(len(pre_doc), _OVERLAP_SIZE - 2 * len(separator), -1): # 每次迭代删除前面的字符 pre_doc = pre_doc[1:] - if doc[:len(pre_doc)] == pre_doc: + if doc[: len(pre_doc)] == pre_doc: # 删除下一个开头重叠的部分 - page_content = doc[len(pre_doc):] + page_content = doc[len(pre_doc) :] merge_docs.append(page_content) pre_doc = doc diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary_api.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary_api.py index fd18fe175..1d5c4fae8 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary_api.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary_api.py @@ -1,27 +1,35 @@ +import json +from typing import List, Optional + from fastapi import Body -from chatchat.configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL, - OVERLAP_SIZE, - logger, log_verbose, ) -from chatchat.server.knowledge_base.utils import (list_files_from_folder) from sse_starlette import EventSourceResponse -import json + +from chatchat.configs import ( + DEFAULT_EMBEDDING_MODEL, + DEFAULT_VS_TYPE, + OVERLAP_SIZE, + log_verbose, + logger, +) from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory -from typing import List, Optional from chatchat.server.knowledge_base.kb_summary.base import KBSummaryService from chatchat.server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter -from chatchat.server.utils import wrap_done, get_ChatOpenAI, BaseResponse from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId +from chatchat.server.knowledge_base.utils import list_files_from_folder +from chatchat.server.utils import BaseResponse, get_ChatOpenAI, wrap_done def recreate_summary_vector_store( - knowledge_base_name: str = Body(..., examples=["samples"]), - allow_empty_kb: bool = Body(True), - vs_type: str = Body(DEFAULT_VS_TYPE), - embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), - file_description: str = Body(''), - model_name: str = Body(None, description="LLM 模型名称。"), - temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), + knowledge_base_name: str = Body(..., examples=["samples"]), + allow_empty_kb: bool = Body(True), + vs_type: str = Body(DEFAULT_VS_TYPE), + embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), + file_description: str = Body(""), + model_name: str = Body(None, description="LLM 模型名称。"), + temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: Optional[int] = Body( + None, description="限制LLM生成Token数量,默认None代表模型最大值" + ), ): """ 重建单个知识库文件摘要 @@ -37,7 +45,6 @@ def recreate_summary_vector_store( """ def output(): - kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) if not kb.exists() and not allow_empty_kb: yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} @@ -64,51 +71,59 @@ def output(): local_wrap=True, ) # 文本摘要适配器 - summary = SummaryAdapter.form_summary(llm=llm, - reduce_llm=reduce_llm, - overlap_size=OVERLAP_SIZE) + summary = SummaryAdapter.form_summary( + llm=llm, reduce_llm=reduce_llm, overlap_size=OVERLAP_SIZE + ) files = list_files_from_folder(knowledge_base_name) i = 0 for i, file_name in enumerate(files): - doc_infos = kb.list_docs(file_name=file_name) - docs = summary.summarize(file_description=file_description, - docs=doc_infos) + docs = summary.summarize( + file_description=file_description, docs=doc_infos + ) - status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs) + status_kb_summary = kb_summary.add_kb_summary( + summary_combine_docs=docs + ) if status_kb_summary: logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成") - yield json.dumps({ - "code": 200, - "msg": f"({i + 1} / {len(files)}): {file_name}", - "total": len(files), - "finished": i + 1, - "doc": file_name, - }, ensure_ascii=False) + yield json.dumps( + { + "code": 200, + "msg": f"({i + 1} / {len(files)}): {file_name}", + "total": len(files), + "finished": i + 1, + "doc": file_name, + }, + ensure_ascii=False, + ) else: - msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。" logger.error(msg) - yield json.dumps({ - "code": 500, - "msg": msg, - }) + yield json.dumps( + { + "code": 500, + "msg": msg, + } + ) i += 1 return EventSourceResponse(output()) def summary_file_to_vector_store( - knowledge_base_name: str = Body(..., examples=["samples"]), - file_name: str = Body(..., examples=["test.pdf"]), - allow_empty_kb: bool = Body(True), - vs_type: str = Body(DEFAULT_VS_TYPE), - embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), - file_description: str = Body(''), - model_name: str = Body(None, description="LLM 模型名称。"), - temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), + knowledge_base_name: str = Body(..., examples=["samples"]), + file_name: str = Body(..., examples=["test.pdf"]), + allow_empty_kb: bool = Body(True), + vs_type: str = Body(DEFAULT_VS_TYPE), + embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), + file_description: str = Body(""), + model_name: str = Body(None, description="LLM 模型名称。"), + temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: Optional[int] = Body( + None, description="限制LLM生成Token数量,默认None代表模型最大值" + ), ): """ 单个知识库根据文件名称摘要 @@ -146,43 +161,48 @@ def output(): local_wrap=True, ) # 文本摘要适配器 - summary = SummaryAdapter.form_summary(llm=llm, - reduce_llm=reduce_llm, - overlap_size=OVERLAP_SIZE) + summary = SummaryAdapter.form_summary( + llm=llm, reduce_llm=reduce_llm, overlap_size=OVERLAP_SIZE + ) doc_infos = kb.list_docs(file_name=file_name) - docs = summary.summarize(file_description=file_description, - docs=doc_infos) + docs = summary.summarize(file_description=file_description, docs=doc_infos) status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs) if status_kb_summary: logger.info(f" {file_name} 总结完成") - yield json.dumps({ - "code": 200, - "msg": f"{file_name} 总结完成", - "doc": file_name, - }, ensure_ascii=False) + yield json.dumps( + { + "code": 200, + "msg": f"{file_name} 总结完成", + "doc": file_name, + }, + ensure_ascii=False, + ) else: - msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。" logger.error(msg) - yield json.dumps({ - "code": 500, - "msg": msg, - }) + yield json.dumps( + { + "code": 500, + "msg": msg, + } + ) return EventSourceResponse(output()) def summary_doc_ids_to_vector_store( - knowledge_base_name: str = Body(..., examples=["samples"]), - doc_ids: List = Body([], examples=[["uuid"]]), - vs_type: str = Body(DEFAULT_VS_TYPE), - embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), - file_description: str = Body(''), - model_name: str = Body(None, description="LLM 模型名称。"), - temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), + knowledge_base_name: str = Body(..., examples=["samples"]), + doc_ids: List = Body([], examples=[["uuid"]]), + vs_type: str = Body(DEFAULT_VS_TYPE), + embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), + file_description: str = Body(""), + model_name: str = Body(None, description="LLM 模型名称。"), + temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: Optional[int] = Body( + None, description="限制LLM生成Token数量,默认None代表模型最大值" + ), ) -> BaseResponse: """ 单个知识库根据doc_ids摘要 @@ -198,7 +218,9 @@ def summary_doc_ids_to_vector_store( """ kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) if not kb.exists(): - return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={}) + return BaseResponse( + code=404, msg=f"未找到知识库 {knowledge_base_name}", data={} + ) else: llm = get_ChatOpenAI( model_name=model_name, @@ -213,18 +235,24 @@ def summary_doc_ids_to_vector_store( local_wrap=True, ) # 文本摘要适配器 - summary = SummaryAdapter.form_summary(llm=llm, - reduce_llm=reduce_llm, - overlap_size=OVERLAP_SIZE) + summary = SummaryAdapter.form_summary( + llm=llm, reduce_llm=reduce_llm, overlap_size=OVERLAP_SIZE + ) doc_infos = kb.get_doc_by_ids(ids=doc_ids) # doc_infos转换成DocumentWithVSId包装的对象 - doc_info_with_ids = [DocumentWithVSId(**doc.dict(), id=with_id) for with_id, doc in zip(doc_ids, doc_infos)] + doc_info_with_ids = [ + DocumentWithVSId(**doc.dict(), id=with_id) + for with_id, doc in zip(doc_ids, doc_infos) + ] - docs = summary.summarize(file_description=file_description, - docs=doc_info_with_ids) + docs = summary.summarize( + file_description=file_description, docs=doc_info_with_ids + ) # 将docs转换成dict resp_summarize = [{**doc.dict()} for doc in docs] - return BaseResponse(code=200, msg="总结完成", data={"summarize": resp_summarize}) + return BaseResponse( + code=200, msg="总结完成", data={"summarize": resp_summarize} + ) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/migrate.py b/libs/chatchat-server/chatchat/server/knowledge_base/migrate.py index 06b6cf109..9bc42b241 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/migrate.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/migrate.py @@ -1,25 +1,42 @@ +import os from datetime import datetime +from typing import List, Literal + from dateutil.parser import parse -import os -from typing import Literal, List from chatchat.configs import ( - DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, - CHUNK_SIZE, OVERLAP_SIZE, logger, log_verbose -) -from chatchat.server.knowledge_base.utils import ( - get_file_path, list_kbs_from_folder, - list_files_from_folder, files2docs_in_thread, - KnowledgeFile + CHUNK_SIZE, + DEFAULT_EMBEDDING_MODEL, + DEFAULT_VS_TYPE, + OVERLAP_SIZE, + ZH_TITLE_ENHANCE, + log_verbose, + logger, ) -from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType +from chatchat.server.db.base import Base, engine from chatchat.server.db.models.conversation_model import ConversationModel from chatchat.server.db.models.message_model import MessageModel -from chatchat.server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported -from chatchat.server.db.repository.knowledge_metadata_repository import add_summary_to_db +from chatchat.server.db.repository.knowledge_file_repository import ( + add_file_to_db, +) -from chatchat.server.db.base import Base, engine +# ensure Models are imported +from chatchat.server.db.repository.knowledge_metadata_repository import ( + add_summary_to_db, +) from chatchat.server.db.session import session_scope +from chatchat.server.knowledge_base.kb_service.base import ( + KBServiceFactory, + SupportedVSType, +) +from chatchat.server.knowledge_base.utils import ( + KnowledgeFile, + files2docs_in_thread, + get_file_path, + list_files_from_folder, + list_kbs_from_folder, +) + def create_tables(): Base.metadata.create_all(bind=engine) @@ -31,8 +48,8 @@ def reset_tables(): def import_from_db( - sqlite_path: str = None, - # csv_path: str = None, + sqlite_path: str = None, + # csv_path: str = None, ) -> bool: """ 在知识库与向量库无变化的情况下,从备份数据库中导入数据到 info.db。 @@ -49,7 +66,12 @@ def import_from_db( con = sql.connect(sqlite_path) con.row_factory = sql.Row cur = con.cursor() - tables = [x["name"] for x in cur.execute("select name from sqlite_master where type='table'").fetchall()] + tables = [ + x["name"] + for x in cur.execute( + "select name from sqlite_master where type='table'" + ).fetchall() + ] for model in models: table = model.local_table.fullname if table not in tables: @@ -77,19 +99,20 @@ def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]: kb_files.append(kb_file) except Exception as e: msg = f"{e},已跳过" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", exc_info=e if log_verbose else None + ) return kb_files def folder2db( - kb_names: List[str], - mode: Literal["recreate_vs", "update_in_db", "increment"], - vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, - embed_model: str = DEFAULT_EMBEDDING_MODEL, - chunk_size: int = CHUNK_SIZE, - chunk_overlap: int = OVERLAP_SIZE, - zh_title_enhance: bool = ZH_TITLE_ENHANCE, + kb_names: List[str], + mode: Literal["recreate_vs", "update_in_db", "increment"], + vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, + embed_model: str = DEFAULT_EMBEDDING_MODEL, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + zh_title_enhance: bool = ZH_TITLE_ENHANCE, ): """ use existed files in local folder to populate database and/or vector store. @@ -102,13 +125,17 @@ def folder2db( def files2vs(kb_name: str, kb_files: List[KnowledgeFile]) -> List: result = [] - for success, res in files2docs_in_thread(kb_files, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - zh_title_enhance=zh_title_enhance): + for success, res in files2docs_in_thread( + kb_files, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance, + ): if success: _, filename, docs = res - print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档") + print( + f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档" + ) kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) kb_file.splited_docs = docs kb.add_doc(kb_file=kb_file, not_refresh_vs_cache=True) @@ -156,23 +183,27 @@ def files2vs(kb_name: str, kb_files: List[KnowledgeFile]) -> List: else: print(f"unsupported migrate mode: {mode}") end = datetime.now() - kb_path = f"知识库路径\t:{kb.kb_path}\n" if kb.vs_type()==SupportedVSType.FAISS else "" + kb_path = ( + f"知识库路径\t:{kb.kb_path}\n" + if kb.vs_type() == SupportedVSType.FAISS + else "" + ) file_count = len(kb_files) success_count = len(result) - docs_count = sum([len(x['docs']) for x in result]) + docs_count = sum([len(x["docs"]) for x in result]) print("\n" + "-" * 100) print( ( - f"知识库名称\t:{kb_name}\n" - f"知识库类型\t:{kb.vs_type()}\n" - f"向量模型:\t:{kb.embed_model}\n" + f"知识库名称\t:{kb_name}\n" + f"知识库类型\t:{kb.vs_type()}\n" + f"向量模型:\t:{kb.embed_model}\n" ) - +kb_path+ - ( - f"文件总数量\t:{file_count}\n" - f"入库文件数\t:{success_count}\n" - f"知识条目数\t:{docs_count}\n" - f"用时\t\t:{end-start}" + + kb_path + + ( + f"文件总数量\t:{file_count}\n" + f"入库文件数\t:{success_count}\n" + f"知识条目数\t:{docs_count}\n" + f"用时\t\t:{end-start}" ) ) print("-" * 100 + "\n") diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/model/kb_document_model.py b/libs/chatchat-server/chatchat/server/knowledge_base/model/kb_document_model.py index 662929de1..78c9567ef 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/model/kb_document_model.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/model/kb_document_model.py @@ -1,4 +1,3 @@ - from langchain.docstore.document import Document @@ -6,5 +5,6 @@ class DocumentWithVSId(Document): """ 矢量化后的文档 """ + id: str = None score: float = 3.0 diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py index 26dd18bb4..ff950b546 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py @@ -1,27 +1,30 @@ +import importlib +import json +import logging import os from functools import lru_cache +from pathlib import Path +from typing import Dict, Generator, List, Tuple, Union + +import chardet +import langchain_community.document_loaders +from langchain.docstore.document import Document +from langchain.text_splitter import MarkdownHeaderTextSplitter, TextSplitter +from langchain_community.document_loaders import JSONLoader, TextLoader + from chatchat.configs import ( - KB_ROOT_PATH, CHUNK_SIZE, + KB_ROOT_PATH, OVERLAP_SIZE, + TEXT_SPLITTER_NAME, ZH_TITLE_ENHANCE, log_verbose, text_splitter_dict, - TEXT_SPLITTER_NAME, ) -import importlib -from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance -import langchain_community.document_loaders -from langchain.docstore.document import Document -from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter -from pathlib import Path -from chatchat.server.utils import run_in_thread_pool, run_in_process_pool -import json -from typing import List, Union, Dict, Tuple, Generator -import chardet -from langchain_community.document_loaders import JSONLoader, TextLoader - -import logging +from chatchat.server.file_rag.text_splitter import ( + zh_title_enhance as func_zh_title_enhance, +) +from chatchat.server.utils import run_in_process_pool, run_in_thread_pool logger = logging.getLogger() @@ -53,8 +56,11 @@ def get_file_path(knowledge_base_name: str, doc_name: str): def list_kbs_from_folder(): - return [f for f in os.listdir(KB_ROOT_PATH) - if os.path.isdir(os.path.join(KB_ROOT_PATH, f))] + return [ + f + for f in os.listdir(KB_ROOT_PATH) + if os.path.isdir(os.path.join(KB_ROOT_PATH, f)) + ] def list_files_from_folder(kb_name: str): @@ -78,7 +84,9 @@ def process_entry(entry): for target_entry in target_it: process_entry(target_entry) elif entry.is_file(): - file_path = (Path(os.path.relpath(entry.path, doc_path)).as_posix()) # 路径统一为 posix 格式 + file_path = Path( + os.path.relpath(entry.path, doc_path) + ).as_posix() # 路径统一为 posix 格式 result.append(file_path) elif entry.is_dir(): with os.scandir(entry.path) as it: @@ -92,37 +100,49 @@ def process_entry(entry): return result -LOADER_DICT = {"UnstructuredHTMLLoader": ['.html', '.htm'], - "MHTMLLoader": ['.mhtml'], - "TextLoader": ['.md'], - "UnstructuredMarkdownLoader": ['.md'], - "JSONLoader": [".json"], - "JSONLinesLoader": [".jsonl"], - "CSVLoader": [".csv"], - # "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv - "RapidOCRPDFLoader": [".pdf"], - "RapidOCRDocLoader": ['.docx', '.doc'], - "RapidOCRPPTLoader": ['.ppt', '.pptx', ], - "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'], - "UnstructuredFileLoader": ['.eml', '.msg', '.rst', - '.rtf', '.txt', '.xml', - '.epub', '.odt','.tsv'], - "UnstructuredEmailLoader": ['.eml', '.msg'], - "UnstructuredEPubLoader": ['.epub'], - "UnstructuredExcelLoader": ['.xlsx', '.xls', '.xlsd'], - "NotebookLoader": ['.ipynb'], - "UnstructuredODTLoader": ['.odt'], - "PythonLoader": ['.py'], - "UnstructuredRSTLoader": ['.rst'], - "UnstructuredRTFLoader": ['.rtf'], - "SRTLoader": ['.srt'], - "TomlLoader": ['.toml'], - "UnstructuredTSVLoader": ['.tsv'], - "UnstructuredWordDocumentLoader": ['.docx', '.doc'], - "UnstructuredXMLLoader": ['.xml'], - "UnstructuredPowerPointLoader": ['.ppt', '.pptx'], - "EverNoteLoader": ['.enex'], - } +LOADER_DICT = { + "UnstructuredHTMLLoader": [".html", ".htm"], + "MHTMLLoader": [".mhtml"], + "TextLoader": [".md"], + "UnstructuredMarkdownLoader": [".md"], + "JSONLoader": [".json"], + "JSONLinesLoader": [".jsonl"], + "CSVLoader": [".csv"], + # "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv + "RapidOCRPDFLoader": [".pdf"], + "RapidOCRDocLoader": [".docx", ".doc"], + "RapidOCRPPTLoader": [ + ".ppt", + ".pptx", + ], + "RapidOCRLoader": [".png", ".jpg", ".jpeg", ".bmp"], + "UnstructuredFileLoader": [ + ".eml", + ".msg", + ".rst", + ".rtf", + ".txt", + ".xml", + ".epub", + ".odt", + ".tsv", + ], + "UnstructuredEmailLoader": [".eml", ".msg"], + "UnstructuredEPubLoader": [".epub"], + "UnstructuredExcelLoader": [".xlsx", ".xls", ".xlsd"], + "NotebookLoader": [".ipynb"], + "UnstructuredODTLoader": [".odt"], + "PythonLoader": [".py"], + "UnstructuredRSTLoader": [".rst"], + "UnstructuredRTFLoader": [".rtf"], + "SRTLoader": [".srt"], + "TomlLoader": [".toml"], + "UnstructuredTSVLoader": [".tsv"], + "UnstructuredWordDocumentLoader": [".docx", ".doc"], + "UnstructuredXMLLoader": [".xml"], + "UnstructuredPowerPointLoader": [".ppt", ".pptx"], + "EverNoteLoader": [".enex"], +} SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] @@ -153,22 +173,34 @@ def get_LoaderClass(file_extension): def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): - ''' + """ 根据loader_name和文件路径或内容返回文档加载器。 - ''' + """ loader_kwargs = loader_kwargs or {} try: - if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader", "FilteredCSVLoader", - "RapidOCRDocLoader", "RapidOCRPPTLoader"]: - document_loaders_module = importlib.import_module("chatchat.server.file_rag.document_loaders") + if loader_name in [ + "RapidOCRPDFLoader", + "RapidOCRLoader", + "FilteredCSVLoader", + "RapidOCRDocLoader", + "RapidOCRPPTLoader", + ]: + document_loaders_module = importlib.import_module( + "chatchat.server.file_rag.document_loaders" + ) else: - document_loaders_module = importlib.import_module("langchain_community.document_loaders") + document_loaders_module = importlib.import_module( + "langchain_community.document_loaders" + ) DocumentLoader = getattr(document_loaders_module, loader_name) except Exception as e: msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) - document_loaders_module = importlib.import_module("langchain_community.document_loaders") + logger.error( + f"{e.__class__.__name__}: {msg}", exc_info=e if log_verbose else None + ) + document_loaders_module = importlib.import_module( + "langchain_community.document_loaders" + ) DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") if loader_name == "UnstructuredFileLoader": @@ -176,7 +208,7 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): elif loader_name == "CSVLoader": if not loader_kwargs.get("encoding"): # 如果未指定 encoding,自动识别文件编码类型,避免langchain loader 加载文件报编码错误 - with open(file_path, 'rb') as struct_file: + with open(file_path, "rb") as struct_file: encode_detect = chardet.detect(struct_file.read()) if encode_detect is None: encode_detect = {"encoding": "utf-8"} @@ -194,77 +226,91 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): @lru_cache() -def make_text_splitter( - splitter_name, - chunk_size, - chunk_overlap -): +def make_text_splitter(splitter_name, chunk_size, chunk_overlap): """ 根据参数获取特定的分词器 """ splitter_name = splitter_name or "SpacyTextSplitter" try: - if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定 - headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on'] + if ( + splitter_name == "MarkdownHeaderTextSplitter" + ): # MarkdownHeaderTextSplitter特殊判定 + headers_to_split_on = text_splitter_dict[splitter_name][ + "headers_to_split_on" + ] text_splitter = MarkdownHeaderTextSplitter( - headers_to_split_on=headers_to_split_on, strip_headers=False) + headers_to_split_on=headers_to_split_on, strip_headers=False + ) else: - try: ## 优先使用用户自定义的text_splitter text_splitter_module = importlib.import_module("server.text_splitter") TextSplitter = getattr(text_splitter_module, splitter_name) except: ## 否则使用langchain的text_splitter - text_splitter_module = importlib.import_module("langchain.text_splitter") + text_splitter_module = importlib.import_module( + "langchain.text_splitter" + ) TextSplitter = getattr(text_splitter_module, splitter_name) - if text_splitter_dict[splitter_name]["source"] == "tiktoken": ## 从tiktoken加载 + if ( + text_splitter_dict[splitter_name]["source"] == "tiktoken" + ): ## 从tiktoken加载 try: text_splitter = TextSplitter.from_tiktoken_encoder( - encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], + encoding_name=text_splitter_dict[splitter_name][ + "tokenizer_name_or_path" + ], pipeline="zh_core_web_sm", chunk_size=chunk_size, - chunk_overlap=chunk_overlap + chunk_overlap=chunk_overlap, ) except: text_splitter = TextSplitter.from_tiktoken_encoder( - encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], + encoding_name=text_splitter_dict[splitter_name][ + "tokenizer_name_or_path" + ], chunk_size=chunk_size, - chunk_overlap=chunk_overlap + chunk_overlap=chunk_overlap, ) - elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载 - - if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2": - from transformers import GPT2TokenizerFast + elif ( + text_splitter_dict[splitter_name]["source"] == "huggingface" + ): ## 从huggingface加载 + if ( + text_splitter_dict[splitter_name]["tokenizer_name_or_path"] + == "gpt2" + ): from langchain.text_splitter import CharacterTextSplitter + from transformers import GPT2TokenizerFast + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") else: ## 字符长度加载 from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained( text_splitter_dict[splitter_name]["tokenizer_name_or_path"], - trust_remote_code=True) + trust_remote_code=True, + ) text_splitter = TextSplitter.from_huggingface_tokenizer( tokenizer=tokenizer, chunk_size=chunk_size, - chunk_overlap=chunk_overlap + chunk_overlap=chunk_overlap, ) else: try: text_splitter = TextSplitter( pipeline="zh_core_web_sm", chunk_size=chunk_size, - chunk_overlap=chunk_overlap + chunk_overlap=chunk_overlap, ) except: text_splitter = TextSplitter( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap + chunk_size=chunk_size, chunk_overlap=chunk_overlap ) except Exception as e: print(e) - text_splitter_module = importlib.import_module('langchain.text_splitter') + text_splitter_module = importlib.import_module("langchain.text_splitter") TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) - + # If you use SpacyTextSplitter you can use GPU to do split likes Issue #1287 # text_splitter._tokenizer.max_length = 37016792 # text_splitter._tokenizer.prefer_gpu() @@ -273,14 +319,14 @@ def make_text_splitter( class KnowledgeFile: def __init__( - self, - filename: str, - knowledge_base_name: str, - loader_kwargs: Dict = {}, + self, + filename: str, + knowledge_base_name: str, + loader_kwargs: Dict = {}, ): - ''' + """ 对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。 - ''' + """ self.kb_name = knowledge_base_name self.filename = str(Path(filename).as_posix()) self.ext = os.path.splitext(filename)[-1].lower() @@ -296,9 +342,11 @@ def __init__( def file2docs(self, refresh: bool = False): if self.docs is None or refresh: logger.info(f"{self.document_loader_name} used for {self.filepath}") - loader = get_loader(loader_name=self.document_loader_name, - file_path=self.filepath, - loader_kwargs=self.loader_kwargs) + loader = get_loader( + loader_name=self.document_loader_name, + file_path=self.filepath, + loader_kwargs=self.loader_kwargs, + ) if isinstance(loader, TextLoader): loader.encoding = "utf8" self.docs = loader.load() @@ -307,21 +355,24 @@ def file2docs(self, refresh: bool = False): return self.docs def docs2texts( - self, - docs: List[Document] = None, - zh_title_enhance: bool = ZH_TITLE_ENHANCE, - refresh: bool = False, - chunk_size: int = CHUNK_SIZE, - chunk_overlap: int = OVERLAP_SIZE, - text_splitter: TextSplitter = None, + self, + docs: List[Document] = None, + zh_title_enhance: bool = ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, ): docs = docs or self.file2docs(refresh=refresh) if not docs: return [] if self.ext not in [".csv"]: if text_splitter is None: - text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, - chunk_overlap=chunk_overlap) + text_splitter = make_text_splitter( + splitter_name=self.text_splitter_name, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) if self.text_splitter_name == "MarkdownHeaderTextSplitter": docs = text_splitter.split_text(docs[0].page_content) else: @@ -337,21 +388,23 @@ def docs2texts( return self.splited_docs def file2text( - self, - zh_title_enhance: bool = ZH_TITLE_ENHANCE, - refresh: bool = False, - chunk_size: int = CHUNK_SIZE, - chunk_overlap: int = OVERLAP_SIZE, - text_splitter: TextSplitter = None, + self, + zh_title_enhance: bool = ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, ): if self.splited_docs is None or refresh: docs = self.file2docs() - self.splited_docs = self.docs2texts(docs=docs, - zh_title_enhance=zh_title_enhance, - refresh=refresh, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - text_splitter=text_splitter) + self.splited_docs = self.docs2texts( + docs=docs, + zh_title_enhance=zh_title_enhance, + refresh=refresh, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + text_splitter=text_splitter, + ) return self.splited_docs def file_exist(self): @@ -364,27 +417,30 @@ def get_size(self): return os.path.getsize(self.filepath) -def files2docs_in_thread_file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]: +def files2docs_in_thread_file2docs( + *, file: KnowledgeFile, **kwargs +) -> Tuple[bool, Tuple[str, str, List[Document]]]: try: return True, (file.kb_name, file.filename, file.file2text(**kwargs)) except Exception as e: msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", exc_info=e if log_verbose else None + ) return False, (file.kb_name, file.filename, msg) def files2docs_in_thread( - files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], - chunk_size: int = CHUNK_SIZE, - chunk_overlap: int = OVERLAP_SIZE, - zh_title_enhance: bool = ZH_TITLE_ENHANCE, + files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + zh_title_enhance: bool = ZH_TITLE_ENHANCE, ) -> Generator: - ''' + """ 利用多线程批量将磁盘文件转化成langchain Document. 如果传入参数是Tuple,形式为(filename, kb_name) 生成器返回值为 status, (kb_name, file_name, docs | error) - ''' + """ kwargs_list = [] for i, file in enumerate(files): @@ -407,7 +463,9 @@ def files2docs_in_thread( except Exception as e: yield False, (kb_name, filename, str(e)) - for result in run_in_thread_pool(func=files2docs_in_thread_file2docs, params=kwargs_list): + for result in run_in_thread_pool( + func=files2docs_in_thread_file2docs, params=kwargs_list + ): yield result @@ -415,12 +473,12 @@ def files2docs_in_thread( from pprint import pprint kb_file = KnowledgeFile( - filename="E:\\LLM\\Data\\Test.md", - knowledge_base_name="samples") + filename="E:\\LLM\\Data\\Test.md", knowledge_base_name="samples" + ) # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter" kb_file.text_splitter_name = "MarkdownHeaderTextSplitter" docs = kb_file.file2docs() # pprint(docs[-1]) texts = kb_file.docs2texts(docs) for text in texts: - print(text) \ No newline at end of file + print(text) diff --git a/libs/chatchat-server/chatchat/server/llm_api_shutdown.py b/libs/chatchat-server/chatchat/server/llm_api_shutdown.py index 1ac1404d2..659c28240 100644 --- a/libs/chatchat-server/chatchat/server/llm_api_shutdown.py +++ b/libs/chatchat-server/chatchat/server/llm_api_shutdown.py @@ -3,16 +3,20 @@ python llm_api_shutdown.py --serve all 可选"all","controller","model_worker","openai_api_server", all表示停止所有服务 """ -import sys import os +import sys sys.path.append(os.path.dirname(os.path.dirname(__file__))) -import subprocess import argparse +import subprocess parser = argparse.ArgumentParser() -parser.add_argument("--serve", choices=["all", "controller", "model_worker", "openai_api_server"], default="all") +parser.add_argument( + "--serve", + choices=["all", "controller", "model_worker", "openai_api_server"], + default="all", +) args = parser.parse_args() diff --git a/libs/chatchat-server/chatchat/server/llm_api_stale.py b/libs/chatchat-server/chatchat/server/llm_api_stale.py index f0ac9a401..29e33b7a0 100644 --- a/libs/chatchat-server/chatchat/server/llm_api_stale.py +++ b/libs/chatchat-server/chatchat/server/llm_api_stale.py @@ -4,15 +4,15 @@ 但少数非关键参数如--worker-address,--allowed-origins,--allowed-methods,--allowed-headers不支持 """ -import sys import os +import sys sys.path.append(os.path.dirname(os.path.dirname(__file__))) -import subprocess -import re -import logging import argparse +import logging +import re +import subprocess LOG_PATH = "./logs/" LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" @@ -22,11 +22,13 @@ parser = argparse.ArgumentParser() # ------multi worker----------------- -parser.add_argument('--model-path-address', - default="THUDM/chatglm2-6b@localhost@20002", - nargs="+", - type=str, - help="model path, host, and port, formatted as model-path@host@port") +parser.add_argument( + "--model-path-address", + default="THUDM/chatglm2-6b@localhost@20002", + nargs="+", + type=str, + help="model path, host, and port, formatted as model-path@host@port", +) # ---------------controller------------------------- parser.add_argument("--controller-host", type=str, default="localhost") @@ -79,9 +81,7 @@ default="20GiB", help="The maximum memory per gpu. Use a string like '13Gib'", ) -parser.add_argument( - "--load-8bit", action="store_true", help="Use 8-bit quantization" -) +parser.add_argument("--load-8bit", action="store_true", help="Use 8-bit quantization") parser.add_argument( "--cpu-offloading", action="store_true", @@ -126,13 +126,26 @@ parser.add_argument("--no-register", action="store_true") worker_args = [ - "worker-host", "worker-port", - "model-path", "revision", "device", "gpus", "num-gpus", - "max-gpu-memory", "load-8bit", "cpu-offloading", - "gptq-ckpt", "gptq-wbits", "gptq-groupsize", - "gptq-act-order", "model-names", "limit-worker-concurrency", - "stream-interval", "no-register", - "controller-address", "worker-address" + "worker-host", + "worker-port", + "model-path", + "revision", + "device", + "gpus", + "num-gpus", + "max-gpu-memory", + "load-8bit", + "cpu-offloading", + "gptq-ckpt", + "gptq-wbits", + "gptq-groupsize", + "gptq-act-order", + "model-names", + "limit-worker-concurrency", + "stream-interval", + "no-register", + "controller-address", + "worker-address", ] # -----------------openai server--------------------------- @@ -155,9 +168,13 @@ type=lambda s: s.split(","), help="Optional list of comma separated API keys", ) -server_args = ["server-host", "server-port", "allow-credentials", "api-keys", - "controller-address" - ] +server_args = [ + "server-host", + "server-port", + "allow-credentials", + "api-keys", + "controller-address", +] # 0,controller, model_worker, openai_api_server # 1, 命令行选项 @@ -190,7 +207,11 @@ def string_args(args, args_list): # 1==True -> True elif isinstance(value, bool) and value == True: args_str += f" --{key} " - elif isinstance(value, list) or isinstance(value, tuple) or isinstance(value, set): + elif ( + isinstance(value, list) + or isinstance(value, tuple) + or isinstance(value, set) + ): value = " ".join(value) args_str += f" --{key} {value} " else: @@ -200,7 +221,13 @@ def string_args(args, args_list): def launch_worker(item, args, worker_args=worker_args): - log_name = item.split("/")[-1].split("\\")[-1].replace("-", "_").replace("@", "_").replace(".", "_") + log_name = ( + item.split("/")[-1] + .split("\\")[-1] + .replace("-", "_") + .replace("@", "_") + .replace(".", "_") + ) # 先分割model-path-address,在传到string_args中分析参数 args.model_path, args.worker_host, args.worker_port = item.split("@") args.worker_address = f"http://{args.worker_host}:{args.worker_port}" @@ -208,21 +235,28 @@ def launch_worker(item, args, worker_args=worker_args): print(f"如长时间未启动,请到{LOG_PATH}{log_name}.log下查看日志") worker_str_args = string_args(args, worker_args) print(worker_str_args) - worker_sh = base_launch_sh.format("model_worker", worker_str_args, LOG_PATH, f"worker_{log_name}") - worker_check_sh = base_check_sh.format(LOG_PATH, f"worker_{log_name}", "model_worker") + worker_sh = base_launch_sh.format( + "model_worker", worker_str_args, LOG_PATH, f"worker_{log_name}" + ) + worker_check_sh = base_check_sh.format( + LOG_PATH, f"worker_{log_name}", "model_worker" + ) subprocess.run(worker_sh, shell=True, check=True) subprocess.run(worker_check_sh, shell=True, check=True) -def launch_all(args, - controller_args=controller_args, - worker_args=worker_args, - server_args=server_args - ): +def launch_all( + args, + controller_args=controller_args, + worker_args=worker_args, + server_args=server_args, +): print(f"Launching llm service,logs are located in {LOG_PATH}...") print(f"开始启动LLM服务,请到{LOG_PATH}下监控各模块日志...") controller_str_args = string_args(args, controller_args) - controller_sh = base_launch_sh.format("controller", controller_str_args, LOG_PATH, "controller") + controller_sh = base_launch_sh.format( + "controller", controller_str_args, LOG_PATH, "controller" + ) controller_check_sh = base_check_sh.format(LOG_PATH, "controller", "controller") subprocess.run(controller_sh, shell=True, check=True) subprocess.run(controller_check_sh, shell=True, check=True) @@ -235,8 +269,12 @@ def launch_all(args, launch_worker(item, args=args, worker_args=worker_args) server_str_args = string_args(args, server_args) - server_sh = base_launch_sh.format("openai_api_server", server_str_args, LOG_PATH, "openai_api_server") - server_check_sh = base_check_sh.format(LOG_PATH, "openai_api_server", "openai_api_server") + server_sh = base_launch_sh.format( + "openai_api_server", server_str_args, LOG_PATH, "openai_api_server" + ) + server_check_sh = base_check_sh.format( + LOG_PATH, "openai_api_server", "openai_api_server" + ) subprocess.run(server_sh, shell=True, check=True) subprocess.run(server_check_sh, shell=True, check=True) print("Launching LLM service done!") @@ -246,8 +284,12 @@ def launch_all(args, if __name__ == "__main__": args = parser.parse_args() # 必须要加http//:,否则InvalidSchema: No connection adapters were found - args = argparse.Namespace(**vars(args), - **{"controller-address": f"http://{args.controller_host}:{str(args.controller_port)}"}) + args = argparse.Namespace( + **vars(args), + **{ + "controller-address": f"http://{args.controller_host}:{str(args.controller_port)}" + }, + ) if args.gpus: if len(args.gpus.split(",")) < args.num_gpus: diff --git a/libs/chatchat-server/chatchat/server/localai_embeddings.py b/libs/chatchat-server/chatchat/server/localai_embeddings.py index 040c8eb9a..e27041c21 100644 --- a/libs/chatchat-server/chatchat/server/localai_embeddings.py +++ b/libs/chatchat-server/chatchat/server/localai_embeddings.py @@ -15,10 +15,10 @@ Union, ) +from langchain_community.utils.openai import is_openai_v1 from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names -from langchain_community.utils.openai import is_openai_v1 from tenacity import ( AsyncRetrying, before_sleep_log, @@ -27,8 +27,8 @@ stop_after_attempt, wait_exponential, ) -from chatchat.server.utils import run_in_thread_pool +from chatchat.server.utils import run_in_thread_pool logger = logging.getLogger(__name__) @@ -284,11 +284,15 @@ def _embedding_func(self, text: str, *, engine: str) -> List[float]: # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") - return embed_with_retry( - self, - input=[text], - **self._invocation_params, - ).data[0].embedding + return ( + embed_with_retry( + self, + input=[text], + **self._invocation_params, + ) + .data[0] + .embedding + ) async def _aembedding_func(self, text: str, *, engine: str) -> List[float]: """Call out to LocalAI's embedding endpoint.""" @@ -298,12 +302,16 @@ async def _aembedding_func(self, text: str, *, engine: str) -> List[float]: # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") return ( - await async_embed_with_retry( - self, - input=[text], - **self._invocation_params, + ( + await async_embed_with_retry( + self, + input=[text], + **self._invocation_params, + ) ) - ).data[0].embedding + .data[0] + .embedding + ) def embed_documents( self, texts: List[str], chunk_size: Optional[int] = 0 @@ -318,6 +326,7 @@ def embed_documents( Returns: List of embeddings, one for each text. """ + # call _embedding_func for each text with multithreads def task(seq, text): return (seq, self._embedding_func(text, engine=self.deployment)) diff --git a/libs/chatchat-server/chatchat/server/memory/conversation_db_buffer_memory.py b/libs/chatchat-server/chatchat/server/memory/conversation_db_buffer_memory.py index 47de79811..66c834bf9 100644 --- a/libs/chatchat-server/chatchat/server/memory/conversation_db_buffer_memory.py +++ b/libs/chatchat-server/chatchat/server/memory/conversation_db_buffer_memory.py @@ -1,11 +1,12 @@ import logging -from typing import Any, List, Dict +from typing import Any, Dict, List from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage +from langchain.schema import AIMessage, BaseMessage, HumanMessage, get_buffer_string from langchain.schema.language_model import BaseLanguageModel -from chatchat.server.db.repository.message_repository import filter_message + from chatchat.server.db.models.message_model import MessageModel +from chatchat.server.db.repository.message_repository import filter_message class ConversationBufferDBMemory(BaseChatMemory): @@ -22,7 +23,9 @@ def buffer(self) -> List[BaseMessage]: """String buffer of memory.""" # fetch limited messages desc, and return reversed - messages = filter_message(conversation_id=self.conversation_id, limit=self.message_limit) + messages = filter_message( + conversation_id=self.conversation_id, limit=self.message_limit + ) # 返回的记录按时间倒序,转为正序 messages = list(reversed(messages)) chat_messages: List[BaseMessage] = [] @@ -39,7 +42,9 @@ def buffer(self) -> List[BaseMessage]: pruned_memory = [] while curr_buffer_length > self.max_token_limit and chat_messages: pruned_memory.append(chat_messages.pop(0)) - curr_buffer_length = self.llm.get_num_tokens(get_buffer_string(chat_messages)) + curr_buffer_length = self.llm.get_num_tokens( + get_buffer_string(chat_messages) + ) return chat_messages @@ -70,4 +75,4 @@ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: def clear(self) -> None: """Nothing to clear, got a memory like a vault.""" - pass \ No newline at end of file + pass diff --git a/libs/chatchat-server/chatchat/server/pydantic_v2.py b/libs/chatchat-server/chatchat/server/pydantic_v2.py index 6ced351c7..6792458c5 100644 --- a/libs/chatchat-server/chatchat/server/pydantic_v2.py +++ b/libs/chatchat-server/chatchat/server/pydantic_v2.py @@ -1,3 +1,3 @@ -from pydantic import * -from pydantic.fields import FieldInfo +from pydantic import * from pydantic import typing +from pydantic.fields import FieldInfo diff --git a/libs/chatchat-server/chatchat/server/reranker/reranker.py b/libs/chatchat-server/chatchat/server/reranker/reranker.py index 0bdcda19e..0a214d07f 100644 --- a/libs/chatchat-server/chatchat/server/reranker/reranker.py +++ b/libs/chatchat-server/chatchat/server/reranker/reranker.py @@ -2,16 +2,18 @@ import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) -from typing import Any, List, Optional -from sentence_transformers import CrossEncoder -from typing import Optional, Sequence -from langchain_core.documents import Document +from typing import Any, List, Optional, Sequence + from langchain.callbacks.manager import Callbacks from langchain.retrievers.document_compressors.base import BaseDocumentCompressor +from langchain_core.documents import Document from pydantic import Field, PrivateAttr +from sentence_transformers import CrossEncoder + class LangchainReranker(BaseDocumentCompressor): """Document compressor that uses `Cohere Rerank API`.""" + model_name_or_path: str = Field() _model: Any = PrivateAttr() top_n: int = Field() @@ -24,17 +26,18 @@ class LangchainReranker(BaseDocumentCompressor): # activation_fct = None # apply_softmax = False - def __init__(self, - model_name_or_path: str, - top_n: int = 3, - device: str = "cuda", - max_length: int = 1024, - batch_size: int = 32, - # show_progress_bar: bool = None, - num_workers: int = 0, - # activation_fct = None, - # apply_softmax = False, - ): + def __init__( + self, + model_name_or_path: str, + top_n: int = 3, + device: str = "cuda", + max_length: int = 1024, + batch_size: int = 32, + # show_progress_bar: bool = None, + num_workers: int = 0, + # activation_fct = None, + # apply_softmax = False, + ): # self.top_n=top_n # self.model_name_or_path=model_name_or_path # self.device=device @@ -45,7 +48,9 @@ def __init__(self, # self.activation_fct=activation_fct # self.apply_softmax=apply_softmax - self._model = CrossEncoder(model_name=model_name_or_path, max_length=max_length, device=device) + self._model = CrossEncoder( + model_name=model_name_or_path, max_length=max_length, device=device + ) super().__init__( top_n=top_n, model_name_or_path=model_name_or_path, @@ -59,10 +64,10 @@ def __init__(self, ) def compress_documents( - self, - documents: Sequence[Document], - query: str, - callbacks: Optional[Callbacks] = None, + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """ Compress documents using Cohere's rerank API. @@ -80,14 +85,15 @@ def compress_documents( doc_list = list(documents) _docs = [d.page_content for d in doc_list] sentence_pairs = [[query, _doc] for _doc in _docs] - results = self._model.predict(sentences=sentence_pairs, - batch_size=self.batch_size, - # show_progress_bar=self.show_progress_bar, - num_workers=self.num_workers, - # activation_fct=self.activation_fct, - # apply_softmax=self.apply_softmax, - convert_to_tensor=True - ) + results = self._model.predict( + sentences=sentence_pairs, + batch_size=self.batch_size, + # show_progress_bar=self.show_progress_bar, + num_workers=self.num_workers, + # activation_fct=self.activation_fct, + # apply_softmax=self.apply_softmax, + convert_to_tensor=True, + ) top_k = self.top_n if self.top_n < len(results) else len(results) values, indices = results.topk(top_k) @@ -100,21 +106,26 @@ def compress_documents( if __name__ == "__main__": - from chatchat.configs import (LLM_MODELS, - VECTOR_SEARCH_TOP_K, - SCORE_THRESHOLD, - TEMPERATURE, - USE_RERANKER, - RERANKER_MODEL, - RERANKER_MAX_LENGTH, - MODEL_PATH) + from chatchat.configs import ( + LLM_MODELS, + MODEL_PATH, + RERANKER_MAX_LENGTH, + RERANKER_MODEL, + SCORE_THRESHOLD, + TEMPERATURE, + USE_RERANKER, + VECTOR_SEARCH_TOP_K, + ) if USE_RERANKER: - reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL, "BAAI/bge-reranker-large") + reranker_model_path = MODEL_PATH["reranker"].get( + RERANKER_MODEL, "BAAI/bge-reranker-large" + ) print("-----------------model path------------------") print(reranker_model_path) - reranker_model = LangchainReranker(top_n=3, - device="cpu", - max_length=RERANKER_MAX_LENGTH, - model_name_or_path=reranker_model_path - ) + reranker_model = LangchainReranker( + top_n=3, + device="cpu", + max_length=RERANKER_MAX_LENGTH, + model_name_or_path=reranker_model_path, + ) diff --git a/libs/chatchat-server/chatchat/server/utils.py b/libs/chatchat-server/chatchat/server/utils.py index 14c7b9c92..700a2fffc 100644 --- a/libs/chatchat-server/chatchat/server/utils.py +++ b/libs/chatchat-server/chatchat/server/utils.py @@ -1,36 +1,41 @@ -from fastapi import FastAPI -from pathlib import Path import asyncio +import logging +import multiprocessing as mp import os +import socket import sys -import multiprocessing as mp -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed - -from langchain_core.embeddings import Embeddings -from langchain.tools import BaseTool -from langchain_openai.chat_models import ChatOpenAI -from langchain_openai.llms import OpenAI -import httpx -import openai +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed +from pathlib import Path from typing import ( - Optional, + Any, + Awaitable, Callable, - Generator, Dict, + Generator, List, - Any, - Awaitable, - Union, - Tuple, Literal, + Optional, + Tuple, + Union, ) -import socket -from chatchat.configs import (log_verbose, HTTPX_DEFAULT_TIMEOUT, - DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, TEMPERATURE, - MODEL_PLATFORMS) -from chatchat.server.pydantic_v2 import BaseModel, Field -import logging +import httpx +import openai +from fastapi import FastAPI +from langchain.tools import BaseTool +from langchain_core.embeddings import Embeddings +from langchain_openai.chat_models import ChatOpenAI +from langchain_openai.llms import OpenAI + +from chatchat.configs import ( + DEFAULT_EMBEDDING_MODEL, + DEFAULT_LLM_MODEL, + HTTPX_DEFAULT_TIMEOUT, + MODEL_PLATFORMS, + TEMPERATURE, + log_verbose, +) +from chatchat.server.pydantic_v2 import BaseModel, Field logger = logging.getLogger() @@ -42,8 +47,9 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): except Exception as e: logging.exception(e) msg = f"Caught exception: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", exc_info=e if log_verbose else None + ) finally: # Signal the aiter to stop. event.set() @@ -59,11 +65,13 @@ def get_config_platforms() -> Dict[str, Dict]: def get_config_models( - model_name: str = None, - model_type: Literal["llm", "embed", "image", "reranking","speech2text","tts"] = None, - platform_name: str = None, + model_name: str = None, + model_type: Literal[ + "llm", "embed", "image", "reranking", "speech2text", "tts" + ] = None, + platform_name: str = None, ) -> Dict[str, Dict]: - ''' + """ 获取配置的模型列表,返回值为: {model_name: { "platform_name": xx, @@ -74,7 +82,7 @@ def get_config_models( "api_key": xx, "api_proxy": xx, }} - ''' + """ # import importlib # 不能支持重载 # from chatchat.configs import model_config @@ -95,7 +103,7 @@ def get_config_models( "reranking_models", "speech2text_models", "tts_models", - ] + ] else: model_types = [f"{model_type}_models"] @@ -114,11 +122,13 @@ def get_config_models( return result -def get_model_info(model_name: str = None, platform_name: str = None, multiple: bool = False) -> Dict: - ''' +def get_model_info( + model_name: str = None, platform_name: str = None, multiple: bool = False +) -> Dict: + """ 获取配置的模型信息,主要是 api_base_url, api_key 如果指定 multiple=True,则返回所有重名模型;否则仅返回第一个 - ''' + """ result = get_config_models(model_name=model_name, platform_name=platform_name) if len(result) > 0: if multiple: @@ -130,14 +140,14 @@ def get_model_info(model_name: str = None, platform_name: str = None, multiple: def get_ChatOpenAI( - model_name: str = DEFAULT_LLM_MODEL, - temperature: float = TEMPERATURE, - max_tokens: int = None, - streaming: bool = True, - callbacks: List[Callable] = [], - verbose: bool = True, - local_wrap: bool = False, # use local wrapped api - **kwargs: Any, + model_name: str = DEFAULT_LLM_MODEL, + temperature: float = TEMPERATURE, + max_tokens: int = None, + streaming: bool = True, + callbacks: List[Callable] = [], + verbose: bool = True, + local_wrap: bool = False, # use local wrapped api + **kwargs: Any, ) -> ChatOpenAI: model_info = get_model_info(model_name) params = dict( @@ -147,7 +157,7 @@ def get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, - **kwargs + **kwargs, ) try: if local_wrap: @@ -163,21 +173,23 @@ def get_ChatOpenAI( ) model = ChatOpenAI(**params) except Exception as e: - logger.error(f"failed to create ChatOpenAI for model: {model_name}.", exc_info=True) + logger.error( + f"failed to create ChatOpenAI for model: {model_name}.", exc_info=True + ) model = None return model def get_OpenAI( - model_name: str, - temperature: float, - max_tokens: int = None, - streaming: bool = True, - echo: bool = True, - callbacks: List[Callable] = [], - verbose: bool = True, - local_wrap: bool = False, # use local wrapped api - **kwargs: Any, + model_name: str, + temperature: float, + max_tokens: int = None, + streaming: bool = True, + echo: bool = True, + callbacks: List[Callable] = [], + verbose: bool = True, + local_wrap: bool = False, # use local wrapped api + **kwargs: Any, ) -> OpenAI: # TODO: 从API获取模型信息 model_info = get_model_info(model_name) @@ -189,7 +201,7 @@ def get_OpenAI( temperature=temperature, max_tokens=max_tokens, echo=echo, - **kwargs + **kwargs, ) try: if local_wrap: @@ -211,12 +223,15 @@ def get_OpenAI( def get_Embeddings( - embed_model: str = DEFAULT_EMBEDDING_MODEL, - local_wrap: bool = False, # use local wrapped api + embed_model: str = DEFAULT_EMBEDDING_MODEL, + local_wrap: bool = False, # use local wrapped api ) -> Embeddings: - from langchain_openai import OpenAIEmbeddings from langchain_community.embeddings import OllamaEmbeddings - from chatchat.server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154 + from langchain_openai import OpenAIEmbeddings + + from chatchat.server.localai_embeddings import ( + LocalAIEmbeddings, + ) model_info = get_model_info(model_name=embed_model) params = dict(model=embed_model) @@ -235,37 +250,46 @@ def get_Embeddings( if model_info.get("platform_type") == "openai": return OpenAIEmbeddings(**params) elif model_info.get("platform_type") == "ollama": - return OllamaEmbeddings(base_url=model_info.get("api_base_url").replace('/v1', ''), - model=embed_model, - ) + return OllamaEmbeddings( + base_url=model_info.get("api_base_url").replace("/v1", ""), + model=embed_model, + ) else: return LocalAIEmbeddings(**params) except Exception as e: - logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True) + logger.error( + f"failed to create Embeddings for model: {embed_model}.", exc_info=True + ) -def check_embed_model(embed_model: str=DEFAULT_EMBEDDING_MODEL) -> bool: +def check_embed_model(embed_model: str = DEFAULT_EMBEDDING_MODEL) -> bool: embeddings = get_Embeddings(embed_model=embed_model) try: embeddings.embed_query("this is a test") return True except Exception as e: - logger.error(f"failed to access embed model '{embed_model}': {e}", exc_info=True) + logger.error( + f"failed to access embed model '{embed_model}': {e}", exc_info=True + ) return False def get_OpenAIClient( - platform_name: str = None, - model_name: str = None, - is_async: bool = True, + platform_name: str = None, + model_name: str = None, + is_async: bool = True, ) -> Union[openai.Client, openai.AsyncClient]: - ''' + """ construct an openai Client for specified platform or model - ''' + """ if platform_name is None: - platform_info = get_model_info(model_name=model_name, platform_name=platform_name) + platform_info = get_model_info( + model_name=model_name, platform_name=platform_name + ) if platform_info is None: - raise RuntimeError(f"cannot find configured platform for model: {model_name}") + raise RuntimeError( + f"cannot find configured platform for model: {model_name}" + ) platform_name = platform_info.get("platform_name") platform_info = get_config_platforms().get(platform_name) assert platform_info, f"cannot find configured platform: {platform_name}" @@ -337,11 +361,11 @@ class Config: "example": { "question": "工伤保险如何办理?", "response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n" - "2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n" - "3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n" - "4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n" - "5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n" - "6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。", + "2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n" + "3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n" + "4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n" + "5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n" + "6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。", "history": [ [ "工伤保险是什么?", @@ -360,9 +384,9 @@ class Config: def run_async(cor): - ''' + """ 在同步环境中运行异步代码. - ''' + """ try: loop = asyncio.get_event_loop() except: @@ -371,9 +395,9 @@ def run_async(cor): def iter_over_async(ait, loop=None): - ''' + """ 将异步生成器封装成同步生成器. - ''' + """ ait = ait.__aiter__() async def get_next(): @@ -397,11 +421,11 @@ async def get_next(): def MakeFastAPIOffline( - app: FastAPI, - static_dir=Path(__file__).parent / "api_server" / "static", - static_url="/static-offline-docs", - docs_url: Optional[str] = "/docs", - redoc_url: Optional[str] = "/redoc", + app: FastAPI, + static_dir=Path(__file__).parent / "api_server" / "static", + static_url="/static-offline-docs", + docs_url: Optional[str] = "/docs", + redoc_url: Optional[str] = "/redoc", ) -> None: """patch the FastAPI obj that doesn't rely on CDN for the documentation page""" from fastapi import Request @@ -417,9 +441,9 @@ def MakeFastAPIOffline( swagger_ui_oauth2_redirect_url = app.swagger_ui_oauth2_redirect_url def remove_route(url: str) -> None: - ''' + """ remove original route from app - ''' + """ index = None for i, r in enumerate(app.routes): if r.path.lower() == url.lower(): @@ -530,10 +554,10 @@ def webui_address() -> str: def get_prompt_template(type: str, name: str) -> Optional[str]: - ''' + """ 从prompt_config中加载模板内容 type: "llm_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。 - ''' + """ from chatchat.configs import PROMPT_TEMPLATES @@ -541,19 +565,20 @@ def get_prompt_template(type: str, name: str) -> Optional[str]: def set_httpx_config( - timeout: float = HTTPX_DEFAULT_TIMEOUT, - proxy: Union[str, Dict] = None, - unused_proxies: List[str] = [], + timeout: float = HTTPX_DEFAULT_TIMEOUT, + proxy: Union[str, Dict] = None, + unused_proxies: List[str] = [], ): - ''' + """ 设置httpx默认timeout。httpx默认timeout是5秒,在请求LLM回答时不够用。 将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效) 对于chatgpt等在线API,如要使用代理需要手动配置。搜索引擎的代理如何处置还需考虑。 - ''' + """ - import httpx import os + import httpx + httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout @@ -574,7 +599,9 @@ def set_httpx_config( os.environ[k] = v # set host to bypass proxy - no_proxy = [x.strip() for x in os.environ.get("no_proxy", "").split(",") if x.strip()] + no_proxy = [ + x.strip() for x in os.environ.get("no_proxy", "").split(",") if x.strip() + ] no_proxy += [ # do not use proxy for locahost "http://127.0.0.1", @@ -591,17 +618,18 @@ def _get_proxies(): return proxies import urllib.request + urllib.request.getproxies = _get_proxies def run_in_thread_pool( - func: Callable, - params: List[Dict] = [], + func: Callable, + params: List[Dict] = [], ) -> Generator: - ''' + """ 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 - ''' + """ tasks = [] with ThreadPoolExecutor() as pool: for kwargs in params: @@ -615,17 +643,19 @@ def run_in_thread_pool( def run_in_process_pool( - func: Callable, - params: List[Dict] = [], + func: Callable, + params: List[Dict] = [], ) -> Generator: - ''' + """ 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 - ''' + """ tasks = [] max_workers = None if sys.platform.startswith("win"): - max_workers = min(mp.cpu_count(), 60) # max_workers should not exceed 60 on windows + max_workers = min( + mp.cpu_count(), 60 + ) # max_workers should not exceed 60 on windows with ProcessPoolExecutor(max_workers=max_workers) as pool: for kwargs in params: tasks.append(pool.submit(func, **kwargs)) @@ -638,15 +668,15 @@ def run_in_process_pool( def get_httpx_client( - use_async: bool = False, - proxies: Union[str, Dict] = None, - timeout: float = HTTPX_DEFAULT_TIMEOUT, - unused_proxies: List[str] = [], - **kwargs, + use_async: bool = False, + proxies: Union[str, Dict] = None, + timeout: float = HTTPX_DEFAULT_TIMEOUT, + unused_proxies: List[str] = [], + **kwargs, ) -> Union[httpx.Client, httpx.AsyncClient]: - ''' + """ helper to get httpx client with default proxies that bypass local addesses. - ''' + """ default_proxies = { # do not use proxy for locahost "all://127.0.0.1": None, @@ -659,21 +689,34 @@ def get_httpx_client( # get proxies from system envionrent # proxy not str empty string, None, False, 0, [] or {} - default_proxies.update({ - "http://": (os.environ.get("http_proxy") - if os.environ.get("http_proxy") and len(os.environ.get("http_proxy").strip()) - else None), - "https://": (os.environ.get("https_proxy") - if os.environ.get("https_proxy") and len(os.environ.get("https_proxy").strip()) - else None), - "all://": (os.environ.get("all_proxy") - if os.environ.get("all_proxy") and len(os.environ.get("all_proxy").strip()) - else None), - }) + default_proxies.update( + { + "http://": ( + os.environ.get("http_proxy") + if os.environ.get("http_proxy") + and len(os.environ.get("http_proxy").strip()) + else None + ), + "https://": ( + os.environ.get("https_proxy") + if os.environ.get("https_proxy") + and len(os.environ.get("https_proxy").strip()) + else None + ), + "all://": ( + os.environ.get("all_proxy") + if os.environ.get("all_proxy") + and len(os.environ.get("all_proxy").strip()) + else None + ), + } + ) for host in os.environ.get("no_proxy", "").split(","): if host := host.strip(): # default_proxies.update({host: None}) # Origin code - default_proxies.update({'all://' + host: None}) # PR 1838 fix, if not add 'all://', httpx will raise error + default_proxies.update( + {"all://" + host: None} + ) # PR 1838 fix, if not add 'all://', httpx will raise error # merge default proxies with user provided proxies if isinstance(proxies, str): @@ -692,9 +735,9 @@ def get_httpx_client( def get_server_configs() -> Dict: - ''' + """ 获取configs中的原始配置项,供前端使用 - ''' + """ _custom = { "api_address": api_address(), } @@ -703,12 +746,13 @@ def get_server_configs() -> Dict: def get_temp_dir(id: str = None) -> Tuple[str, str]: - ''' + """ 创建一个临时目录,返回(路径,文件夹名称) - ''' - from chatchat.configs import BASE_TEMP_DIR + """ import uuid + from chatchat.configs import BASE_TEMP_DIR + if id is not None: # 如果指定的临时目录已存在,直接返回 path = os.path.join(BASE_TEMP_DIR, id) if os.path.isdir(path): @@ -723,26 +767,37 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]: # 动态更新知识库信息 def update_search_local_knowledgebase_tool(): import re + from chatchat.server.agent.tools_factory import tools_registry from chatchat.server.db.repository.knowledge_base_repository import list_kbs_from_db - kbs=list_kbs_from_db() + + kbs = list_kbs_from_db() template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]." - KB_info_str = '\n'.join([f"{kb.kb_name}: {kb.kb_info}" for kb in kbs]) - KB_name_info_str = '\n'.join([f"{kb.kb_name}" for kb in kbs]) + KB_info_str = "\n".join([f"{kb.kb_name}: {kb.kb_info}" for kb in kbs]) + KB_name_info_str = "\n".join([f"{kb.kb_name}" for kb in kbs]) template_knowledge = template.format(KB_info=KB_info_str, key=KB_name_info_str) - search_local_knowledgebase_tool=tools_registry._TOOLS_REGISTRY.get("search_local_knowledgebase") + search_local_knowledgebase_tool = tools_registry._TOOLS_REGISTRY.get( + "search_local_knowledgebase" + ) if search_local_knowledgebase_tool: - search_local_knowledgebase_tool.description = " ".join(re.split(r"\n+\s*", template_knowledge)) - search_local_knowledgebase_tool.args["database"]["choices"]=[kb.kb_name for kb in kbs] + search_local_knowledgebase_tool.description = " ".join( + re.split(r"\n+\s*", template_knowledge) + ) + search_local_knowledgebase_tool.args["database"]["choices"] = [ + kb.kb_name for kb in kbs + ] def get_tool(name: str = None) -> Union[BaseTool, Dict[str, BaseTool]]: import importlib + from chatchat.server.agent import tools_factory + importlib.reload(tools_factory) from chatchat.server.agent.tools_factory import tools_registry + update_search_local_knowledgebase_tool() if name is None: return tools_registry._TOOLS_REGISTRY @@ -752,17 +807,18 @@ def get_tool(name: str = None) -> Union[BaseTool, Dict[str, BaseTool]]: def get_tool_config(name: str = None) -> Dict: import importlib + # TODO 因为使用了变量更新,不支持重载 # from chatchat.configs import model_config # importlib.reload(model_config) from chatchat.configs import TOOL_CONFIG + if name is None: return TOOL_CONFIG else: return TOOL_CONFIG.get(name, {}) - def is_port_in_use(port): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - return sock.connect_ex(('localhost', port)) == 0 \ No newline at end of file + return sock.connect_ex(("localhost", port)) == 0 diff --git a/libs/chatchat-server/chatchat/server/webui_allinone_stale.py b/libs/chatchat-server/chatchat/server/webui_allinone_stale.py index 1aa57d07f..607c09f61 100644 --- a/libs/chatchat-server/chatchat/server/webui_allinone_stale.py +++ b/libs/chatchat-server/chatchat/server/webui_allinone_stale.py @@ -15,55 +15,71 @@ python webui_alline.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB """ +import os +import subprocess + import streamlit as st -from chatchat.webui_pages.utils import * from streamlit_option_menu import option_menu -from chatchat.webui_pages import * -import os -from chatchat.server.llm_api_stale import string_args,launch_all,controller_args,worker_args,server_args,LOG_PATH -from chatchat.server.api_allinone_stale import parser, api_args -import subprocess +from chatchat.server.api_allinone_stale import api_args, parser +from chatchat.server.llm_api_stale import ( + LOG_PATH, + controller_args, + launch_all, + server_args, + string_args, + worker_args, +) +from chatchat.webui_pages import * +from chatchat.webui_pages.utils import * -parser.add_argument("--use-remote-api",action="store_true") -parser.add_argument("--nohup",action="store_true") -parser.add_argument("--server.port",type=int,default=8501) -parser.add_argument("--theme.base",type=str,default='"light"') -parser.add_argument("--theme.primaryColor",type=str,default='"#165dff"') -parser.add_argument("--theme.secondaryBackgroundColor",type=str,default='"#f5f5f5"') -parser.add_argument("--theme.textColor",type=str,default='"#000000"') -web_args = ["server.port","theme.base","theme.primaryColor","theme.secondaryBackgroundColor","theme.textColor"] +parser.add_argument("--use-remote-api", action="store_true") +parser.add_argument("--nohup", action="store_true") +parser.add_argument("--server.port", type=int, default=8501) +parser.add_argument("--theme.base", type=str, default='"light"') +parser.add_argument("--theme.primaryColor", type=str, default='"#165dff"') +parser.add_argument("--theme.secondaryBackgroundColor", type=str, default='"#f5f5f5"') +parser.add_argument("--theme.textColor", type=str, default='"#000000"') +web_args = [ + "server.port", + "theme.base", + "theme.primaryColor", + "theme.secondaryBackgroundColor", + "theme.textColor", +] -def launch_api(args,args_list=api_args,log_name=None): +def launch_api(args, args_list=api_args, log_name=None): print("Launching api ...") print("启动API服务...") if not log_name: log_name = f"{LOG_PATH}api_{args.api_host}_{args.api_port}" print(f"logs on api are written in {log_name}") print(f"API日志位于{log_name}下,如启动异常请查看日志") - args_str = string_args(args,args_list) + args_str = string_args(args, args_list) api_sh = "python server/{script} {args_str} >{log_name}.log 2>&1 &".format( - script="api.py",args_str=args_str,log_name=log_name) + script="api.py", args_str=args_str, log_name=log_name + ) subprocess.run(api_sh, shell=True, check=True) print("launch api done!") print("启动API服务完毕.") -def launch_webui(args,args_list=web_args,log_name=None): + +def launch_webui(args, args_list=web_args, log_name=None): print("Launching webui...") print("启动webui服务...") if not log_name: log_name = f"{LOG_PATH}webui" - args_str = string_args(args,args_list) + args_str = string_args(args, args_list) if args.nohup: print(f"logs on api are written in {log_name}") print(f"webui服务日志位于{log_name}下,如启动异常请查看日志") webui_sh = "streamlit run webui.py {args_str} >{log_name}.log 2>&1 &".format( - args_str=args_str,log_name=log_name) + args_str=args_str, log_name=log_name + ) else: - webui_sh = "streamlit run webui.py {args_str}".format( - args_str=args_str) + webui_sh = "streamlit run webui.py {args_str}".format(args_str=args_str) subprocess.run(webui_sh, shell=True, check=True) print("launch webui done!") print("启动webui服务完毕.") @@ -71,13 +87,20 @@ def launch_webui(args,args_list=web_args,log_name=None): if __name__ == "__main__": print("Starting webui_allineone.py, it would take a while, please be patient....") - print(f"开始启动webui_allinone,启动LLM服务需要约3-10分钟,请耐心等待,如长时间未启动,请到{LOG_PATH}下查看日志...") + print( + f"开始启动webui_allinone,启动LLM服务需要约3-10分钟,请耐心等待,如长时间未启动,请到{LOG_PATH}下查看日志..." + ) args = parser.parse_args() - print("*"*80) + print("*" * 80) if not args.use_remote_api: - launch_all(args=args,controller_args=controller_args,worker_args=worker_args,server_args=server_args) - launch_api(args=args,args_list=api_args) - launch_webui(args=args,args_list=web_args) + launch_all( + args=args, + controller_args=controller_args, + worker_args=worker_args, + server_args=server_args, + ) + launch_api(args=args, args_list=api_args) + launch_webui(args=args, args_list=web_args) print("Start webui_allinone.py done!") - print("感谢耐心等待,启动webui_allinone完毕。") \ No newline at end of file + print("感谢耐心等待,启动webui_allinone完毕。") diff --git a/libs/chatchat-server/chatchat/startup.py b/libs/chatchat-server/chatchat/startup.py index 596e18893..d14a270ef 100644 --- a/libs/chatchat-server/chatchat/startup.py +++ b/libs/chatchat-server/chatchat/startup.py @@ -1,13 +1,13 @@ import asyncio +import logging +import logging.config import multiprocessing -from contextlib import asynccontextmanager import multiprocessing as mp import os - -import logging -import logging.config import sys +from contextlib import asynccontextmanager from multiprocessing import Process + logger = logging.getLogger() # 设置numexpr最大线程数,默认为CPU核心数 @@ -19,9 +19,10 @@ except: pass -from fastapi import FastAPI import argparse -from typing import List, Dict +from typing import Dict, List + +from fastapi import FastAPI def _set_app_event(app: FastAPI, started_event: mp.Event = None): @@ -35,15 +36,19 @@ async def lifespan(app: FastAPI): def run_init_server( - model_platforms_shard: Dict, - started_event: mp.Event = None, - model_providers_cfg_path: str = None, - provider_host: str = None, - provider_port: int = None): + model_platforms_shard: Dict, + started_event: mp.Event = None, + model_providers_cfg_path: str = None, + provider_host: str = None, + provider_port: int = None, +): + from chatchat.configs import ( + MODEL_PROVIDERS_CFG_HOST, + MODEL_PROVIDERS_CFG_PATH_CONFIG, + MODEL_PROVIDERS_CFG_PORT, + ) from chatchat.init_server import init_server - from chatchat.configs import (MODEL_PROVIDERS_CFG_PATH_CONFIG, - MODEL_PROVIDERS_CFG_HOST, - MODEL_PROVIDERS_CFG_PORT) + if model_providers_cfg_path is None: model_providers_cfg_path = MODEL_PROVIDERS_CFG_PATH_CONFIG if provider_host is None: @@ -51,28 +56,30 @@ def run_init_server( if provider_port is None: provider_port = MODEL_PROVIDERS_CFG_PORT - init_server(model_platforms_shard=model_platforms_shard, - started_event=started_event, - model_providers_cfg_path=model_providers_cfg_path, - provider_host=provider_host, - provider_port=provider_port) + init_server( + model_platforms_shard=model_platforms_shard, + started_event=started_event, + model_providers_cfg_path=model_providers_cfg_path, + provider_host=provider_host, + provider_port=provider_port, + ) -def run_api_server(model_platforms_shard: Dict, - started_event: mp.Event = None, - run_mode: str = None): - from chatchat.server.api_server.server_app import create_app +def run_api_server( + model_platforms_shard: Dict, started_event: mp.Event = None, run_mode: str = None +): import uvicorn - from chatchat.server.utils import set_httpx_config - from chatchat.configs import MODEL_PLATFORMS, API_SERVER from model_providers.core.utils.utils import ( get_config_dict, get_log_file, get_timestamp_ms, ) - from chatchat.configs import LOG_PATH - MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms']) + from chatchat.configs import API_SERVER, LOG_PATH, MODEL_PLATFORMS + from chatchat.server.api_server.server_app import create_app + from chatchat.server.utils import set_httpx_config + + MODEL_PLATFORMS.extend(model_platforms_shard["provider_platforms"]) logger.info(f"Api MODEL_PLATFORMS: {MODEL_PLATFORMS}") set_httpx_config() app = create_app(run_mode=run_mode) @@ -84,65 +91,99 @@ def run_api_server(model_platforms_shard: Dict, logging_conf = get_config_dict( "INFO", get_log_file(log_path=LOG_PATH, sub_dir=f"run_api_server_{get_timestamp_ms()}"), - - 1024*1024*1024*3, - 1024*1024*1024*3, + 1024 * 1024 * 1024 * 3, + 1024 * 1024 * 1024 * 3, ) logging.config.dictConfig(logging_conf) # type: ignore uvicorn.run(app, host=host, port=port) -def run_webui(model_platforms_shard: Dict, - started_event: mp.Event = None, run_mode: str = None): +def run_webui( + model_platforms_shard: Dict, started_event: mp.Event = None, run_mode: str = None +): import sys - from chatchat.server.utils import set_httpx_config - from chatchat.configs import MODEL_PLATFORMS, WEBUI_SERVER + from model_providers.core.utils.utils import ( get_config_dict, get_log_file, get_timestamp_ms, ) - from chatchat.configs import LOG_PATH - if model_platforms_shard.get('provider_platforms'): - MODEL_PLATFORMS.extend(model_platforms_shard.get('provider_platforms')) + from chatchat.configs import LOG_PATH, MODEL_PLATFORMS, WEBUI_SERVER + from chatchat.server.utils import set_httpx_config + + if model_platforms_shard.get("provider_platforms"): + MODEL_PLATFORMS.extend(model_platforms_shard.get("provider_platforms")) logger.info(f"Webui MODEL_PLATFORMS: {MODEL_PLATFORMS}") set_httpx_config() host = WEBUI_SERVER["host"] port = WEBUI_SERVER["port"] - script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'webui.py') - - flag_options = {'server_address': host, - 'server_port': port, - 'theme_base': 'light', - 'theme_primaryColor': '#165dff', - 'theme_secondaryBackgroundColor': '#f5f5f5', - 'theme_textColor': '#000000', - 'global_disableWatchdogWarning': None, - 'global_disableWidgetStateDuplicationWarning': None, - 'global_showWarningOnDirectExecution': None, - 'global_developmentMode': None, 'global_logLevel': None, 'global_unitTest': None, - 'global_suppressDeprecationWarnings': None, 'global_minCachedMessageSize': None, - 'global_maxCachedMessageAge': None, 'global_storeCachedForwardMessagesInMemory': None, - 'global_dataFrameSerialization': None, 'logger_level': None, 'logger_messageFormat': None, - 'logger_enableRich': None, 'client_caching': None, 'client_displayEnabled': None, - 'client_showErrorDetails': None, 'client_toolbarMode': None, 'client_showSidebarNavigation': None, - 'runner_magicEnabled': None, 'runner_installTracer': None, 'runner_fixMatplotlib': None, - 'runner_postScriptGC': None, 'runner_fastReruns': None, - 'runner_enforceSerializableSessionState': None, 'runner_enumCoercion': None, - 'server_folderWatchBlacklist': None, 'server_fileWatcherType': None, 'server_headless': None, - 'server_runOnSave': None, 'server_allowRunOnSave': None, 'server_scriptHealthCheckEnabled': None, - 'server_baseUrlPath': None, 'server_enableCORS': None, 'server_enableXsrfProtection': None, - 'server_maxUploadSize': None, 'server_maxMessageSize': None, 'server_enableArrowTruncation': None, - 'server_enableWebsocketCompression': None, 'server_enableStaticServing': None, - 'browser_serverAddress': None, 'browser_gatherUsageStats': None, 'browser_serverPort': None, - 'server_sslCertFile': None, 'server_sslKeyFile': None, 'ui_hideTopBar': None, - 'ui_hideSidebarNav': None, 'magic_displayRootDocString': None, - 'magic_displayLastExprIfNoSemicolon': None, 'deprecation_showfileUploaderEncoding': None, - 'deprecation_showImageFormat': None, 'deprecation_showPyplotGlobalUse': None, - 'theme_backgroundColor': None, 'theme_font': None} + script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "webui.py") + + flag_options = { + "server_address": host, + "server_port": port, + "theme_base": "light", + "theme_primaryColor": "#165dff", + "theme_secondaryBackgroundColor": "#f5f5f5", + "theme_textColor": "#000000", + "global_disableWatchdogWarning": None, + "global_disableWidgetStateDuplicationWarning": None, + "global_showWarningOnDirectExecution": None, + "global_developmentMode": None, + "global_logLevel": None, + "global_unitTest": None, + "global_suppressDeprecationWarnings": None, + "global_minCachedMessageSize": None, + "global_maxCachedMessageAge": None, + "global_storeCachedForwardMessagesInMemory": None, + "global_dataFrameSerialization": None, + "logger_level": None, + "logger_messageFormat": None, + "logger_enableRich": None, + "client_caching": None, + "client_displayEnabled": None, + "client_showErrorDetails": None, + "client_toolbarMode": None, + "client_showSidebarNavigation": None, + "runner_magicEnabled": None, + "runner_installTracer": None, + "runner_fixMatplotlib": None, + "runner_postScriptGC": None, + "runner_fastReruns": None, + "runner_enforceSerializableSessionState": None, + "runner_enumCoercion": None, + "server_folderWatchBlacklist": None, + "server_fileWatcherType": None, + "server_headless": None, + "server_runOnSave": None, + "server_allowRunOnSave": None, + "server_scriptHealthCheckEnabled": None, + "server_baseUrlPath": None, + "server_enableCORS": None, + "server_enableXsrfProtection": None, + "server_maxUploadSize": None, + "server_maxMessageSize": None, + "server_enableArrowTruncation": None, + "server_enableWebsocketCompression": None, + "server_enableStaticServing": None, + "browser_serverAddress": None, + "browser_gatherUsageStats": None, + "browser_serverPort": None, + "server_sslCertFile": None, + "server_sslKeyFile": None, + "ui_hideTopBar": None, + "ui_hideSidebarNav": None, + "magic_displayRootDocString": None, + "magic_displayLastExprIfNoSemicolon": None, + "deprecation_showfileUploaderEncoding": None, + "deprecation_showImageFormat": None, + "deprecation_showPyplotGlobalUse": None, + "theme_backgroundColor": None, + "theme_font": None, + } args = [] if run_mode == "lite": @@ -157,14 +198,12 @@ def run_webui(model_platforms_shard: Dict, except ImportError: from streamlit import bootstrap - logging_conf = get_config_dict( "INFO", get_log_file(log_path=LOG_PATH, sub_dir=f"run_webui_{get_timestamp_ms()}"), - - 1024*1024*1024*3, - 1024*1024*1024*3, - ) + 1024 * 1024 * 1024 * 3, + 1024 * 1024 * 1024 * 3, + ) logging.config.dictConfig(logging_conf) # type: ignore bootstrap.load_config_options(flag_options=flag_options) bootstrap.run(script_dir, False, args, flag_options) @@ -214,9 +253,12 @@ def parse_args() -> argparse.ArgumentParser: def dump_server_info(after_start=False, args=None): import platform + import langchain + + from chatchat.configs import DEFAULT_EMBEDDING_MODEL, TEXT_SPLITTER_NAME, VERSION from chatchat.server.utils import api_address, webui_address - from chatchat.configs import VERSION, TEXT_SPLITTER_NAME, DEFAULT_EMBEDDING_MODEL + print("\n") print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) print(f"操作系统:{platform.platform()}.") @@ -236,7 +278,8 @@ def dump_server_info(after_start=False, args=None): print( f" Chatchat Model providers Server: model_providers_cfg_path_config:{MODEL_PROVIDERS_CFG_PATH_CONFIG}\n" f" provider_host:{MODEL_PROVIDERS_CFG_HOST}\n" - f" provider_host:{MODEL_PROVIDERS_CFG_HOST}\n") + f" provider_host:{MODEL_PROVIDERS_CFG_HOST}\n" + ) print(f" Chatchat Api Server: {api_address()}") if args.webui: @@ -246,23 +289,27 @@ def dump_server_info(after_start=False, args=None): async def start_main_server(): - import time import signal - from chatchat.configs import LOG_PATH + import time + from model_providers.core.utils.utils import ( get_config_dict, get_log_file, get_timestamp_ms, ) + from chatchat.configs import LOG_PATH + logging_conf = get_config_dict( "INFO", - get_log_file(log_path=LOG_PATH, sub_dir=f"start_main_server_{get_timestamp_ms()}"), - - 1024*1024*1024*3, - 1024*1024*1024*3, - ) + get_log_file( + log_path=LOG_PATH, sub_dir=f"start_main_server_{get_timestamp_ms()}" + ), + 1024 * 1024 * 1024 * 3, + 1024 * 1024 * 1024 * 3, + ) logging.config.dictConfig(logging_conf) # type: ignore + def handler(signalname): """ Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed. @@ -319,7 +366,10 @@ def process_count(): process = Process( target=run_init_server, name=f"Model providers Server", - kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=model_providers_started), + kwargs=dict( + model_platforms_shard=model_platforms_shard, + started_event=model_providers_started, + ), daemon=True, ) processes["model_providers"] = process @@ -328,7 +378,11 @@ def process_count(): process = Process( target=run_api_server, name=f"API Server", - kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=api_started, run_mode=run_mode), + kwargs=dict( + model_platforms_shard=model_platforms_shard, + started_event=api_started, + run_mode=run_mode, + ), daemon=False, ) processes["api"] = process @@ -338,7 +392,11 @@ def process_count(): process = Process( target=run_webui, name=f"WEBUI Server", - kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=webui_started, run_mode=run_mode), + kwargs=dict( + model_platforms_shard=model_platforms_shard, + started_event=webui_started, + run_mode=run_mode, + ), daemon=True, ) processes["webui"] = process @@ -375,7 +433,6 @@ def process_count(): logger.error(e) logger.warning("Caught KeyboardInterrupt! Setting stop event...") finally: - for p in processes.values(): logger.warning("Sending SIGKILL to %s", p) # Queues and other inter-process communication primitives can break when @@ -396,8 +453,9 @@ def main(): cwd = os.getcwd() sys.path.append(cwd) multiprocessing.freeze_support() - print("cwd:"+cwd) + print("cwd:" + cwd) from chatchat.server.knowledge_base.migrate import create_tables + create_tables() if sys.version_info < (3, 10): loop = asyncio.get_event_loop() diff --git a/libs/chatchat-server/chatchat/webui.py b/libs/chatchat-server/chatchat/webui.py index 54241044b..aa0fdc609 100644 --- a/libs/chatchat-server/chatchat/webui.py +++ b/libs/chatchat-server/chatchat/webui.py @@ -5,26 +5,25 @@ from chatchat.configs import VERSION from chatchat.server.utils import api_address -from chatchat.webui_pages.utils import * -from chatchat.webui_pages.dialogue.dialogue import dialogue_page, chat_box +from chatchat.webui_pages.dialogue.dialogue import chat_box, dialogue_page from chatchat.webui_pages.knowledge_base.knowledge_base import knowledge_base_page - +from chatchat.webui_pages.utils import * api = ApiRequest(base_url=api_address()) if __name__ == "__main__": - is_lite = "lite" in sys.argv # TODO: remove lite mode + is_lite = "lite" in sys.argv # TODO: remove lite mode st.set_page_config( "Langchain-Chatchat WebUI", get_img_base64("chatchat_icon_blue_square_v2.png"), initial_sidebar_state="expanded", menu_items={ - 'Get Help': 'https://github.com/chatchat-space/Langchain-Chatchat', - 'Report a bug': "https://github.com/chatchat-space/Langchain-Chatchat/issues", - 'About': f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}!""" + "Get Help": "https://github.com/chatchat-space/Langchain-Chatchat", + "Report a bug": "https://github.com/chatchat-space/Langchain-Chatchat/issues", + "About": f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}!""", }, - layout="centered" + layout="centered", ) # use the following code to set the app to wide mode and the html markdown to increase the sidebar width @@ -46,8 +45,7 @@ with st.sidebar: st.image( - get_img_base64('logo-long-chatchat-trans-v2.png'), - use_column_width=True + get_img_base64("logo-long-chatchat-trans-v2.png"), use_column_width=True ) st.caption( f"""

当前版本:{VERSION}

""", @@ -60,7 +58,7 @@ sac.MenuItem("知识库管理", icon="hdd-stack"), ], key="selected_page", - open_index=0 + open_index=0, ) sac.divider() diff --git a/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py b/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py index e12f08695..67d2adcfb 100644 --- a/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py +++ b/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py @@ -1,9 +1,9 @@ import base64 -from datetime import datetime import os -from urllib.parse import urlencode import uuid -from typing import List, Dict +from datetime import datetime +from typing import Dict, List +from urllib.parse import urlencode # from audio_recorder_streamlit import audio_recorder import openai @@ -12,27 +12,34 @@ from streamlit_chatbox import * from streamlit_extras.bottom_container import bottom -from chatchat.configs import (LLM_MODEL_CONFIG, TEMPERATURE, MODEL_PLATFORMS, DEFAULT_LLM_MODEL, - DEFAULT_EMBEDDING_MODEL) +from chatchat.configs import ( + DEFAULT_EMBEDDING_MODEL, + DEFAULT_LLM_MODEL, + LLM_MODEL_CONFIG, + MODEL_PLATFORMS, + TEMPERATURE, +) from chatchat.server.callback_handler.agent_callback_handler import AgentStatus from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId from chatchat.server.utils import MsgType, get_config_models -from chatchat.webui_pages.utils import * from chatchat.webui_pages.dialogue.utils import process_files +from chatchat.webui_pages.utils import * -chat_box = ChatBox( - assistant_avatar=get_img_base64("chatchat_icon_blue_square_v2.png") -) +chat_box = ChatBox(assistant_avatar=get_img_base64("chatchat_icon_blue_square_v2.png")) -def save_session(conv_name:str=None): +def save_session(conv_name: str = None): """save session state to chat context""" - chat_box.context_from_session(conv_name, exclude=["selected_page", "prompt", "cur_conv_name"]) + chat_box.context_from_session( + conv_name, exclude=["selected_page", "prompt", "cur_conv_name"] + ) -def restore_session(conv_name:str=None): +def restore_session(conv_name: str = None): """restore sesstion state from chat context""" - chat_box.context_to_session(conv_name, exclude=["selected_page", "prompt", "cur_conv_name"]) + chat_box.context_to_session( + conv_name, exclude=["selected_page", "prompt", "cur_conv_name"] + ) def rerun(): @@ -43,14 +50,18 @@ def rerun(): st.rerun() -def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]: +def get_messages_history( + history_len: int, content_in_expander: bool = False +) -> List[Dict]: """ 返回消息历史。 content_in_expander控制是否返回expander元素中的内容,一般导出的时候可以选上,传入LLM的history不需要 """ def filter(msg): - content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]] + content = [ + x for x in msg["elements"] if x._output_method in ["markdown", "text"] + ] if not content_in_expander: content = [x for x in content if not x._in_expander] content = [x.content for x in content] @@ -85,7 +96,12 @@ def add_conv(name: str = ""): break i += 1 if name in conv_names: - sac.alert("创建新会话出错", f"该会话名称 “{name}” 已存在", color="error", closable=True) + sac.alert( + "创建新会话出错", + f"该会话名称 “{name}” 已存在", + color="error", + closable=True, + ) else: chat_box.use_chat_name(name) st.session_state["cur_conv_name"] = name @@ -95,9 +111,13 @@ def del_conv(name: str = None): conv_names = chat_box.get_chat_names() name = name or chat_box.cur_chat_name if len(conv_names) == 1: - sac.alert("删除会话出错", f"这是最后一个会话,无法删除", color="error", closable=True) + sac.alert( + "删除会话出错", f"这是最后一个会话,无法删除", color="error", closable=True + ) elif not name or name not in conv_names: - sac.alert("删除会话出错", f"无效的会话名称:“{name}”", color="error", closable=True) + sac.alert( + "删除会话出错", f"无效的会话名称:“{name}”", color="error", closable=True + ) else: chat_box.del_chat_name(name) restore_session() @@ -114,8 +134,8 @@ def list_tools(_api: ApiRequest): def dialogue_page( - api: ApiRequest, - is_lite: bool = False, + api: ApiRequest, + is_lite: bool = False, ): ctx = chat_box.context ctx.setdefault("uid", uuid.uuid4().hex) @@ -140,7 +160,11 @@ def llm_model_setting(): cols = st.columns(3) platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS] platform = cols[0].selectbox("选择模型平台", platforms, key="platform") - llm_models = list(get_config_models(model_type="llm", platform_name=None if platform == "所有" else platform)) + llm_models = list( + get_config_models( + model_type="llm", platform_name=None if platform == "所有" else platform + ) + ) llm_model = cols[1].selectbox("选择LLM模型", llm_models, key="llm_model") temperature = cols[2].slider("Temperature", 0.0, 1.0, key="temperature") system_message = st.text_area("System Message:", key="system_message") @@ -160,23 +184,36 @@ def rename_conversation(): tab1, tab2 = st.tabs(["工具设置", "会话设置"]) with tab1: - use_agent = st.checkbox("启用Agent", help="请确保选择的模型具备Agent能力", key="use_agent") + use_agent = st.checkbox( + "启用Agent", help="请确保选择的模型具备Agent能力", key="use_agent" + ) # 选择工具 tools = list_tools(api) tool_names = ["None"] + list(tools) if use_agent: # selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", # check_all=True, key="selected_tools") - selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"], - key="selected_tools") + selected_tools = st.multiselect( + "选择工具", + list(tools), + format_func=lambda x: tools[x]["title"], + key="selected_tools", + ) else: # selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", # key="selected_tool") - selected_tool = st.selectbox("选择工具", tool_names, - format_func=lambda x: tools.get(x, {"title": "None"})["title"], - key="selected_tool") + selected_tool = st.selectbox( + "选择工具", + tool_names, + format_func=lambda x: tools.get(x, {"title": "None"})["title"], + key="selected_tool", + ) selected_tools = [selected_tool] - selected_tool_configs = {name: tool["config"] for name, tool in tools.items() if name in selected_tools} + selected_tool_configs = { + name: tool["config"] + for name, tool in tools.items() + if name in selected_tools + } if "None" in selected_tools: selected_tools.remove("None") @@ -190,11 +227,17 @@ def rename_conversation(): tool_input[k] = st.selectbox(v["title"], choices) else: if v["type"] == "integer": - tool_input[k] = st.slider(v["title"], value=v.get("default")) + tool_input[k] = st.slider( + v["title"], value=v.get("default") + ) elif v["type"] == "number": - tool_input[k] = st.slider(v["title"], value=v.get("default"), step=0.1) + tool_input[k] = st.slider( + v["title"], value=v.get("default"), step=0.1 + ) else: - tool_input[k] = st.text_input(v["title"], v.get("default")) + tool_input[k] = st.text_input( + v["title"], v.get("default") + ) # uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False) # files_upload = process_files(files=[uploaded_file]) if uploaded_file else None @@ -209,7 +252,13 @@ def on_conv_change(): print(conversation_name, st.session_state.cur_conv_name) save_session(conversation_name) restore_session(st.session_state.cur_conv_name) - conversation_name = sac.buttons(conv_names, label="当前会话:", key="cur_conv_name", on_change=on_conv_change, ) + + conversation_name = sac.buttons( + conv_names, + label="当前会话:", + key="cur_conv_name", + on_change=on_conv_change, + ) chat_box.use_chat_name(conversation_name) conversation_id = chat_box.context["uid"] if cols[0].button("新建", on_click=add_conv): @@ -249,7 +298,9 @@ def on_conv_change(): chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key] llm_model = ctx.get("llm_model") if llm_model is not None: - chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {}) + chat_model_config["llm_model"][llm_model] = LLM_MODEL_CONFIG["llm_model"].get( + llm_model, {} + ) # chat input with bottom(): @@ -263,21 +314,27 @@ def on_conv_change(): prompt = cols[2].chat_input(chat_input_placeholder, key="prompt") if prompt: history = get_messages_history( - chat_model_config["llm_model"].get(next(iter(chat_model_config["llm_model"])), {}).get("history_len", 1) + chat_model_config["llm_model"] + .get(next(iter(chat_model_config["llm_model"])), {}) + .get("history_len", 1) ) chat_box.user_say(prompt) if files_upload: if files_upload["images"]: - st.markdown(f'', - unsafe_allow_html=True) + st.markdown( + f'', + unsafe_allow_html=True, + ) elif files_upload["videos"]: st.markdown( f'', - unsafe_allow_html=True) + unsafe_allow_html=True, + ) elif files_upload["audios"]: st.markdown( f'', - unsafe_allow_html=True) + unsafe_allow_html=True, + ) chat_box.ai_say("正在思考...") text = "" @@ -302,12 +359,12 @@ def on_conv_change(): tool_input=tool_input, ) for d in client.chat.completions.create( - messages=messages, - model=llm_model, - stream=True, - tools=tools or openai.NOT_GIVEN, - tool_choice=tool_choice, - extra_body=extra_body, + messages=messages, + model=llm_model, + stream=True, + tools=tools or openai.NOT_GIVEN, + tool_choice=tool_choice, + extra_body=extra_body, ): # from pprint import pprint # pprint(d) @@ -328,10 +385,14 @@ def on_conv_change(): text = d.choices[0].delta.content or "" elif d.status == AgentStatus.llm_new_token: text += d.choices[0].delta.content or "" - chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata) + chat_box.update_msg( + text.replace("\n", "\n\n"), streaming=True, metadata=metadata + ) elif d.status == AgentStatus.llm_end: text += d.choices[0].delta.content or "" - chat_box.update_msg(text.replace("\n", "\n\n"), streaming=False, metadata=metadata) + chat_box.update_msg( + text.replace("\n", "\n\n"), streaming=False, metadata=metadata + ) # tool 的输出与 llm 输出重复了 # elif d.status == AgentStatus.tool_start: # formatted_data = { @@ -363,22 +424,42 @@ def on_conv_change(): for inum, doc in enumerate(docs): doc = DocumentWithVSId.parse_obj(doc) filename = doc.metadata.get("source") - parameters = urlencode({"knowledge_base_name": d.tool_output.get("knowledge_base"), "file_name": filename}) - url = f"{api.base_url}/knowledge_base/download_doc?" + parameters + parameters = urlencode( + { + "knowledge_base_name": d.tool_output.get( + "knowledge_base" + ), + "file_name": filename, + } + ) + url = ( + f"{api.base_url}/knowledge_base/download_doc?" + parameters + ) ref = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" source_documents.append(ref) context = "\n".join(source_documents) - chat_box.insert_msg(Markdown(context, in_expander=True, state="complete", title="参考资料")) + chat_box.insert_msg( + Markdown( + context, + in_expander=True, + state="complete", + title="参考资料", + ) + ) chat_box.insert_msg("") else: text += d.choices[0].delta.content or "" - chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata) + chat_box.update_msg( + text.replace("\n", "\n\n"), streaming=True, metadata=metadata + ) chat_box.update_msg(text, streaming=False, metadata=metadata) if os.path.exists("tmp/image.jpg"): with open("tmp/image.jpg", "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode() - img_tag = f'' + img_tag = ( + f'' + ) st.markdown(img_tag, unsafe_allow_html=True) os.remove("tmp/image.jpg") # chat_box.show_feedback(**feedback_kwargs, @@ -416,8 +497,8 @@ def on_conv_change(): cols = st.columns(2) export_btn = cols[0] if cols[1].button( - "清空对话", - use_container_width=True, + "清空对话", + use_container_width=True, ): chat_box.reset_history() rerun() diff --git a/libs/chatchat-server/chatchat/webui_pages/dialogue/utils.py b/libs/chatchat-server/chatchat/webui_pages/dialogue/utils.py index 3d17ba90f..114500db2 100644 --- a/libs/chatchat-server/chatchat/webui_pages/dialogue/utils.py +++ b/libs/chatchat-server/chatchat/webui_pages/dialogue/utils.py @@ -1,8 +1,9 @@ -import streamlit as st import base64 import os from io import BytesIO +import streamlit as st + def encode_file_to_base64(file): # 将文件内容转换为 Base64 编码 @@ -17,15 +18,15 @@ def process_files(files): file_extension = os.path.splitext(file.name)[1].lower() # 检测文件类型并进行相应的处理 - if file_extension in ['.mp4', '.avi']: + if file_extension in [".mp4", ".avi"]: # 视频文件处理 video_base64 = encode_file_to_base64(file) result["videos"].append(video_base64) - elif file_extension in ['.jpg', '.png', '.jpeg']: + elif file_extension in [".jpg", ".png", ".jpeg"]: # 图像文件处理 image_base64 = encode_file_to_base64(file) result["images"].append(image_base64) - elif file_extension in ['.mp3', '.wav', '.ogg', '.flac']: + elif file_extension in [".mp3", ".wav", ".ogg", ".flac"]: # 音频文件处理 audio_base64 = encode_file_to_base64(file) result["audios"].append(audio_base64) diff --git a/libs/chatchat-server/chatchat/webui_pages/knowledge_base/knowledge_base.py b/libs/chatchat-server/chatchat/webui_pages/knowledge_base/knowledge_base.py index 7095f091d..ee6fe2e1a 100644 --- a/libs/chatchat-server/chatchat/webui_pages/knowledge_base/knowledge_base.py +++ b/libs/chatchat-server/chatchat/webui_pages/knowledge_base/knowledge_base.py @@ -1,33 +1,44 @@ +import os +import time +from typing import Dict, Literal, Tuple + +import pandas as pd import streamlit as st +import streamlit_antd_components as sac +from st_aggrid import AgGrid, JsCode +from st_aggrid.grid_options_builder import GridOptionsBuilder from streamlit_antd_components.utils import ParseItems +from chatchat.configs import ( + CHUNK_SIZE, + DEFAULT_VS_TYPE, + OVERLAP_SIZE, + ZH_TITLE_ENHANCE, + kbs_config, +) +from chatchat.server.knowledge_base.kb_service.base import ( + get_kb_details, + get_kb_file_details, +) +from chatchat.server.knowledge_base.utils import LOADER_DICT, get_file_path +from chatchat.server.utils import get_config_models + # from chatchat.webui_pages.loom_view_client import build_providers_embedding_plugins_name, find_menu_items_by_index, \ # set_llm_select, set_embed_select, get_select_embed_endpoint from chatchat.webui_pages.utils import * -from st_aggrid import AgGrid, JsCode -from st_aggrid.grid_options_builder import GridOptionsBuilder -import pandas as pd -from chatchat.server.knowledge_base.utils import get_file_path, LOADER_DICT -from chatchat.server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details -from typing import Literal, Dict, Tuple -from chatchat.configs import (kbs_config, DEFAULT_VS_TYPE, - CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) -from chatchat.server.utils import get_config_models - -import streamlit_antd_components as sac -import os -import time # SENTENCE_SIZE = 100 -cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""") +cell_renderer = JsCode( + """function(params) {if(params.value==true){return '✓'}else{return '×'}}""" +) def config_aggrid( - df: pd.DataFrame, - columns: Dict[Tuple[str, str], Dict] = {}, - selection_mode: Literal["single", "multiple", "disabled"] = "single", - use_checkbox: bool = False, + df: pd.DataFrame, + columns: Dict[Tuple[str, str], Dict] = {}, + selection_mode: Literal["single", "multiple", "disabled"] = "single", + use_checkbox: bool = False, ) -> GridOptionsBuilder: gb = GridOptionsBuilder.from_dataframe(df) gb.configure_column("No", width=40) @@ -39,9 +50,7 @@ def config_aggrid( pre_selected_rows=st.session_state.get("selected_rows", [0]), ) gb.configure_pagination( - enabled=True, - paginationAutoPageSize=False, - paginationPageSize=10 + enabled=True, paginationAutoPageSize=False, paginationPageSize=10 ) return gb @@ -64,11 +73,15 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): kb_list = {x["kb_name"]: x for x in get_kb_details()} except Exception as e: st.error( - "获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。") + "获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。" + ) st.stop() kb_names = list(kb_list.keys()) - if "selected_kb_name" in st.session_state and st.session_state["selected_kb_name"] in kb_names: + if ( + "selected_kb_name" in st.session_state + and st.session_state["selected_kb_name"] in kb_names + ): selected_kb_index = kb_names.index(st.session_state["selected_kb_name"]) else: selected_kb_index = 0 @@ -86,12 +99,11 @@ def format_selected_kb(kb_name: str) -> str: "请选择或新建知识库:", kb_names + ["新建知识库"], format_func=format_selected_kb, - index=selected_kb_index + index=selected_kb_index, ) if selected_kb == "新建知识库": with st.form("新建知识库"): - kb_name = st.text_input( "新建知识库名称", placeholder="新知识库名称,不支持中文命名", @@ -147,15 +159,23 @@ def format_selected_kb(kb_name: str) -> str: elif selected_kb: kb = selected_kb - st.session_state["selected_kb_info"] = kb_list[kb]['kb_info'] + st.session_state["selected_kb_info"] = kb_list[kb]["kb_info"] # 上传文件 - files = st.file_uploader("上传知识文件:", - [i for ls in LOADER_DICT.values() for i in ls], - accept_multiple_files=True, - ) - kb_info = st.text_area("请输入知识库介绍:", value=st.session_state["selected_kb_info"], max_chars=None, - key=None, - help=None, on_change=None, args=None, kwargs=None) + files = st.file_uploader( + "上传知识文件:", + [i for ls in LOADER_DICT.values() for i in ls], + accept_multiple_files=True, + ) + kb_info = st.text_area( + "请输入知识库介绍:", + value=st.session_state["selected_kb_info"], + max_chars=None, + key=None, + help=None, + on_change=None, + args=None, + kwargs=None, + ) if kb_info != st.session_state["selected_kb_info"]: st.session_state["selected_kb_info"] = kb_info @@ -163,27 +183,31 @@ def format_selected_kb(kb_name: str) -> str: # with st.sidebar: with st.expander( - "文件处理配置", - expanded=True, + "文件处理配置", + expanded=True, ): cols = st.columns(3) chunk_size = cols[0].number_input("单段文本最大长度:", 1, 1000, CHUNK_SIZE) - chunk_overlap = cols[1].number_input("相邻文本重合长度:", 0, chunk_size, OVERLAP_SIZE) + chunk_overlap = cols[1].number_input( + "相邻文本重合长度:", 0, chunk_size, OVERLAP_SIZE + ) cols[2].write("") cols[2].write("") zh_title_enhance = cols[2].checkbox("开启中文标题加强", ZH_TITLE_ENHANCE) if st.button( - "添加文件到知识库", - # use_container_width=True, - disabled=len(files) == 0, + "添加文件到知识库", + # use_container_width=True, + disabled=len(files) == 0, ): - ret = api.upload_kb_docs(files, - knowledge_base_name=kb, - override=True, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - zh_title_enhance=zh_title_enhance) + ret = api.upload_kb_docs( + files, + knowledge_base_name=kb, + override=True, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance, + ) if msg := check_success_msg(ret): st.toast(msg, icon="✔") elif msg := check_error_msg(ret): @@ -201,11 +225,23 @@ def format_selected_kb(kb_name: str) -> str: st.write(f"知识库 `{kb}` 中已有文件:") st.info("知识库中包含源文件与向量库,请从下表中选择文件后操作") doc_details.drop(columns=["kb_name"], inplace=True) - doc_details = doc_details[[ - "No", "file_name", "document_loader", "text_splitter", "docs_count", "in_folder", "in_db", - ]] - doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×") - doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×") + doc_details = doc_details[ + [ + "No", + "file_name", + "document_loader", + "text_splitter", + "docs_count", + "in_folder", + "in_db", + ] + ] + doc_details["in_folder"] = ( + doc_details["in_folder"].replace(True, "✓").replace(False, "×") + ) + doc_details["in_db"] = ( + doc_details["in_db"].replace(True, "✓").replace(False, "×") + ) gb = config_aggrid( doc_details, { @@ -232,7 +268,7 @@ def format_selected_kb(kb_name: str) -> str: "#gridToolBar": {"display": "none"}, }, allow_unsafe_jscode=True, - enable_enterprise_modules=False + enable_enterprise_modules=False, ) selected_rows = doc_grid.get("selected_rows") @@ -248,44 +284,49 @@ def format_selected_kb(kb_name: str) -> str: "下载选中文档", fp, file_name=file_name, - use_container_width=True, ) + use_container_width=True, + ) else: cols[0].download_button( "下载选中文档", "", disabled=True, - use_container_width=True, ) + use_container_width=True, + ) st.write() # 将文件分词并加载到向量库中 if cols[1].button( - "重新添加至向量库" if selected_rows and ( - pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库", - disabled=not file_exists(kb, selected_rows)[0], - use_container_width=True, + "重新添加至向量库" + if selected_rows and (pd.DataFrame(selected_rows)["in_db"]).any() + else "添加至向量库", + disabled=not file_exists(kb, selected_rows)[0], + use_container_width=True, ): file_names = [row["file_name"] for row in selected_rows] - api.update_kb_docs(kb, - file_names=file_names, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - zh_title_enhance=zh_title_enhance) + api.update_kb_docs( + kb, + file_names=file_names, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance, + ) st.rerun() # 将文件从向量库中删除,但不删除文件本身。 if cols[2].button( - "从向量库删除", - disabled=not (selected_rows and selected_rows[0]["in_db"]), - use_container_width=True, + "从向量库删除", + disabled=not (selected_rows and selected_rows[0]["in_db"]), + use_container_width=True, ): file_names = [row["file_name"] for row in selected_rows] api.delete_kb_docs(kb, file_names=file_names) st.rerun() if cols[3].button( - "从知识库中删除", - type="primary", - use_container_width=True, + "从知识库中删除", + type="primary", + use_container_width=True, ): file_names = [row["file_name"] for row in selected_rows] api.delete_kb_docs(kb, file_names=file_names, delete_content=True) @@ -296,18 +337,20 @@ def format_selected_kb(kb_name: str) -> str: cols = st.columns(3) if cols[0].button( - "依据源文件重建向量库", - help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", - use_container_width=True, - type="primary", + "依据源文件重建向量库", + help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", + use_container_width=True, + type="primary", ): with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"): empty = st.empty() empty.progress(0.0, "") - for d in api.recreate_vector_store(kb, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - zh_title_enhance=zh_title_enhance): + for d in api.recreate_vector_store( + kb, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance, + ): if msg := check_error_msg(d): st.toast(msg) else: @@ -315,8 +358,8 @@ def format_selected_kb(kb_name: str) -> str: st.rerun() if cols[2].button( - "删除知识库", - use_container_width=True, + "删除知识库", + use_container_width=True, ): ret = api.delete_knowledge_base(kb) st.toast(ret.get("msg", " ")) @@ -332,47 +375,81 @@ def format_selected_kb(kb_name: str) -> str: df = pd.DataFrame([], columns=["seq", "id", "content", "source"]) if selected_rows: file_name = selected_rows[0]["file_name"] - docs = api.search_kb_docs(knowledge_base_name=selected_kb, file_name=file_name) + docs = api.search_kb_docs( + knowledge_base_name=selected_kb, file_name=file_name + ) data = [ - {"seq": i + 1, "id": x["id"], "page_content": x["page_content"], "source": x["metadata"].get("source"), - "type": x["type"], - "metadata": json.dumps(x["metadata"], ensure_ascii=False), - "to_del": "", - } for i, x in enumerate(docs)] + { + "seq": i + 1, + "id": x["id"], + "page_content": x["page_content"], + "source": x["metadata"].get("source"), + "type": x["type"], + "metadata": json.dumps(x["metadata"], ensure_ascii=False), + "to_del": "", + } + for i, x in enumerate(docs) + ] df = pd.DataFrame(data) gb = GridOptionsBuilder.from_dataframe(df) gb.configure_columns(["id", "source", "type", "metadata"], hide=True) gb.configure_column("seq", "No.", width=50) - gb.configure_column("page_content", "内容", editable=True, autoHeight=True, wrapText=True, flex=1, - cellEditor="agLargeTextCellEditor", cellEditorPopup=True) - gb.configure_column("to_del", "删除", editable=True, width=50, wrapHeaderText=True, - cellEditor="agCheckboxCellEditor", cellRender="agCheckboxCellRenderer") + gb.configure_column( + "page_content", + "内容", + editable=True, + autoHeight=True, + wrapText=True, + flex=1, + cellEditor="agLargeTextCellEditor", + cellEditorPopup=True, + ) + gb.configure_column( + "to_del", + "删除", + editable=True, + width=50, + wrapHeaderText=True, + cellEditor="agCheckboxCellEditor", + cellRender="agCheckboxCellRenderer", + ) # 启用分页 - gb.configure_pagination(enabled=True, paginationAutoPageSize=False, paginationPageSize=10) + gb.configure_pagination( + enabled=True, paginationAutoPageSize=False, paginationPageSize=10 + ) gb.configure_selection() edit_docs = AgGrid(df, gb.build(), fit_columns_on_grid_load=True) if st.button("保存更改"): origin_docs = { - x["id"]: {"page_content": x["page_content"], "type": x["type"], "metadata": x["metadata"]} for x in - docs} + x["id"]: { + "page_content": x["page_content"], + "type": x["type"], + "metadata": x["metadata"], + } + for x in docs + } changed_docs = [] for index, row in edit_docs.data.iterrows(): origin_doc = origin_docs[row["id"]] if row["page_content"] != origin_doc["page_content"]: if row["to_del"] not in ["Y", "y", 1]: - changed_docs.append({ - "page_content": row["page_content"], - "type": row["type"], - "metadata": json.loads(row["metadata"]), - }) + changed_docs.append( + { + "page_content": row["page_content"], + "type": row["type"], + "metadata": json.loads(row["metadata"]), + } + ) if changed_docs: - if api.update_kb_docs(knowledge_base_name=selected_kb, - file_names=[file_name], - docs={file_name: changed_docs}): + if api.update_kb_docs( + knowledge_base_name=selected_kb, + file_names=[file_name], + docs={file_name: changed_docs}, + ): st.toast("更新文档成功") else: st.toast("更新文档失败") diff --git a/libs/chatchat-server/chatchat/webui_pages/model_config/__init__.py b/libs/chatchat-server/chatchat/webui_pages/model_config/__init__.py index 3cfc70111..53aeff8e3 100644 --- a/libs/chatchat-server/chatchat/webui_pages/model_config/__init__.py +++ b/libs/chatchat-server/chatchat/webui_pages/model_config/__init__.py @@ -1 +1 @@ -from .model_config import model_config_page \ No newline at end of file +from .model_config import model_config_page diff --git a/libs/chatchat-server/chatchat/webui_pages/model_config/model_config.py b/libs/chatchat-server/chatchat/webui_pages/model_config/model_config.py index 2b0e6a753..833bd1e89 100644 --- a/libs/chatchat-server/chatchat/webui_pages/model_config/model_config.py +++ b/libs/chatchat-server/chatchat/webui_pages/model_config/model_config.py @@ -1,5 +1,7 @@ import streamlit as st + from chatchat.webui_pages.utils import * + def model_config_page(api: ApiRequest): - pass \ No newline at end of file + pass diff --git a/libs/chatchat-server/chatchat/webui_pages/utils.py b/libs/chatchat-server/chatchat/webui_pages/utils.py index 4626d1f67..fea7d39e4 100644 --- a/libs/chatchat-server/chatchat/webui_pages/utils.py +++ b/libs/chatchat-server/chatchat/webui_pages/utils.py @@ -1,31 +1,31 @@ # 该文件封装了对api.py的请求,可以被不同的webui使用 # 通过ApiRequest和AsyncApiRequest支持同步/异步调用 -from typing import * +import base64 +import contextlib +import json +import logging +import os +from io import BytesIO from pathlib import Path +from typing import * + +import httpx + from chatchat.configs import ( + CHUNK_SIZE, DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, + HTTPX_DEFAULT_TIMEOUT, + IMG_DIR, LLM_MODEL_CONFIG, - SCORE_THRESHOLD, - CHUNK_SIZE, OVERLAP_SIZE, - ZH_TITLE_ENHANCE, + SCORE_THRESHOLD, VECTOR_SEARCH_TOP_K, - HTTPX_DEFAULT_TIMEOUT, + ZH_TITLE_ENHANCE, log_verbose, - IMG_DIR ) -import httpx -import contextlib -import json -import os -import base64 -from io import BytesIO -from chatchat.server.utils import set_httpx_config, api_address, get_httpx_client - - -import logging +from chatchat.server.utils import api_address, get_httpx_client, set_httpx_config logger = logging.getLogger() @@ -33,14 +33,14 @@ class ApiRequest: - ''' + """ api.py调用的封装(同步模式),简化api调用方式 - ''' + """ def __init__( - self, - base_url: str = api_address(), - timeout: float = HTTPX_DEFAULT_TIMEOUT, + self, + base_url: str = api_address(), + timeout: float = HTTPX_DEFAULT_TIMEOUT, ): self.base_url = base_url self.timeout = timeout @@ -50,18 +50,18 @@ def __init__( @property def client(self): if self._client is None or self._client.is_closed: - self._client = get_httpx_client(base_url=self.base_url, - use_async=self._use_async, - timeout=self.timeout) + self._client = get_httpx_client( + base_url=self.base_url, use_async=self._use_async, timeout=self.timeout + ) return self._client def get( - self, - url: str, - params: Union[Dict, List[Tuple], bytes] = None, - retry: int = 3, - stream: bool = False, - **kwargs: Any, + self, + url: str, + params: Union[Dict, List[Tuple], bytes] = None, + retry: int = 3, + stream: bool = False, + **kwargs: Any, ) -> Union[httpx.Response, Iterator[httpx.Response], None]: while retry > 0: try: @@ -71,66 +71,76 @@ def get( return self.client.get(url, params=params, **kwargs) except Exception as e: msg = f"error when get {url}: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", + exc_info=e if log_verbose else None, + ) retry -= 1 def post( - self, - url: str, - data: Dict = None, - json: Dict = None, - retry: int = 3, - stream: bool = False, - **kwargs: Any + self, + url: str, + data: Dict = None, + json: Dict = None, + retry: int = 3, + stream: bool = False, + **kwargs: Any, ) -> Union[httpx.Response, Iterator[httpx.Response], None]: while retry > 0: try: # print(kwargs) if stream: - return self.client.stream("POST", url, data=data, json=json, **kwargs) + return self.client.stream( + "POST", url, data=data, json=json, **kwargs + ) else: return self.client.post(url, data=data, json=json, **kwargs) except Exception as e: msg = f"error when post {url}: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", + exc_info=e if log_verbose else None, + ) retry -= 1 def delete( - self, - url: str, - data: Dict = None, - json: Dict = None, - retry: int = 3, - stream: bool = False, - **kwargs: Any + self, + url: str, + data: Dict = None, + json: Dict = None, + retry: int = 3, + stream: bool = False, + **kwargs: Any, ) -> Union[httpx.Response, Iterator[httpx.Response], None]: while retry > 0: try: if stream: - return self.client.stream("DELETE", url, data=data, json=json, **kwargs) + return self.client.stream( + "DELETE", url, data=data, json=json, **kwargs + ) else: return self.client.delete(url, data=data, json=json, **kwargs) except Exception as e: msg = f"error when delete {url}: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", + exc_info=e if log_verbose else None, + ) retry -= 1 def _httpx_stream2generator( - self, - response: contextlib._GeneratorContextManager, - as_json: bool = False, + self, + response: contextlib._GeneratorContextManager, + as_json: bool = False, ): - ''' + """ 将httpx.stream返回的GeneratorContextManager转化为普通生成器 - ''' + """ async def ret_async(response, as_json): try: async with response as r: - chunk_cache = '' + chunk_cache = "" async for chunk in r.aiter_text(None): if not chunk: # fastchat api yield empty bytes on start and end continue @@ -143,12 +153,14 @@ async def ret_async(response, as_json): else: data = json.loads(chunk_cache + chunk) - chunk_cache = '' + chunk_cache = "" yield data except Exception as e: msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", + exc_info=e if log_verbose else None, + ) if chunk.startswith("data: "): chunk_cache += chunk[6:-2] @@ -170,14 +182,16 @@ async def ret_async(response, as_json): yield {"code": 500, "msg": msg} except Exception as e: msg = f"API通信遇到错误:{e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", + exc_info=e if log_verbose else None, + ) yield {"code": 500, "msg": msg} def ret_sync(response, as_json): try: with response as r: - chunk_cache = '' + chunk_cache = "" for chunk in r.iter_text(None): if not chunk: # fastchat api yield empty bytes on start and end continue @@ -190,12 +204,14 @@ def ret_sync(response, as_json): else: data = json.loads(chunk_cache + chunk) - chunk_cache = '' + chunk_cache = "" yield data except Exception as e: msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", + exc_info=e if log_verbose else None, + ) if chunk.startswith("data: "): chunk_cache += chunk[6:-2] @@ -217,8 +233,10 @@ def ret_sync(response, as_json): yield {"code": 500, "msg": msg} except Exception as e: msg = f"API通信遇到错误:{e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", + exc_info=e if log_verbose else None, + ) yield {"code": 500, "msg": msg} if self._use_async: @@ -227,16 +245,16 @@ def ret_sync(response, as_json): return ret_sync(response, as_json) def _get_response_value( - self, - response: httpx.Response, - as_json: bool = False, - value_func: Callable = None, + self, + response: httpx.Response, + as_json: bool = False, + value_func: Callable = None, ): - ''' + """ 转换同步或异步请求返回的响应 `as_json`: 返回json `value_func`: 用户可以自定义返回值,该函数接受response或json - ''' + """ def to_json(r): try: @@ -244,12 +262,14 @@ def to_json(r): except Exception as e: msg = "API未能返回正确的JSON。" + str(e) if log_verbose: - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + logger.error( + f"{e.__class__.__name__}: {msg}", + exc_info=e if log_verbose else None, + ) return {"code": 500, "msg": msg, "data": None} if value_func is None: - value_func = (lambda r: r) + value_func = lambda r: r async def ret_async(response): if as_json: @@ -271,10 +291,10 @@ def get_server_configs(self, **kwargs) -> Dict: return self._get_response_value(response, as_json=True) def get_prompt_template( - self, - type: str = "llm_chat", - name: str = "default", - **kwargs, + self, + type: str = "llm_chat", + name: str = "default", + **kwargs, ) -> str: data = { "type": type, @@ -285,20 +305,20 @@ def get_prompt_template( # 对话相关操作 def chat_chat( - self, - query: str, - metadata: dict, - conversation_id: str = None, - history_len: int = -1, - history: List[Dict] = [], - stream: bool = True, - chat_model_config: Dict = None, - tool_config: Dict = None, - **kwargs, + self, + query: str, + metadata: dict, + conversation_id: str = None, + history_len: int = -1, + history: List[Dict] = [], + stream: bool = True, + chat_model_config: Dict = None, + tool_config: Dict = None, + **kwargs, ): - ''' + """ 对应api.py/chat/chat接口 - ''' + """ data = { "query": query, "metadata": metadata, @@ -317,16 +337,16 @@ def chat_chat( return self._httpx_stream2generator(response, as_json=True) def upload_temp_docs( - self, - files: List[Union[str, Path, bytes]], - knowledge_id: str = None, - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - zh_title_enhance=ZH_TITLE_ENHANCE, + self, + files: List[Union[str, Path, bytes]], + knowledge_id: str = None, + chunk_size=CHUNK_SIZE, + chunk_overlap=OVERLAP_SIZE, + zh_title_enhance=ZH_TITLE_ENHANCE, ): - ''' + """ 对应api.py/knowledge_base/upload_tmep_docs接口 - ''' + """ def convert_file(file, filename=None): if isinstance(file, bytes): # raw bytes @@ -354,21 +374,21 @@ def convert_file(file, filename=None): return self._get_response_value(response, as_json=True) def file_chat( - self, - query: str, - knowledge_id: str, - top_k: int = VECTOR_SEARCH_TOP_K, - score_threshold: float = SCORE_THRESHOLD, - history: List[Dict] = [], - stream: bool = True, - model: str = None, - temperature: float = 0.9, - max_tokens: int = None, - prompt_name: str = "default", + self, + query: str, + knowledge_id: str, + top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: float = SCORE_THRESHOLD, + history: List[Dict] = [], + stream: bool = True, + model: str = None, + temperature: float = 0.9, + max_tokens: int = None, + prompt_name: str = "default", ): - ''' + """ 对应api.py/chat/file_chat接口 - ''' + """ data = { "query": query, "knowledge_id": knowledge_id, @@ -392,25 +412,25 @@ def file_chat( # 知识库相关操作 def list_knowledge_bases( - self, + self, ): - ''' + """ 对应api.py/knowledge_base/list_knowledge_bases接口 - ''' + """ response = self.get("/knowledge_base/list_knowledge_bases") - return self._get_response_value(response, - as_json=True, - value_func=lambda r: r.get("data", [])) + return self._get_response_value( + response, as_json=True, value_func=lambda r: r.get("data", []) + ) def create_knowledge_base( - self, - knowledge_base_name: str, - vector_store_type: str = DEFAULT_VS_TYPE, - embed_model: str = DEFAULT_EMBEDDING_MODEL, + self, + knowledge_base_name: str, + vector_store_type: str = DEFAULT_VS_TYPE, + embed_model: str = DEFAULT_EMBEDDING_MODEL, ): - ''' + """ 对应api.py/knowledge_base/create_knowledge_base接口 - ''' + """ data = { "knowledge_base_name": knowledge_base_name, "vector_store_type": vector_store_type, @@ -424,12 +444,12 @@ def create_knowledge_base( return self._get_response_value(response, as_json=True) def delete_knowledge_base( - self, - knowledge_base_name: str, + self, + knowledge_base_name: str, ): - ''' + """ 对应api.py/knowledge_base/delete_knowledge_base接口 - ''' + """ response = self.post( "/knowledge_base/delete_knowledge_base", json=f"{knowledge_base_name}", @@ -437,32 +457,32 @@ def delete_knowledge_base( return self._get_response_value(response, as_json=True) def list_kb_docs( - self, - knowledge_base_name: str, + self, + knowledge_base_name: str, ): - ''' + """ 对应api.py/knowledge_base/list_files接口 - ''' + """ response = self.get( "/knowledge_base/list_files", - params={"knowledge_base_name": knowledge_base_name} + params={"knowledge_base_name": knowledge_base_name}, + ) + return self._get_response_value( + response, as_json=True, value_func=lambda r: r.get("data", []) ) - return self._get_response_value(response, - as_json=True, - value_func=lambda r: r.get("data", [])) def search_kb_docs( - self, - knowledge_base_name: str, - query: str = "", - top_k: int = VECTOR_SEARCH_TOP_K, - score_threshold: int = SCORE_THRESHOLD, - file_name: str = "", - metadata: dict = {}, + self, + knowledge_base_name: str, + query: str = "", + top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: int = SCORE_THRESHOLD, + file_name: str = "", + metadata: dict = {}, ) -> List: - ''' + """ 对应api.py/knowledge_base/search_docs接口 - ''' + """ data = { "query": query, "knowledge_base_name": knowledge_base_name, @@ -479,20 +499,20 @@ def search_kb_docs( return self._get_response_value(response, as_json=True) def upload_kb_docs( - self, - files: List[Union[str, Path, bytes]], - knowledge_base_name: str, - override: bool = False, - to_vector_store: bool = True, - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - zh_title_enhance=ZH_TITLE_ENHANCE, - docs: Dict = {}, - not_refresh_vs_cache: bool = False, + self, + files: List[Union[str, Path, bytes]], + knowledge_base_name: str, + override: bool = False, + to_vector_store: bool = True, + chunk_size=CHUNK_SIZE, + chunk_overlap=OVERLAP_SIZE, + zh_title_enhance=ZH_TITLE_ENHANCE, + docs: Dict = {}, + not_refresh_vs_cache: bool = False, ): - ''' + """ 对应api.py/knowledge_base/upload_docs接口 - ''' + """ def convert_file(file, filename=None): if isinstance(file, bytes): # raw bytes @@ -526,15 +546,15 @@ def convert_file(file, filename=None): return self._get_response_value(response, as_json=True) def delete_kb_docs( - self, - knowledge_base_name: str, - file_names: List[str], - delete_content: bool = False, - not_refresh_vs_cache: bool = False, + self, + knowledge_base_name: str, + file_names: List[str], + delete_content: bool = False, + not_refresh_vs_cache: bool = False, ): - ''' + """ 对应api.py/knowledge_base/delete_docs接口 - ''' + """ data = { "knowledge_base_name": knowledge_base_name, "file_names": file_names, @@ -549,9 +569,9 @@ def delete_kb_docs( return self._get_response_value(response, as_json=True) def update_kb_info(self, knowledge_base_name, kb_info): - ''' + """ 对应api.py/knowledge_base/update_info接口 - ''' + """ data = { "knowledge_base_name": knowledge_base_name, "kb_info": kb_info, @@ -564,19 +584,19 @@ def update_kb_info(self, knowledge_base_name, kb_info): return self._get_response_value(response, as_json=True) def update_kb_docs( - self, - knowledge_base_name: str, - file_names: List[str], - override_custom_docs: bool = False, - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - zh_title_enhance=ZH_TITLE_ENHANCE, - docs: Dict = {}, - not_refresh_vs_cache: bool = False, + self, + knowledge_base_name: str, + file_names: List[str], + override_custom_docs: bool = False, + chunk_size=CHUNK_SIZE, + chunk_overlap=OVERLAP_SIZE, + zh_title_enhance=ZH_TITLE_ENHANCE, + docs: Dict = {}, + not_refresh_vs_cache: bool = False, ): - ''' + """ 对应api.py/knowledge_base/update_docs接口 - ''' + """ data = { "knowledge_base_name": knowledge_base_name, "file_names": file_names, @@ -598,18 +618,18 @@ def update_kb_docs( return self._get_response_value(response, as_json=True) def recreate_vector_store( - self, - knowledge_base_name: str, - allow_empty_kb: bool = True, - vs_type: str = DEFAULT_VS_TYPE, - embed_model: str = DEFAULT_EMBEDDING_MODEL, - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - zh_title_enhance=ZH_TITLE_ENHANCE, + self, + knowledge_base_name: str, + allow_empty_kb: bool = True, + vs_type: str = DEFAULT_VS_TYPE, + embed_model: str = DEFAULT_EMBEDDING_MODEL, + chunk_size=CHUNK_SIZE, + chunk_overlap=OVERLAP_SIZE, + zh_title_enhance=ZH_TITLE_ENHANCE, ): - ''' + """ 对应api.py/knowledge_base/recreate_vector_store接口 - ''' + """ data = { "knowledge_base_name": knowledge_base_name, "allow_empty_kb": allow_empty_kb, @@ -629,14 +649,14 @@ def recreate_vector_store( return self._httpx_stream2generator(response, as_json=True) def embed_texts( - self, - texts: List[str], - embed_model: str = DEFAULT_EMBEDDING_MODEL, - to_query: bool = False, + self, + texts: List[str], + embed_model: str = DEFAULT_EMBEDDING_MODEL, + to_query: bool = False, ) -> List[List[float]]: - ''' + """ 对文本进行向量化,可选模型包括本地 embed_models 和支持 embeddings 的在线模型 - ''' + """ data = { "texts": texts, "embed_model": embed_model, @@ -646,17 +666,19 @@ def embed_texts( "/other/embed_texts", json=data, ) - return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data")) + return self._get_response_value( + resp, as_json=True, value_func=lambda r: r.get("data") + ) def chat_feedback( - self, - message_id: str, - score: int, - reason: str = "", + self, + message_id: str, + score: int, + reason: str = "", ) -> int: - ''' + """ 反馈对话评价 - ''' + """ data = { "message_id": message_id, "score": score, @@ -666,37 +688,44 @@ def chat_feedback( return self._get_response_value(resp) def list_tools(self) -> Dict: - ''' + """ 列出所有工具 - ''' + """ resp = self.get("/tools") - return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data", {})) + return self._get_response_value( + resp, as_json=True, value_func=lambda r: r.get("data", {}) + ) def call_tool( - self, - name: str, - tool_input: Dict = {}, + self, + name: str, + tool_input: Dict = {}, ): - ''' + """ 调用工具 - ''' + """ data = { "name": name, "tool_input": tool_input, } resp = self.post("/tools/call", json=data) - return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data")) + return self._get_response_value( + resp, as_json=True, value_func=lambda r: r.get("data") + ) + class AsyncApiRequest(ApiRequest): - def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT): + def __init__( + self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT + ): super().__init__(base_url, timeout) self._use_async = True def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: - ''' + """ return error message if error occured when requests API - ''' + """ if isinstance(data, dict): if key in data: return data[key] @@ -706,22 +735,24 @@ def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str: - ''' + """ return error message if error occured when requests API - ''' - if (isinstance(data, dict) - and key in data - and "code" in data - and data["code"] == 200): + """ + if ( + isinstance(data, dict) + and key in data + and "code" in data + and data["code"] == 200 + ): return data[key] return "" def get_img_base64(file_name: str) -> str: - ''' + """ get_img_base64 used in streamlit. absolute local path not working on windows. - ''' + """ image = f"{IMG_DIR}/{file_name}" # 读取图片 with open(image, "rb") as f: diff --git a/libs/chatchat-server/langchain_chatchat/__init__.py b/libs/chatchat-server/langchain_chatchat/__init__.py index 2ca5ec0c3..e87fe0f65 100644 --- a/libs/chatchat-server/langchain_chatchat/__init__.py +++ b/libs/chatchat-server/langchain_chatchat/__init__.py @@ -3,32 +3,30 @@ import types # 动态导入 a_chatchat 模块 -chatchat = importlib.import_module('chatchat') +chatchat = importlib.import_module("chatchat") # 创建新的模块对象 -module = types.ModuleType('langchain_chatchat') -sys.modules['langchain_chatchat'] = module +module = types.ModuleType("langchain_chatchat") +sys.modules["langchain_chatchat"] = module # 把 a_chatchat 的所有属性复制到 langchain_chatchat for attr in dir(chatchat): - if not attr.startswith('_'): + if not attr.startswith("_"): setattr(module, attr, getattr(chatchat, attr)) # 动态导入子模块 def import_submodule(name): - full_name = f'chatchat.{name}' + full_name = f"chatchat.{name}" submodule = importlib.import_module(full_name) - sys.modules[f'langchain_chatchat.{name}'] = submodule + sys.modules[f"langchain_chatchat.{name}"] = submodule for attr in dir(submodule): - if not attr.startswith('_'): + if not attr.startswith("_"): setattr(module, attr, getattr(submodule, attr)) # 需要的子模块列表,自己添加 -submodules = ['configs', 'server', - 'startup', 'webui_pages' - ] +submodules = ["configs", "server", "startup", "webui_pages"] # 导入所有子模块 for submodule in submodules: diff --git a/libs/chatchat-server/tests/api/test_kb_api.py b/libs/chatchat-server/tests/api/test_kb_api.py index 52f7a2438..466e01485 100644 --- a/libs/chatchat-server/tests/api/test_kb_api.py +++ b/libs/chatchat-server/tests/api/test_kb_api.py @@ -1,16 +1,16 @@ -import requests import json import sys from pathlib import Path +import requests + root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) -from chatchat.server.utils import api_address -from chatchat.configs import VECTOR_SEARCH_TOP_K -from chatchat.server.knowledge_base.utils import get_kb_path, get_file_path - from pprint import pprint +from chatchat.configs import VECTOR_SEARCH_TOP_K +from chatchat.server.knowledge_base.utils import get_file_path, get_kb_path +from chatchat.server.utils import api_address api_base_url = api_address() @@ -145,11 +145,14 @@ def test_update_info(api="/knowledge_base/update_info"): pprint(data) assert data["code"] == 200 + def test_update_docs(api="/knowledge_base/update_docs"): url = api_base_url + api print(f"\n更新知识文件") - r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)}) + r = requests.post( + url, json={"knowledge_base_name": kb, "file_names": list(test_files)} + ) data = r.json() pprint(data) assert data["code"] == 200 @@ -160,7 +163,9 @@ def test_delete_docs(api="/knowledge_base/delete_docs"): url = api_base_url + api print(f"\n删除知识文件") - r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)}) + r = requests.post( + url, json={"knowledge_base_name": kb, "file_names": list(test_files)} + ) data = r.json() pprint(data) assert data["code"] == 200 diff --git a/libs/chatchat-server/tests/api/test_kb_api_request.py b/libs/chatchat-server/tests/api/test_kb_api_request.py index 4fb01746e..cb84109a4 100644 --- a/libs/chatchat-server/tests/api/test_kb_api_request.py +++ b/libs/chatchat-server/tests/api/test_kb_api_request.py @@ -1,17 +1,17 @@ -import requests import json import sys from pathlib import Path +import requests + root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) -from chatchat.server.utils import api_address -from chatchat.configs import VECTOR_SEARCH_TOP_K -from chatchat.server.knowledge_base.utils import get_kb_path, get_file_path -from chatchat.webui_pages.utils import ApiRequest - from pprint import pprint +from chatchat.configs import VECTOR_SEARCH_TOP_K +from chatchat.server.knowledge_base.utils import get_file_path, get_kb_path +from chatchat.server.utils import api_address +from chatchat.webui_pages.utils import ApiRequest api_base_url = api_address() api: ApiRequest = ApiRequest(api_base_url) diff --git a/libs/chatchat-server/tests/api/test_kb_summary_api.py b/libs/chatchat-server/tests/api/test_kb_summary_api.py index 88bb51532..d5daa74fa 100644 --- a/libs/chatchat-server/tests/api/test_kb_summary_api.py +++ b/libs/chatchat-server/tests/api/test_kb_summary_api.py @@ -1,8 +1,9 @@ -import requests import json import sys from pathlib import Path +import requests + root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) from chatchat.server.utils import api_address @@ -14,16 +15,18 @@ doc_ids = [ "357d580f-fdf7-495c-b58b-595a398284e8", "c7338773-2e83-4671-b237-1ad20335b0f0", - "6da613d1-327d-466f-8c1a-b32e6f461f47" + "6da613d1-327d-466f-8c1a-b32e6f461f47", ] -def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summary_file_to_vector_store"): +def test_summary_file_to_vector_store( + api="/knowledge_base/kb_summary_api/summary_file_to_vector_store", +): url = api_base_url + api print("\n文件摘要:") - r = requests.post(url, json={"knowledge_base_name": kb, - "file_name": file_name - }, stream=True) + r = requests.post( + url, json={"knowledge_base_name": kb, "file_name": file_name}, stream=True + ) for chunk in r.iter_content(None): data = json.loads(chunk[6:]) assert isinstance(data, dict) @@ -31,12 +34,14 @@ def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summar print(data["msg"]) -def test_summary_doc_ids_to_vector_store(api="/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store"): +def test_summary_doc_ids_to_vector_store( + api="/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store", +): url = api_base_url + api print("\n文件摘要:") - r = requests.post(url, json={"knowledge_base_name": kb, - "doc_ids": doc_ids - }, stream=True) + r = requests.post( + url, json={"knowledge_base_name": kb, "doc_ids": doc_ids}, stream=True + ) for chunk in r.iter_content(None): data = json.loads(chunk[6:]) assert isinstance(data, dict) diff --git a/libs/chatchat-server/tests/api/test_openai_wrap.py b/libs/chatchat-server/tests/api/test_openai_wrap.py index 96d2bb4f6..ce627c4a0 100644 --- a/libs/chatchat-server/tests/api/test_openai_wrap.py +++ b/libs/chatchat-server/tests/api/test_openai_wrap.py @@ -1,21 +1,21 @@ import sys from pathlib import Path -sys.path.append(str(Path(__file__).parent.parent.parent)) -import requests +sys.path.append(str(Path(__file__).parent.parent.parent)) import openai +import requests -from chatchat.configs import DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL +from chatchat.configs import DEFAULT_EMBEDDING_MODEL, DEFAULT_LLM_MODEL from chatchat.server.utils import api_address - api_base_url = f"{api_address()}/v1" client = openai.Client( api_key="EMPTY", base_url=api_base_url, ) + def test_chat(): resp = client.chat.completions.create( messages=[{"role": "user", "content": "你是谁"}], diff --git a/libs/chatchat-server/tests/api/test_server_state_api.py b/libs/chatchat-server/tests/api/test_server_state_api.py index c4492d2a5..026d69819 100644 --- a/libs/chatchat-server/tests/api/test_server_state_api.py +++ b/libs/chatchat-server/tests/api/test_server_state_api.py @@ -1,17 +1,17 @@ import sys from pathlib import Path + root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) -from chatchat.webui_pages.utils import ApiRequest - -import pytest from pprint import pprint from typing import List +import pytest -api = ApiRequest() +from chatchat.webui_pages.utils import ApiRequest +api = ApiRequest() @pytest.mark.parametrize("type", ["llm_chat"]) diff --git a/libs/chatchat-server/tests/api/test_stream_chat_api.py b/libs/chatchat-server/tests/api/test_stream_chat_api.py index 3a5a9ebfe..439a900be 100644 --- a/libs/chatchat-server/tests/api/test_stream_chat_api.py +++ b/libs/chatchat-server/tests/api/test_stream_chat_api.py @@ -1,47 +1,41 @@ -import requests import json import sys from pathlib import Path -sys.path.append(str(Path(__file__).parent.parent.parent)) -from chatchat.configs import BING_SUBSCRIPTION_KEY -from chatchat.server.utils import api_address +import requests +sys.path.append(str(Path(__file__).parent.parent.parent)) from pprint import pprint +from chatchat.configs import BING_SUBSCRIPTION_KEY +from chatchat.server.utils import api_address api_base_url = api_address() def dump_input(d, title): print("\n") - print("=" * 30 + title + " input " + "="*30) + print("=" * 30 + title + " input " + "=" * 30) pprint(d) def dump_output(r, title): print("\n") - print("=" * 30 + title + " output" + "="*30) + print("=" * 30 + title + " output" + "=" * 30) for line in r.iter_content(None, decode_unicode=True): print(line, end="", flush=True) headers = { - 'accept': 'application/json', - 'Content-Type': 'application/json', + "accept": "application/json", + "Content-Type": "application/json", } data = { "query": "请用100字左右的文字介绍自己", "history": [ - { - "role": "user", - "content": "你好" - }, - { - "role": "assistant", - "content": "你好,我是人工智能大模型" - } + {"role": "user", "content": "你好"}, + {"role": "assistant", "content": "你好,我是人工智能大模型"}, ], "stream": True, "temperature": 0.7, @@ -62,21 +56,15 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"): "query": "如何提问以获得高质量答案", "knowledge_base_name": "samples", "history": [ - { - "role": "user", - "content": "你好" - }, - { - "role": "assistant", - "content": "你好,我是 ChatGLM" - } + {"role": "user", "content": "你好"}, + {"role": "assistant", "content": "你好,我是 ChatGLM"}, ], - "stream": True + "stream": True, } dump_input(data, api) response = requests.post(url, headers=headers, json=data, stream=True) print("\n") - print("=" * 30 + api + " output" + "="*30) + print("=" * 30 + api + " output" + "=" * 30) for line in response.iter_content(None, decode_unicode=True): data = json.loads(line[6:]) if "answer" in data: @@ -84,4 +72,3 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"): pprint(data) assert "docs" in data and len(data["docs"]) > 0 assert response.status_code == 200 - diff --git a/libs/chatchat-server/tests/api/test_stream_chat_api_thread.py b/libs/chatchat-server/tests/api/test_stream_chat_api_thread.py index 999a34350..75e11a449 100644 --- a/libs/chatchat-server/tests/api/test_stream_chat_api_thread.py +++ b/libs/chatchat-server/tests/api/test_stream_chat_api_thread.py @@ -1,36 +1,36 @@ -import requests import json import sys from pathlib import Path -sys.path.append(str(Path(__file__).parent.parent.parent)) -from chatchat.configs import BING_SUBSCRIPTION_KEY -from chatchat.server.utils import api_address +import requests -from pprint import pprint -from concurrent.futures import ThreadPoolExecutor, as_completed +sys.path.append(str(Path(__file__).parent.parent.parent)) import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pprint import pprint +from chatchat.configs import BING_SUBSCRIPTION_KEY +from chatchat.server.utils import api_address api_base_url = api_address() def dump_input(d, title): print("\n") - print("=" * 30 + title + " input " + "="*30) + print("=" * 30 + title + " input " + "=" * 30) pprint(d) def dump_output(r, title): print("\n") - print("=" * 30 + title + " output" + "="*30) + print("=" * 30 + title + " output" + "=" * 30) for line in r.iter_content(None, decode_unicode=True): print(line, end="", flush=True) headers = { - 'accept': 'application/json', - 'Content-Type': 'application/json', + "accept": "application/json", + "Content-Type": "application/json", } @@ -40,16 +40,10 @@ def knowledge_chat(api="/chat/knowledge_base_chat"): "query": "如何提问以获得高质量答案", "knowledge_base_name": "samples", "history": [ - { - "role": "user", - "content": "你好" - }, - { - "role": "assistant", - "content": "你好,我是 ChatGLM" - } + {"role": "user", "content": "你好"}, + {"role": "assistant", "content": "你好,我是 ChatGLM"}, ], - "stream": True + "stream": True, } result = [] response = requests.post(url, headers=headers, json=data, stream=True) @@ -57,7 +51,7 @@ def knowledge_chat(api="/chat/knowledge_base_chat"): for line in response.iter_content(None, decode_unicode=True): data = json.loads(line[6:]) result.append(data) - + return result @@ -69,7 +63,7 @@ def test_thread(): for i in range(10): t = pool.submit(knowledge_chat) threads.append(t) - + for r in as_completed(threads): end = time.time() times.append(end - start) diff --git a/libs/chatchat-server/tests/api/test_tools.py b/libs/chatchat-server/tests/api/test_tools.py index 1bb736733..230c2dd74 100644 --- a/libs/chatchat-server/tests/api/test_tools.py +++ b/libs/chatchat-server/tests/api/test_tools.py @@ -1,15 +1,17 @@ import sys from pathlib import Path + sys.path.append(str(Path(__file__).parent.parent.parent)) from pprint import pprint + import requests from chatchat.server.utils import api_address - api_base_url = f"{api_address()}/tools" + def test_tool_list(): resp = requests.get(api_base_url) assert resp.status_code == 200 @@ -21,7 +23,7 @@ def test_tool_list(): def test_tool_call(): data = { "name": "calculate", - "kwargs": {"a":1,"b":2,"operator":"+"}, + "kwargs": {"a": 1, "b": 2, "operator": "+"}, } resp = requests.post(f"{api_base_url}/call", json=data) assert resp.status_code == 200 diff --git a/libs/chatchat-server/tests/conftest.py b/libs/chatchat-server/tests/conftest.py index 92c5e9464..24da0c152 100644 --- a/libs/chatchat-server/tests/conftest.py +++ b/libs/chatchat-server/tests/conftest.py @@ -4,12 +4,12 @@ from typing import Dict, List, Sequence import pytest -from pytest import Config, Function, Parser from model_providers.core.utils.utils import ( get_config_dict, get_log_file, get_timestamp_ms, ) +from pytest import Config, Function, Parser def pytest_addoption(parser: Parser) -> None: @@ -100,7 +100,7 @@ def logging_conf() -> dict: get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"), 1024 * 1024 * 1024 * 3, 1024 * 1024 * 1024 * 3, - ) + ) @pytest.fixture diff --git a/libs/chatchat-server/tests/custom_splitter/test_different_splitter.py b/libs/chatchat-server/tests/custom_splitter/test_different_splitter.py index 8b0df217c..50ed3eec3 100644 --- a/libs/chatchat-server/tests/custom_splitter/test_different_splitter.py +++ b/libs/chatchat-server/tests/custom_splitter/test_different_splitter.py @@ -1,16 +1,13 @@ import os +import sys from transformers import AutoTokenizer -import sys sys.path.append("../..") -from chatchat.configs import ( - CHUNK_SIZE, - OVERLAP_SIZE -) - +from chatchat.configs import CHUNK_SIZE, OVERLAP_SIZE from chatchat.server.knowledge_base.utils import make_text_splitter + def text(splitter_name): from langchain import document_loaders @@ -31,23 +28,26 @@ def text(splitter_name): return docs - - import pytest from langchain.docstore.document import Document -@pytest.mark.parametrize("splitter_name", - [ - "ChineseRecursiveTextSplitter", - "SpacyTextSplitter", - "RecursiveCharacterTextSplitter", - "MarkdownHeaderTextSplitter" - ]) + +@pytest.mark.parametrize( + "splitter_name", + [ + "ChineseRecursiveTextSplitter", + "SpacyTextSplitter", + "RecursiveCharacterTextSplitter", + "MarkdownHeaderTextSplitter", + ], +) def test_different_splitter(splitter_name): try: docs = text(splitter_name) assert isinstance(docs, list) - if len(docs)>0: + if len(docs) > 0: assert isinstance(docs[0], Document) except Exception as e: - pytest.fail(f"test_different_splitter failed with {splitter_name}, error: {str(e)}") + pytest.fail( + f"test_different_splitter failed with {splitter_name}, error: {str(e)}" + ) diff --git a/libs/chatchat-server/tests/document_loader/test_imgloader.py b/libs/chatchat-server/tests/document_loader/test_imgloader.py index 92460cb4e..cb85319b1 100644 --- a/libs/chatchat-server/tests/document_loader/test_imgloader.py +++ b/libs/chatchat-server/tests/document_loader/test_imgloader.py @@ -9,6 +9,7 @@ "ocr_test.jpg": str(root_path / "tests" / "samples" / "ocr_test.jpg"), } + def test_rapidocrloader(): img_path = test_files["ocr_test.jpg"] from document_loaders import RapidOCRLoader @@ -16,6 +17,8 @@ def test_rapidocrloader(): loader = RapidOCRLoader(img_path) docs = loader.load() pprint(docs) - assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str) - - + assert ( + isinstance(docs, list) + and len(docs) > 0 + and isinstance(docs[0].page_content, str) + ) diff --git a/libs/chatchat-server/tests/document_loader/test_pdfloader.py b/libs/chatchat-server/tests/document_loader/test_pdfloader.py index 8bba7da99..6a330e522 100644 --- a/libs/chatchat-server/tests/document_loader/test_pdfloader.py +++ b/libs/chatchat-server/tests/document_loader/test_pdfloader.py @@ -9,6 +9,7 @@ "ocr_test.pdf": str(root_path / "tests" / "samples" / "ocr_test.pdf"), } + def test_rapidocrpdfloader(): pdf_path = test_files["ocr_test.pdf"] from document_loaders import RapidOCRPDFLoader @@ -16,6 +17,8 @@ def test_rapidocrpdfloader(): loader = RapidOCRPDFLoader(pdf_path) docs = loader.load() pprint(docs) - assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str) - - + assert ( + isinstance(docs, list) + and len(docs) > 0 + and isinstance(docs[0].page_content, str) + ) diff --git a/libs/chatchat-server/tests/integration_tests/unit_server/test_init_server.py b/libs/chatchat-server/tests/integration_tests/unit_server/test_init_server.py index 1bc2c1ead..3bfac6064 100644 --- a/libs/chatchat-server/tests/integration_tests/unit_server/test_init_server.py +++ b/libs/chatchat-server/tests/integration_tests/unit_server/test_init_server.py @@ -1,10 +1,10 @@ -from chatchat.init_server import init_server +import logging +import logging.config import multiprocessing as mp +from chatchat.init_server import init_server from chatchat.server.utils import is_port_in_use from chatchat.startup import run_init_server -import logging -import logging.config logger = logging.getLogger(__name__) @@ -19,7 +19,10 @@ def test_init_server(logging_conf, providers_file): process = mp.Process( target=run_init_server, name=f"Model providers Server", - kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=model_providers_started), + kwargs=dict( + model_platforms_shard=model_platforms_shard, + started_event=model_providers_started, + ), daemon=True, ) diff --git a/libs/chatchat-server/tests/kb_vector_db/test_faiss_kb.py b/libs/chatchat-server/tests/kb_vector_db/test_faiss_kb.py index fae178bb7..29e7ad2ab 100644 --- a/libs/chatchat-server/tests/kb_vector_db/test_faiss_kb.py +++ b/libs/chatchat-server/tests/kb_vector_db/test_faiss_kb.py @@ -2,7 +2,6 @@ from chatchat.server.knowledge_base.migrate import create_tables from chatchat.server.knowledge_base.utils import KnowledgeFile - kbService = FaissKBService("test") test_kb_name = "test" test_file_name = "README.md" diff --git a/libs/chatchat-server/tests/kb_vector_db/test_milvus_db.py b/libs/chatchat-server/tests/kb_vector_db/test_milvus_db.py index 22941d1e3..230007475 100644 --- a/libs/chatchat-server/tests/kb_vector_db/test_milvus_db.py +++ b/libs/chatchat-server/tests/kb_vector_db/test_milvus_db.py @@ -11,6 +11,7 @@ testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name) search_content = "如何启动api服务" + def test_init(): create_tables() @@ -26,6 +27,7 @@ def test_add_doc(): def test_search_db(): result = kbService.search_docs(search_content) assert len(result) > 0 + + def test_delete_doc(): assert kbService.delete_doc(testKnowledgeFile) - diff --git a/libs/chatchat-server/tests/kb_vector_db/test_pg_db.py b/libs/chatchat-server/tests/kb_vector_db/test_pg_db.py index abafac893..2a383df80 100644 --- a/libs/chatchat-server/tests/kb_vector_db/test_pg_db.py +++ b/libs/chatchat-server/tests/kb_vector_db/test_pg_db.py @@ -26,6 +26,7 @@ def test_add_doc(): def test_search_db(): result = kbService.search_docs(search_content) assert len(result) > 0 + + def test_delete_doc(): assert kbService.delete_doc(testKnowledgeFile) - diff --git a/libs/chatchat-server/tests/test_migrate.py b/libs/chatchat-server/tests/test_migrate.py index e020ef6c3..1b5e2ecfa 100644 --- a/libs/chatchat-server/tests/test_migrate.py +++ b/libs/chatchat-server/tests/test_migrate.py @@ -1,15 +1,23 @@ -from pathlib import Path -from pprint import pprint import os import shutil import sys +from pathlib import Path +from pprint import pprint + root_path = Path(__file__).parent.parent sys.path.append(str(root_path)) from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory -from chatchat.server.knowledge_base.utils import get_kb_path, get_doc_path, KnowledgeFile -from chatchat.server.knowledge_base.migrate import folder2db, prune_db_docs, prune_folder_files - +from chatchat.server.knowledge_base.migrate import ( + folder2db, + prune_db_docs, + prune_folder_files, +) +from chatchat.server.knowledge_base.utils import ( + KnowledgeFile, + get_doc_path, + get_kb_path, +) # setup test knowledge base kb_name = "test_kb_for_migrate" diff --git a/libs/chatchat-server/tests/test_qwen_agent.py b/libs/chatchat-server/tests/test_qwen_agent.py index 75190d91c..d72abf15e 100644 --- a/libs/chatchat-server/tests/test_qwen_agent.py +++ b/libs/chatchat-server/tests/test_qwen_agent.py @@ -1,16 +1,23 @@ import sys from pathlib import Path + sys.path.append(str(Path(__file__).parent.parent)) import asyncio import json from pprint import pprint + +from langchain import globals from langchain.agents import AgentExecutor + +from chatchat.server.agent.agent_factory.qwen_agent import ( + create_structured_qwen_chat_agent, +) from chatchat.server.agent.tools_factory.tools_registry import all_tools -from chatchat.server.agent.agent_factory.qwen_agent import create_structured_qwen_chat_agent -from chatchat.server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler +from chatchat.server.callback_handler.agent_callback_handler import ( + AgentExecutorAsyncIteratorCallbackHandler, +) from chatchat.server.utils import get_ChatOpenAI -from langchain import globals # globals.set_debug(True) # globals.set_verbose(True) @@ -19,9 +26,9 @@ async def test1(): callback = AgentExecutorAsyncIteratorCallbackHandler() qwen_model = get_ChatOpenAI("qwen", 0.01, streaming=False, callbacks=[callback]) - executor = create_structured_qwen_chat_agent(llm=qwen_model, - tools=all_tools, - callbacks=[callback]) + executor = create_structured_qwen_chat_agent( + llm=qwen_model, tools=all_tools, callbacks=[callback] + ) # ret = executor.invoke({"input": "苏州今天冷吗"}) ret = asyncio.create_task(executor.ainvoke({"input": "苏州今天冷吗"})) async for chunk in callback.aiter(): @@ -34,98 +41,117 @@ async def test1(): async def test_server_chat(): from chatchat.server.chat.chat import chat - mc={'preprocess_model': { - 'qwen': { - 'temperature': 0.4, - 'max_tokens': 2048, - 'history_len': 100, - 'prompt_name': 'default', - 'callbacks': False} + mc = { + "preprocess_model": { + "qwen": { + "temperature": 0.4, + "max_tokens": 2048, + "history_len": 100, + "prompt_name": "default", + "callbacks": False, + } + }, + "llm_model": { + "qwen": { + "temperature": 0.9, + "max_tokens": 4096, + "history_len": 3, + "prompt_name": "default", + "callbacks": True, + } + }, + "action_model": { + "qwen": { + "temperature": 0.01, + "max_tokens": 4096, + "prompt_name": "qwen", + "callbacks": True, + } }, - 'llm_model': { - 'qwen': { - 'temperature': 0.9, - 'max_tokens': 4096, - 'history_len': 3, - 'prompt_name': 'default', - 'callbacks': True} - }, - 'action_model': { - 'qwen': { - 'temperature': 0.01, - 'max_tokens': 4096, - 'prompt_name': 'qwen', - 'callbacks': True} - }, - 'postprocess_model': { - 'qwen': { - 'temperature': 0.01, - 'max_tokens': 4096, - 'prompt_name': 'default', - 'callbacks': True} + "postprocess_model": { + "qwen": { + "temperature": 0.01, + "max_tokens": 4096, + "prompt_name": "default", + "callbacks": True, } - } + }, + } - tc={'weather_check': {'use': False, 'api-key': 'your key'}} + tc = {"weather_check": {"use": False, "api-key": "your key"}} - async for x in (await chat("苏州天气如何",{}, - model_config=mc, - tool_config=tc, - conversation_id=None, - history_len=-1, - history=[], - stream=True)).body_iterator: + async for x in ( + await chat( + "苏州天气如何", + {}, + model_config=mc, + tool_config=tc, + conversation_id=None, + history_len=-1, + history=[], + stream=True, + ) + ).body_iterator: pprint(x) async def test_text2image(): from chatchat.server.chat.chat import chat - mc={'preprocess_model': { - 'qwen-api': { - 'temperature': 0.4, - 'max_tokens': 2048, - 'history_len': 100, - 'prompt_name': 'default', - 'callbacks': False} + mc = { + "preprocess_model": { + "qwen-api": { + "temperature": 0.4, + "max_tokens": 2048, + "history_len": 100, + "prompt_name": "default", + "callbacks": False, + } + }, + "llm_model": { + "qwen-api": { + "temperature": 0.9, + "max_tokens": 4096, + "history_len": 3, + "prompt_name": "default", + "callbacks": True, + } }, - 'llm_model': { - 'qwen-api': { - 'temperature': 0.9, - 'max_tokens': 4096, - 'history_len': 3, - 'prompt_name': 'default', - 'callbacks': True} - }, - 'action_model': { - 'qwen-api': { - 'temperature': 0.01, - 'max_tokens': 4096, - 'prompt_name': 'qwen', - 'callbacks': True} - }, - 'postprocess_model': { - 'qwen-api': { - 'temperature': 0.01, - 'max_tokens': 4096, - 'prompt_name': 'default', - 'callbacks': True} - }, - 'image_model': { - 'sd-turbo': {} - } + "action_model": { + "qwen-api": { + "temperature": 0.01, + "max_tokens": 4096, + "prompt_name": "qwen", + "callbacks": True, + } + }, + "postprocess_model": { + "qwen-api": { + "temperature": 0.01, + "max_tokens": 4096, + "prompt_name": "default", + "callbacks": True, + } + }, + "image_model": {"sd-turbo": {}}, } - tc={'text2images': {'use': True}} + tc = {"text2images": {"use": True}} - async for x in (await chat("draw a house",{}, - model_config=mc, - tool_config=tc, - conversation_id=None, - history_len=-1, - history=[], - stream=False)).body_iterator: + async for x in ( + await chat( + "draw a house", + {}, + model_config=mc, + tool_config=tc, + conversation_id=None, + history_len=-1, + history=[], + stream=False, + ) + ).body_iterator: x = json.loads(x) pprint(x) + asyncio.run(test1()) diff --git a/libs/chatchat-server/tests/unit_tests/config/test_config.py b/libs/chatchat-server/tests/unit_tests/config/test_config.py index 106575e47..7a973b61c 100644 --- a/libs/chatchat-server/tests/unit_tests/config/test_config.py +++ b/libs/chatchat-server/tests/unit_tests/config/test_config.py @@ -1,21 +1,20 @@ +import os from pathlib import Path from chatchat.configs import ( - ConfigBasicFactory, ConfigBasic, + ConfigBasicFactory, ConfigBasicWorkSpace, - ConfigModelWorkSpace, + ConfigKb, + ConfigKbWorkSpace, ConfigModel, - ConfigServerWorkSpace, + ConfigModelWorkSpace, ConfigServer, - ConfigKbWorkSpace, - ConfigKb, + ConfigServerWorkSpace, ) -import os class TestWorkSpace: - def test_config_basic_workspace_clear(self): config_basic_workspace: ConfigBasicWorkSpace = ConfigBasicWorkSpace() @@ -44,7 +43,16 @@ def test_config_basic_workspace(self): config_basic_workspace.clear() def test_workspace_default(self): - from chatchat.configs import (log_verbose, DATA_PATH, IMG_DIR, NLTK_DATA_PATH, LOG_FORMAT, LOG_PATH, MEDIA_PATH) + from chatchat.configs import ( + DATA_PATH, + IMG_DIR, + LOG_FORMAT, + LOG_PATH, + MEDIA_PATH, + NLTK_DATA_PATH, + log_verbose, + ) + assert log_verbose is False assert DATA_PATH is not None assert IMG_DIR is not None @@ -64,11 +72,18 @@ def test_config_model_workspace(self): config_model_workspace.set_history_len(history_len=1) config_model_workspace.set_max_tokens(max_tokens=1000) config_model_workspace.set_temperature(temperature=0.1) - config_model_workspace.set_support_agent_models(support_agent_models=["glm4-chat"]) + config_model_workspace.set_support_agent_models( + support_agent_models=["glm4-chat"] + ) config_model_workspace.set_model_providers_cfg_path_config( - model_providers_cfg_path_config="model_providers.yaml") - config_model_workspace.set_model_providers_cfg_host(model_providers_cfg_host="127.0.0.1") - config_model_workspace.set_model_providers_cfg_port(model_providers_cfg_port=8000) + model_providers_cfg_path_config="model_providers.yaml" + ) + config_model_workspace.set_model_providers_cfg_host( + model_providers_cfg_host="127.0.0.1" + ) + config_model_workspace.set_model_providers_cfg_port( + model_providers_cfg_port=8000 + ) config: ConfigModel = config_model_workspace.get_config() @@ -86,10 +101,21 @@ def test_config_model_workspace(self): def test_model_config(self): from chatchat.configs import ( - DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, Agent_MODEL, HISTORY_LEN, MAX_TOKENS, TEMPERATURE, - SUPPORT_AGENT_MODELS, MODEL_PROVIDERS_CFG_PATH_CONFIG, MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT, - TOOL_CONFIG, MODEL_PLATFORMS, LLM_MODEL_CONFIG + DEFAULT_EMBEDDING_MODEL, + DEFAULT_LLM_MODEL, + HISTORY_LEN, + LLM_MODEL_CONFIG, + MAX_TOKENS, + MODEL_PLATFORMS, + MODEL_PROVIDERS_CFG_HOST, + MODEL_PROVIDERS_CFG_PATH_CONFIG, + MODEL_PROVIDERS_CFG_PORT, + SUPPORT_AGENT_MODELS, + TEMPERATURE, + TOOL_CONFIG, + Agent_MODEL, ) + assert DEFAULT_LLM_MODEL is not None assert DEFAULT_EMBEDDING_MODEL is not None assert Agent_MODEL is None @@ -126,9 +152,13 @@ def test_config_server_workspace(self): def test_server_config(self): from chatchat.configs import ( - HTTPX_DEFAULT_TIMEOUT, OPEN_CROSS_DOMAIN, DEFAULT_BIND_HOST, - WEBUI_SERVER, API_SERVER + API_SERVER, + DEFAULT_BIND_HOST, + HTTPX_DEFAULT_TIMEOUT, + OPEN_CROSS_DOMAIN, + WEBUI_SERVER, ) + assert HTTPX_DEFAULT_TIMEOUT is not None assert OPEN_CROSS_DOMAIN is not None assert DEFAULT_BIND_HOST is not None @@ -152,9 +182,11 @@ def test_config_kb_workspace(self): config_kb_workspace.set_search_engine_top_k(search_engine_top_k=10) config_kb_workspace.set_zh_title_enhance(zh_title_enhance=True) config_kb_workspace.set_pdf_ocr_threshold(pdf_ocr_threshold=(0.1, 0.2)) - config_kb_workspace.set_kb_info(kb_info={ - "samples": "关于本项目issue的解答", - }) + config_kb_workspace.set_kb_info( + kb_info={ + "samples": "关于本项目issue的解答", + } + ) config_kb_workspace.set_kb_root_path(kb_root_path="test") config_kb_workspace.set_db_root_path(db_root_path="test") config_kb_workspace.set_sqlalchemy_database_uri(sqlalchemy_database_uri="test") diff --git a/libs/chatchat-server/tests/unit_tests/test_sdk_import.py b/libs/chatchat-server/tests/unit_tests/test_sdk_import.py index dac6c28fa..4922bebd1 100644 --- a/libs/chatchat-server/tests/unit_tests/test_sdk_import.py +++ b/libs/chatchat-server/tests/unit_tests/test_sdk_import.py @@ -1,10 +1,4 @@ - - def test_sdk_import_unit(): from langchain_chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS - from langchain_chatchat.server.utils import is_port_in_use from langchain_chatchat.startup import run_init_server - - -