diff --git a/.github/actions/custom-build-and-push/action.yml b/.github/actions/custom-build-and-push/action.yml new file mode 100644 index 00000000000..48344237059 --- /dev/null +++ b/.github/actions/custom-build-and-push/action.yml @@ -0,0 +1,76 @@ +name: 'Build and Push Docker Image with Retry' +description: 'Attempts to build and push a Docker image, with a retry on failure' +inputs: + context: + description: 'Build context' + required: true + file: + description: 'Dockerfile location' + required: true + platforms: + description: 'Target platforms' + required: true + pull: + description: 'Always attempt to pull a newer version of the image' + required: false + default: 'true' + push: + description: 'Push the image to registry' + required: false + default: 'true' + load: + description: 'Load the image into Docker daemon' + required: false + default: 'true' + tags: + description: 'Image tags' + required: true + cache-from: + description: 'Cache sources' + required: false + cache-to: + description: 'Cache destinations' + required: false + retry-wait-time: + description: 'Time to wait before retry in seconds' + required: false + default: '5' + +runs: + using: "composite" + steps: + - name: Build and push Docker image (First Attempt) + id: buildx1 + uses: docker/build-push-action@v5 + continue-on-error: true + with: + context: ${{ inputs.context }} + file: ${{ inputs.file }} + platforms: ${{ inputs.platforms }} + pull: ${{ inputs.pull }} + push: ${{ inputs.push }} + load: ${{ inputs.load }} + tags: ${{ inputs.tags }} + cache-from: ${{ inputs.cache-from }} + cache-to: ${{ inputs.cache-to }} + + - name: Wait to retry + if: steps.buildx1.outcome != 'success' + run: | + echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..." + sleep ${{ inputs.retry-wait-time }} + shell: bash + + - name: Build and push Docker image (Retry Attempt) + if: steps.buildx1.outcome != 'success' + uses: docker/build-push-action@v5 + with: + context: ${{ inputs.context }} + file: ${{ inputs.file }} + platforms: ${{ inputs.platforms }} + pull: ${{ inputs.pull }} + push: ${{ inputs.push }} + load: ${{ inputs.load }} + tags: ${{ inputs.tags }} + cache-from: ${{ inputs.cache-from }} + cache-to: ${{ inputs.cache-to }} diff --git a/.github/workflows/helm-build-push.yml b/.github/workflows/helm-build-push.yml index 8f5436c7ae6..149ef1c2d32 100644 --- a/.github/workflows/helm-build-push.yml +++ b/.github/workflows/helm-build-push.yml @@ -4,6 +4,10 @@ on: push: workflow_dispatch: +env: + CHART_YAML_PATH: deployment/helm/charts/danswer/Chart.yaml + VALUES_YAML_PATH: deployment/helm/charts/danswer/values.yaml + jobs: helm_chart_version_check: runs-on: ubuntu-latest @@ -20,11 +24,11 @@ jobs: # on main or a stable tag on a dev branch. - name: Fail on semver pre-release chart version - run: yq .version deployment/helm/Chart.yaml | grep -v '[a-zA-Z-]' + run: yq .version ${{ env.CHART_YAML_PATH }} | grep -v '[a-zA-Z-]' if: ${{ github.ref_name == 'main' }} - name: Fail on stable semver chart version - run: yq .version deployment/helm/Chart.yaml | grep '[a-zA-Z-]' + run: yq .version ${{ env.CHART_YAML_PATH }} | grep '[a-zA-Z-]' if: ${{ github.ref_name != 'main' }} # To reduce resource usage images are built only on tag. @@ -37,19 +41,19 @@ jobs: curl -H "Authorization: Bearer $(echo ${{ secrets.GITHUB_TOKEN }} | base64)" https://ghcr.io/v2/stackhpc/danswer/danswer-backend/tags/list | jq .tags - | grep $( yq .appVersion deployment/helm/Chart.yaml )-$( yq .tagSuffix deployment/helm/values.yaml ) + | grep $( yq .appVersion ${{ env.CHART_YAML_PATH }} )-$( yq .tagSuffix ${{ env.VALUES_YAML_PATH }} ) && curl -H "Authorization: Bearer $(echo ${{ secrets.GITHUB_TOKEN }} | base64)" https://ghcr.io/v2/stackhpc/danswer/danswer-web-server/tags/list | jq .tags - | grep $( yq .appVersion deployment/helm/Chart.yaml )-$( yq .tagSuffix deployment/helm/values.yaml ) + | grep $( yq .appVersion ${{ env.CHART_YAML_PATH }} )-$( yq .tagSuffix ${{ env.VALUES_YAML_PATH }} ) # Check if current chart version exists in releases already - name: Check for Helm chart version bump id: version_check run: | set -xe - chart_version=$(yq .version deployment/helm/Chart.yaml) + chart_version=$(yq .version ${{ env.CHART_YAML_PATH }}) if [[ $(curl https://api.github.com/repos/stackhpc/danswer/releases | jq '.[].tag_name' | grep danswer-helm-$chart_version) ]]; then echo chart_version_changed=false >> $GITHUB_OUTPUT else @@ -84,12 +88,12 @@ jobs: run: | helm repo add bitnami https://charts.bitnami.com/bitnami helm repo add vespa https://unoplat.github.io/vespa-helm-charts - helm dependency build deployment/helm + helm dependency build deployment/helm/charts/danswer - name: Run chart-releaser uses: helm/chart-releaser-action@v1.6.0 with: - charts_dir: deployment + charts_dir: deployment/helm/charts pages_branch: helm-publish mark_as_latest: ${{ github.ref_name == 'main' }} env: diff --git a/.github/workflows/pr-helm-chart-testing.yml.disabled.txt b/.github/workflows/pr-helm-chart-testing.yml.disabled.txt new file mode 100644 index 00000000000..7c4903a07f7 --- /dev/null +++ b/.github/workflows/pr-helm-chart-testing.yml.disabled.txt @@ -0,0 +1,67 @@ +# This workflow is intentionally disabled while we're still working on it +# It's close to ready, but a race condition needs to be fixed with +# API server and Vespa startup, and it needs to have a way to build/test against +# local containers + +name: Helm - Lint and Test Charts + +on: + merge_group: + pull_request: + branches: [ main ] + +jobs: + lint-test: + runs-on: Amd64 + + # fetch-depth 0 is required for helm/chart-testing-action + steps: + - name: Checkout code + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Helm + uses: azure/setup-helm@v4.2.0 + with: + version: v3.14.4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + cache: 'pip' + cache-dependency-path: | + backend/requirements/default.txt + backend/requirements/dev.txt + backend/requirements/model_server.txt + - run: | + python -m pip install --upgrade pip + pip install -r backend/requirements/default.txt + pip install -r backend/requirements/dev.txt + pip install -r backend/requirements/model_server.txt + + - name: Set up chart-testing + uses: helm/chart-testing-action@v2.6.1 + + - name: Run chart-testing (list-changed) + id: list-changed + run: | + changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }}) + if [[ -n "$changed" ]]; then + echo "changed=true" >> "$GITHUB_OUTPUT" + fi + + - name: Run chart-testing (lint) +# if: steps.list-changed.outputs.changed == 'true' + run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }} + + - name: Create kind cluster +# if: steps.list-changed.outputs.changed == 'true' + uses: helm/kind-action@v1.10.0 + + - name: Run chart-testing (install) +# if: steps.list-changed.outputs.changed == 'true' + run: ct install --all --config ct.yaml +# run: ct install --target-branch ${{ github.event.repository.default_branch }} + \ No newline at end of file diff --git a/.github/workflows/pr-python-connector-tests.yml b/.github/workflows/pr-python-connector-tests.yml new file mode 100644 index 00000000000..00b92c9b003 --- /dev/null +++ b/.github/workflows/pr-python-connector-tests.yml @@ -0,0 +1,57 @@ +name: Connector Tests + +on: + pull_request: + branches: [main] + schedule: + # This cron expression runs the job daily at 16:00 UTC (9am PT) + - cron: "0 16 * * *" + +env: + # Confluence + CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }} + CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }} + CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }} + CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }} + CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }} + CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }} + +jobs: + connectors-check: + runs-on: ubuntu-latest + + env: + PYTHONPATH: ./backend + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + cache: "pip" + cache-dependency-path: | + backend/requirements/default.txt + backend/requirements/dev.txt + + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install -r backend/requirements/default.txt + pip install -r backend/requirements/dev.txt + + - name: Run Tests + shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}" + run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors + + - name: Alert on Failure + if: failure() && github.event_name == 'schedule' + env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} + run: | + curl -X POST \ + -H 'Content-type: application/json' \ + --data '{"text":"Scheduled Connector Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \ + $SLACK_WEBHOOK diff --git a/.github/workflows/run-it.yml b/.github/workflows/run-it.yml index 7c0c1814c3b..0ca0031c64c 100644 --- a/.github/workflows/run-it.yml +++ b/.github/workflows/run-it.yml @@ -28,30 +28,20 @@ jobs: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - - name: Build Web Docker image - uses: docker/build-push-action@v5 - with: - context: ./web - file: ./web/Dockerfile - platforms: linux/arm64 - pull: true - push: true - load: true - tags: danswer/danswer-web-server:it - cache-from: type=registry,ref=danswer/danswer-web-server:it - cache-to: | - type=registry,ref=danswer/danswer-web-server:it,mode=max - type=inline + # NOTE: we don't need to build the Web Docker image since it's not used + # during the IT for now. We have a separate action to verify it builds + # succesfully + - name: Pull Web Docker image + run: | + docker pull danswer/danswer-web-server:latest + docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:it - name: Build Backend Docker image - uses: docker/build-push-action@v5 + uses: ./.github/actions/custom-build-and-push with: context: ./backend file: ./backend/Dockerfile platforms: linux/arm64 - pull: true - push: true - load: true tags: danswer/danswer-backend:it cache-from: type=registry,ref=danswer/danswer-backend:it cache-to: | @@ -59,14 +49,11 @@ jobs: type=inline - name: Build Model Server Docker image - uses: docker/build-push-action@v5 + uses: ./.github/actions/custom-build-and-push with: context: ./backend file: ./backend/Dockerfile.model_server platforms: linux/arm64 - pull: true - push: true - load: true tags: danswer/danswer-model-server:it cache-from: type=registry,ref=danswer/danswer-model-server:it cache-to: | @@ -74,14 +61,11 @@ jobs: type=inline - name: Build integration test Docker image - uses: docker/build-push-action@v5 + uses: ./.github/actions/custom-build-and-push with: context: ./backend file: ./backend/tests/integration/Dockerfile platforms: linux/arm64 - pull: true - push: true - load: true tags: danswer/integration-test-runner:it cache-from: type=registry,ref=danswer/integration-test-runner:it cache-to: | @@ -92,8 +76,11 @@ jobs: run: | cd deployment/docker_compose ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \ + AUTH_TYPE=basic \ + REQUIRE_EMAIL_VERIFICATION=false \ + DISABLE_TELEMETRY=true \ IMAGE_TAG=it \ - docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build + docker compose -f docker-compose.dev.yml -p danswer-stack up -d id: start_docker - name: Wait for service to be ready @@ -137,6 +124,7 @@ jobs: -e POSTGRES_PASSWORD=password \ -e POSTGRES_DB=postgres \ -e VESPA_HOST=index \ + -e REDIS_HOST=cache \ -e API_SERVER_HOST=api_server \ -e OPENAI_API_KEY=${OPENAI_API_KEY} \ danswer/integration-test-runner:it diff --git a/.gitignore b/.gitignore index d9d7727b2f0..aedf5ed007b 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ .mypy_cache .idea /deployment/data/nginx/app.conf -.vscode/launch.json +.vscode/ *.sw? /backend/tests/regression/answer_quality/search_test_config.yaml +**.tgz diff --git a/.vscode/env_template.txt b/.vscode/env_template.txt index b3fae8cee73..89faca0abf0 100644 --- a/.vscode/env_template.txt +++ b/.vscode/env_template.txt @@ -1,5 +1,5 @@ -# Copy this file to .env at the base of the repo and fill in the values -# This will help with development iteration speed and reduce repeat tasks for dev +# Copy this file to .env in the .vscode folder +# Fill in the values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI # Also check out danswer/backend/scripts/restart_containers.sh for a script to restart the containers which Danswer relies on outside of VSCode/Cursor processes # For local dev, often user Authentication is not needed @@ -15,7 +15,7 @@ LOG_LEVEL=debug # This passes top N results to LLM an additional time for reranking prior to answer generation # This step is quite heavy on token usage so we disable it for dev generally -DISABLE_LLM_DOC_RELEVANCE=True +DISABLE_LLM_DOC_RELEVANCE=False # Useful if you want to toggle auth on/off (google_oauth/OIDC specifically) @@ -27,9 +27,9 @@ REQUIRE_EMAIL_VERIFICATION=False # Set these so if you wipe the DB, you don't end up having to go through the UI every time GEN_AI_API_KEY= -# If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper -GEN_AI_MODEL_VERSION=gpt-3.5-turbo -FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo +# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper +GEN_AI_MODEL_VERSION=gpt-4o +FAST_GEN_AI_MODEL_VERSION=gpt-4o # For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time # Only needed if using DanswerBot @@ -38,7 +38,7 @@ FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo # Python stuff -PYTHONPATH=./backend +PYTHONPATH=../backend PYTHONUNBUFFERED=1 @@ -49,4 +49,3 @@ BING_API_KEY= # Enable the full set of Danswer Enterprise Edition features # NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development) ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False - diff --git a/.vscode/launch.template.jsonc b/.vscode/launch.template.jsonc index 9aaadb32acf..c733800981c 100644 --- a/.vscode/launch.template.jsonc +++ b/.vscode/launch.template.jsonc @@ -1,15 +1,23 @@ -/* - - Copy this file into '.vscode/launch.json' or merge its - contents into your existing configurations. - -*/ +/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */ { // Use IntelliSense to learn about possible attributes. // Hover to view descriptions of existing attributes. // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", + "compounds": [ + { + "name": "Run All Danswer Services", + "configurations": [ + "Web Server", + "Model Server", + "API Server", + "Indexing", + "Background Jobs", + "Slack Bot" + ] + } + ], "configurations": [ { "name": "Web Server", @@ -17,7 +25,7 @@ "request": "launch", "cwd": "${workspaceRoot}/web", "runtimeExecutable": "npm", - "envFile": "${workspaceFolder}/.env", + "envFile": "${workspaceFolder}/.vscode/.env", "runtimeArgs": [ "run", "dev" ], @@ -25,11 +33,12 @@ }, { "name": "Model Server", - "type": "python", + "consoleName": "Model Server", + "type": "debugpy", "request": "launch", "module": "uvicorn", "cwd": "${workspaceFolder}/backend", - "envFile": "${workspaceFolder}/.env", + "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1" @@ -39,16 +48,16 @@ "--reload", "--port", "9000" - ], - "consoleTitle": "Model Server" + ] }, { "name": "API Server", - "type": "python", + "consoleName": "API Server", + "type": "debugpy", "request": "launch", "module": "uvicorn", "cwd": "${workspaceFolder}/backend", - "envFile": "${workspaceFolder}/.env", + "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_DANSWER_MODEL_INTERACTIONS": "True", "LOG_LEVEL": "DEBUG", @@ -59,32 +68,32 @@ "--reload", "--port", "8080" - ], - "consoleTitle": "API Server" + ] }, { "name": "Indexing", - "type": "python", + "consoleName": "Indexing", + "type": "debugpy", "request": "launch", "program": "danswer/background/update.py", "cwd": "${workspaceFolder}/backend", - "envFile": "${workspaceFolder}/.env", + "envFile": "${workspaceFolder}/.vscode/.env", "env": { "ENABLE_MULTIPASS_INDEXING": "false", "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." - }, - "consoleTitle": "Indexing" + } }, // Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev { "name": "Background Jobs", - "type": "python", + "consoleName": "Background Jobs", + "type": "debugpy", "request": "launch", "program": "scripts/dev_run_background_jobs.py", "cwd": "${workspaceFolder}/backend", - "envFile": "${workspaceFolder}/.env", + "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_DANSWER_MODEL_INTERACTIONS": "True", "LOG_LEVEL": "DEBUG", @@ -93,18 +102,18 @@ }, "args": [ "--no-indexing" - ], - "consoleTitle": "Background Jobs" + ] }, // For the listner to access the Slack API, // DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project { "name": "Slack Bot", - "type": "python", + "consoleName": "Slack Bot", + "type": "debugpy", "request": "launch", "program": "danswer/danswerbot/slack/listener.py", "cwd": "${workspaceFolder}/backend", - "envFile": "${workspaceFolder}/.env", + "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", @@ -113,11 +122,12 @@ }, { "name": "Pytest", - "type": "python", + "consoleName": "Pytest", + "type": "debugpy", "request": "launch", "module": "pytest", "cwd": "${workspaceFolder}/backend", - "envFile": "${workspaceFolder}/.env", + "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", @@ -128,18 +138,16 @@ // Specify a sepcific module/test to run or provide nothing to run all tests //"tests/unit/danswer/llm/answering/test_prune_and_merge.py" ] - } - ], - "compounds": [ + }, { - "name": "Run Danswer", - "configurations": [ - "Web Server", - "Model Server", - "API Server", - "Indexing", - "Background Jobs", - ] + "name": "Clear and Restart External Volumes and Containers", + "type": "node", + "request": "launch", + "runtimeExecutable": "bash", + "runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"], + "cwd": "${workspaceFolder}", + "console": "integratedTerminal", + "stopOnEntry": true } ] } diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 116e78b6f19..a23118e52ad 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -48,20 +48,24 @@ We would love to see you there! ## Get Started 🚀 -Danswer being a fully functional app, relies on some external pieces of software, specifically: +Danswer being a fully functional app, relies on some external software, specifically: - [Postgres](https://www.postgresql.org/) (Relational DB) - [Vespa](https://vespa.ai/) (Vector DB/Search Engine) +- [Redis](https://redis.io/) (Cache) +- [Nginx](https://nginx.org/) (Not needed for development flows generally) -This guide provides instructions to set up the Danswer specific services outside of Docker because it's easier for -development purposes but also feel free to just use the containers and update with local changes by providing the -`--build` flag. + +> **Note:** +> This guide provides instructions to set up the Danswer specific services outside of Docker because it's easier for +> development purposes. However, you can also use the containers and update with local changes by providing the +> `--build` flag. ### Local Set Up -It is recommended to use Python version 3.11 +Be sure to use Python version 3.11. If using a lower version, modifications will have to be made to the code. -If using a higher version, the version of Tensorflow we use may not be available for your platform. +If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python). #### Installing Requirements @@ -73,8 +77,9 @@ python -m venv .venv source .venv/bin/activate ``` ---> Note that this virtual environment MUST NOT be set up WITHIN the danswer -directory +> **Note:** +> This virtual environment MUST NOT be set up WITHIN the danswer directory if you plan on using mypy within certain IDEs. +> For simplicity, we recommend setting up the virtual environment outside of the danswer directory. _For Windows, activate the virtual environment using Command Prompt:_ ```bash @@ -89,19 +94,22 @@ Install the required python dependencies: ```bash pip install -r danswer/backend/requirements/default.txt pip install -r danswer/backend/requirements/dev.txt +pip install -r danswer/backend/requirements/ee.txt pip install -r danswer/backend/requirements/model_server.txt ``` + Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend. Once the above is done, navigate to `danswer/web` run: ```bash npm i ``` -Install Playwright (required by the Web Connector) +Install Playwright (headless browser required by the Web Connector) -> Note: If you have just done the pip install, open a new terminal and source the python virtual-env again. -This will update the path to include playwright +> **Note:** +> If you have just run the pip install, open a new terminal and source the python virtual-env again. +> This will pull the updated PATH to include playwright Then install Playwright by running: ```bash @@ -110,11 +118,14 @@ playwright install #### Dependent Docker Containers -First navigate to `danswer/deployment/docker_compose`, then start up Vespa and Postgres with: +You will need Docker installed to run these containers. + +First navigate to `danswer/deployment/docker_compose`, then start up Postgres/Vespa/Redis with: ```bash -docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db +docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db cache ``` -(index refers to Vespa and relational_db refers to Postgres) +(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis) + #### Running Danswer To start the frontend, navigate to `danswer/web` and run: @@ -127,11 +138,10 @@ Navigate to `danswer/backend` and run: ```bash uvicorn model_server.main:app --reload --port 9000 ``` + _For Windows (for compatibility with both PowerShell and Command Prompt):_ ```bash -powershell -Command " - uvicorn model_server.main:app --reload --port 9000 -" +powershell -Command "uvicorn model_server.main:app --reload --port 9000" ``` The first time running Danswer, you will need to run the DB migrations for Postgres. @@ -154,6 +164,7 @@ To run the backend API server, navigate back to `danswer/backend` and run: ```bash AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080 ``` + _For Windows (for compatibility with both PowerShell and Command Prompt):_ ```bash powershell -Command " @@ -162,20 +173,28 @@ powershell -Command " " ``` -Note: if you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services. +> **Note:** +> If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services. + ### Formatting and Linting #### Backend For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports). First, install pre-commit (if you don't have it already) following the instructions [here](https://pre-commit.com/#installation). + +With the virtual environment active, install the pre-commit library with: +```bash +pip install pre-commit +``` + Then, from the `danswer/backend` directory, run: ```bash pre-commit install ``` Additionally, we use `mypy` for static type checking. -Danswer is fully type-annotated, and we would like to keep it that way! +Danswer is fully type-annotated, and we want to keep it that way! To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory. @@ -186,6 +205,7 @@ Please double check that prettier passes before creating a pull request. ### Release Process -Danswer follows the semver versioning standard. +Danswer loosely follows the SemVer versioning standard. +Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes. A set of Docker containers will be pushed automatically to DockerHub with every tag. You can see the containers [here](https://hub.docker.com/search?q=danswer%2F). diff --git a/CONTRIBUTING_MACOS.md b/CONTRIBUTING_MACOS.md new file mode 100644 index 00000000000..519eccffd51 --- /dev/null +++ b/CONTRIBUTING_MACOS.md @@ -0,0 +1,31 @@ +## Some additional notes for Mac Users +The base instructions to set up the development environment are located in [CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md). + +### Setting up Python +Ensure [Homebrew](https://brew.sh/) is already set up. + +Then install python 3.11. +```bash +brew install python@3.11 +``` + +Add python 3.11 to your path: add the following line to ~/.zshrc +``` +export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH" +``` + +> **Note:** +> You will need to open a new terminal for the path change above to take effect. + + +### Setting up Docker +On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and +ensure it is running before continuing with the docker commands. + + +### Formatting and Linting +MacOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly. +After installing pre-commit, run the following command: +```bash +sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit +``` \ No newline at end of file diff --git a/backend/Dockerfile b/backend/Dockerfile index d8c388801d7..de9d9472b11 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -74,9 +74,9 @@ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')" # Pre-downloading NLTK for setups with limited egress RUN python -c "import nltk; \ - nltk.download('stopwords', quiet=True); \ - nltk.download('wordnet', quiet=True); \ - nltk.download('punkt', quiet=True);" +nltk.download('stopwords', quiet=True); \ +nltk.download('punkt', quiet=True);" +# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed # Set up application files WORKDIR /app diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 8c028202bfc..154d6ff3d66 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -16,7 +16,9 @@ # Interpret the config file for Python logging. # This line sets up loggers basically. -if config.config_file_name is not None: +if config.config_file_name is not None and config.attributes.get( + "configure_logger", True +): fileConfig(config.config_file_name) # add your model's MetaData object here diff --git a/backend/alembic/versions/a3795dce87be_migration_confluence_to_be_explicit.py b/backend/alembic/versions/a3795dce87be_migration_confluence_to_be_explicit.py new file mode 100644 index 00000000000..20e33d0e227 --- /dev/null +++ b/backend/alembic/versions/a3795dce87be_migration_confluence_to_be_explicit.py @@ -0,0 +1,158 @@ +"""migration confluence to be explicit + +Revision ID: a3795dce87be +Revises: 1f60f60c3401 +Create Date: 2024-09-01 13:52:12.006740 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from sqlalchemy.sql import table, column + +revision = "a3795dce87be" +down_revision = "1f60f60c3401" +branch_labels: None = None +depends_on: None = None + + +def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]: + from urllib.parse import urlparse + + def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]: + parsed_url = urlparse(wiki_url) + wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split('/spaces')[0]}" + path_parts = parsed_url.path.split("/") + space = path_parts[3] + page_id = path_parts[5] if len(path_parts) > 5 else "" + return wiki_base, space, page_id + + def _extract_confluence_keys_from_datacenter_url( + wiki_url: str, + ) -> tuple[str, str, str]: + DISPLAY = "/display/" + PAGE = "/pages/" + parsed_url = urlparse(wiki_url) + wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split(DISPLAY)[0]}" + space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0] + page_id = "" + if (content := parsed_url.path.split(PAGE)) and len(content) > 1: + page_id = content[1] + return wiki_base, space, page_id + + is_confluence_cloud = ( + ".atlassian.net/wiki/spaces/" in wiki_url + or ".jira.com/wiki/spaces/" in wiki_url + ) + + if is_confluence_cloud: + wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(wiki_url) + else: + wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url( + wiki_url + ) + + return wiki_base, space, page_id, is_confluence_cloud + + +def reconstruct_confluence_url( + wiki_base: str, space: str, page_id: str, is_cloud: bool +) -> str: + if is_cloud: + url = f"{wiki_base}/spaces/{space}" + if page_id: + url += f"/pages/{page_id}" + else: + url = f"{wiki_base}/display/{space}" + if page_id: + url += f"/pages/{page_id}" + return url + + +def upgrade() -> None: + connector = table( + "connector", + column("id", sa.Integer), + column("source", sa.String()), + column("input_type", sa.String()), + column("connector_specific_config", postgresql.JSONB), + ) + + # Fetch all Confluence connectors + connection = op.get_bind() + confluence_connectors = connection.execute( + sa.select(connector).where( + sa.and_( + connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL" + ) + ) + ).fetchall() + + for row in confluence_connectors: + config = row.connector_specific_config + wiki_page_url = config["wiki_page_url"] + wiki_base, space, page_id, is_cloud = extract_confluence_keys_from_url( + wiki_page_url + ) + + new_config = { + "wiki_base": wiki_base, + "space": space, + "page_id": page_id, + "is_cloud": is_cloud, + } + + for key, value in config.items(): + if key not in ["wiki_page_url"]: + new_config[key] = value + + op.execute( + connector.update() + .where(connector.c.id == row.id) + .values(connector_specific_config=new_config) + ) + + +def downgrade() -> None: + connector = table( + "connector", + column("id", sa.Integer), + column("source", sa.String()), + column("input_type", sa.String()), + column("connector_specific_config", postgresql.JSONB), + ) + + confluence_connectors = ( + op.get_bind() + .execute( + sa.select(connector).where( + connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL" + ) + ) + .fetchall() + ) + + for row in confluence_connectors: + config = row.connector_specific_config + if all(key in config for key in ["wiki_base", "space", "is_cloud"]): + wiki_page_url = reconstruct_confluence_url( + config["wiki_base"], + config["space"], + config.get("page_id", ""), + config["is_cloud"], + ) + + new_config = {"wiki_page_url": wiki_page_url} + new_config.update( + { + k: v + for k, v in config.items() + if k not in ["wiki_base", "space", "page_id", "is_cloud"] + } + ) + + op.execute( + connector.update() + .where(connector.c.id == row.id) + .values(connector_specific_config=new_config) + ) diff --git a/backend/alembic/versions/ba98eba0f66a_add_support_for_litellm_proxy_in_.py b/backend/alembic/versions/ba98eba0f66a_add_support_for_litellm_proxy_in_.py new file mode 100644 index 00000000000..2d45a15f2c6 --- /dev/null +++ b/backend/alembic/versions/ba98eba0f66a_add_support_for_litellm_proxy_in_.py @@ -0,0 +1,26 @@ +"""add support for litellm proxy in reranking + +Revision ID: ba98eba0f66a +Revises: bceb1e139447 +Create Date: 2024-09-06 10:36:04.507332 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "ba98eba0f66a" +down_revision = "bceb1e139447" +branch_labels: None = None +depends_on: None = None + + +def upgrade() -> None: + op.add_column( + "search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True) + ) + + +def downgrade() -> None: + op.drop_column("search_settings", "rerank_api_url") diff --git a/backend/alembic/versions/bceb1e139447_add_base_url_to_cloudembeddingprovider.py b/backend/alembic/versions/bceb1e139447_add_base_url_to_cloudembeddingprovider.py new file mode 100644 index 00000000000..968500e6aaf --- /dev/null +++ b/backend/alembic/versions/bceb1e139447_add_base_url_to_cloudembeddingprovider.py @@ -0,0 +1,26 @@ +"""Add base_url to CloudEmbeddingProvider + +Revision ID: bceb1e139447 +Revises: a3795dce87be +Create Date: 2024-08-28 17:00:52.554580 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "bceb1e139447" +down_revision = "a3795dce87be" +branch_labels: None = None +depends_on: None = None + + +def upgrade() -> None: + op.add_column( + "embedding_provider", sa.Column("api_url", sa.String(), nullable=True) + ) + + +def downgrade() -> None: + op.drop_column("embedding_provider", "api_url") diff --git a/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py b/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py new file mode 100644 index 00000000000..2d8e7402e48 --- /dev/null +++ b/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py @@ -0,0 +1,26 @@ +"""add has_web_login column to user + +Revision ID: f7e58d357687 +Revises: bceb1e139447 +Create Date: 2024-09-07 20:20:54.522620 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "f7e58d357687" +down_revision = "ba98eba0f66a" +branch_labels: None = None +depends_on: None = None + + +def upgrade() -> None: + op.add_column( + "user", + sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"), + ) + + +def downgrade() -> None: + op.drop_column("user", "has_web_login") diff --git a/backend/danswer/auth/schemas.py b/backend/danswer/auth/schemas.py index 9e0553991cc..db8a97ceb04 100644 --- a/backend/danswer/auth/schemas.py +++ b/backend/danswer/auth/schemas.py @@ -33,7 +33,9 @@ class UserRead(schemas.BaseUser[uuid.UUID]): class UserCreate(schemas.BaseUserCreate): role: UserRole = UserRole.BASIC + has_web_login: bool | None = True class UserUpdate(schemas.BaseUserUpdate): role: UserRole + has_web_login: bool | None = True diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index eec1db412e0..89ade310338 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -16,7 +16,9 @@ from fastapi import Request from fastapi import Response from fastapi import status +from fastapi.security import OAuth2PasswordRequestForm from fastapi_users import BaseUserManager +from fastapi_users import exceptions from fastapi_users import FastAPIUsers from fastapi_users import models from fastapi_users import schemas @@ -33,6 +35,7 @@ from danswer.auth.invited_users import get_invited_users from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserRole +from danswer.auth.schemas import UserUpdate from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import DISABLE_AUTH from danswer.configs.app_configs import EMAIL_FROM @@ -67,23 +70,6 @@ logger = setup_logger() -def validate_curator_request(groups: list | None, is_public: bool) -> None: - if is_public: - detail = "Curators cannot create public objects" - logger.error(detail) - raise HTTPException( - status_code=401, - detail=detail, - ) - if not groups: - detail = "Curators must specify 1+ groups" - logger.error(detail) - raise HTTPException( - status_code=401, - detail=detail, - ) - - def is_user_admin(user: User | None) -> bool: if AUTH_TYPE == AuthType.DISABLED: return True @@ -201,16 +187,36 @@ async def create( user_create: schemas.UC | UserCreate, safe: bool = False, request: Optional[Request] = None, - ) -> models.UP: + ) -> User: verify_email_is_invited(user_create.email) verify_email_domain(user_create.email) # if hasattr(user_create, "role"): - # user_count = await get_user_count() - # if user_count == 0 or user_create.email in get_default_admin_user_emails(): - # user_create.role = UserRole.ADMIN - # else: - # user_create.role = UserRole.BASIC - return await super().create(user_create, safe=safe, request=request) # type: ignore + # user_count = await get_user_count() + # if user_count == 0 or user_create.email in get_default_admin_user_emails(): + # user_create.role = UserRole.ADMIN + # else: + # user_create.role = UserRole.BASIC + user = None + try: + user = await super().create(user_create, safe=safe, request=request) # type: ignore + except exceptions.UserAlreadyExists: + user = await self.get_by_email(user_create.email) + # Handle case where user has used product outside of web and is now creating an account through web + if ( + not user.has_web_login + and hasattr(user_create, "has_web_login") + and user_create.has_web_login + ): + user_update = UserUpdate( + password=user_create.password, + has_web_login=True, + role=user_create.role, + is_verified=user_create.is_verified, + ) + user = await self.update(user_update, user) + else: + raise exceptions.UserAlreadyExists() + return user async def oauth_callback( self: "BaseUserManager[models.UOAP, models.ID]", @@ -251,6 +257,17 @@ async def oauth_callback( if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY: await self.user_db.update(user, update_dict={"oidc_expiry": None}) + # Handle case where user has used product outside of web and is now creating an account through web + if not user.has_web_login: + await self.user_db.update( + user, + update_dict={ + "is_verified": is_verified_by_default, + "has_web_login": True, + }, + ) + user.is_verified = is_verified_by_default + user.has_web_login = True return user async def on_after_register( @@ -279,6 +296,22 @@ async def on_after_request_verify( send_user_verification_email(user.email, token) + async def authenticate( + self, credentials: OAuth2PasswordRequestForm + ) -> Optional[User]: + user = await super().authenticate(credentials) + if user is None: + try: + user = await self.get_by_email(credentials.username) + if not user.has_web_login: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD", + ) + except exceptions.UserNotExists: + pass + return user + async def get_user_manager( user_db: SQLAlchemyUserDatabase = Depends(get_user_db), diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index ffd805c2986..c401dde83ca 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -6,6 +6,7 @@ from celery import Celery # type: ignore from celery.contrib.abortable import AbortableTask # type: ignore from celery.exceptions import TaskRevokedError +from sqlalchemy import inspect from sqlalchemy import text from sqlalchemy.orm import Session @@ -20,7 +21,10 @@ from danswer.background.task_utils import name_cc_prune_task from danswer.background.task_utils import name_document_set_sync_task from danswer.configs.app_configs import JOB_TIMEOUT -from danswer.configs.constants import POSTGRES_CELERY_APP_NAME +from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY +from danswer.configs.app_configs import REDIS_HOST +from danswer.configs.app_configs import REDIS_PASSWORD +from danswer.configs.app_configs import REDIS_PORT from danswer.configs.constants import PostgresAdvisoryLocks from danswer.connectors.factory import instantiate_connector from danswer.connectors.models import InputType @@ -35,9 +39,7 @@ from danswer.db.document_set import fetch_documents_for_document_set_paginated from danswer.db.document_set import get_document_set_by_id from danswer.db.document_set import mark_document_set_as_synced -from danswer.db.engine import build_connection_string from danswer.db.engine import get_sqlalchemy_engine -from danswer.db.engine import SYNC_DB_API from danswer.db.models import DocumentSet from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index @@ -46,11 +48,17 @@ logger = setup_logger() -connection_string = build_connection_string( - db_api=SYNC_DB_API, app_name=POSTGRES_CELERY_APP_NAME +CELERY_PASSWORD_PART = "" +if REDIS_PASSWORD: + CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@" + +# example celery_broker_url: "redis://:password@localhost:6379/15" +celery_broker_url = ( + f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}" +) +celery_backend_url = ( + f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}" ) -celery_broker_url = f"sqla+{connection_string}" -celery_backend_url = f"db+{connection_string}" celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url) @@ -360,6 +368,15 @@ def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool: bool: Returns True if there are more rows to process, False if not. """ + inspector = inspect(db_session.bind) + if not inspector: + return False + + # With the move to redis as celery's broker and backend, kombu tables may not even exist. + # We can fail silently. + if not inspector.has_table("kombu_message"): + return False + query = text( """ SELECT id, timestamp, payload diff --git a/backend/danswer/background/connector_deletion.py b/backend/danswer/background/connector_deletion.py index 90883564910..c904c804d06 100644 --- a/backend/danswer/background/connector_deletion.py +++ b/backend/danswer/background/connector_deletion.py @@ -151,8 +151,7 @@ def delete_connector_credential_pair( # index attempts delete_index_attempts( db_session=db_session, - connector_id=connector_id, - credential_id=credential_id, + cc_pair_id=cc_pair.id, ) # document sets diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index a98f4e1f5ad..d65b4b0c907 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -118,19 +118,19 @@ def _run_indexing( db_cc_pair = index_attempt.connector_credential_pair db_connector = index_attempt.connector_credential_pair.connector db_credential = index_attempt.connector_credential_pair.credential + earliest_index_time = ( + db_connector.indexing_start.timestamp() if db_connector.indexing_start else 0 + ) last_successful_index_time = ( - db_connector.indexing_start.timestamp() - if index_attempt.from_beginning and db_connector.indexing_start is not None - else ( - 0.0 - if index_attempt.from_beginning - else get_last_successful_attempt_time( - connector_id=db_connector.id, - credential_id=db_credential.id, - search_settings=index_attempt.search_settings, - db_session=db_session, - ) + earliest_index_time + if index_attempt.from_beginning + else get_last_successful_attempt_time( + connector_id=db_connector.id, + credential_id=db_credential.id, + earliest_index=earliest_index_time, + search_settings=index_attempt.search_settings, + db_session=db_session, ) ) diff --git a/backend/danswer/background/task_utils.py b/backend/danswer/background/task_utils.py index 6e122678813..e746e43abae 100644 --- a/backend/danswer/background/task_utils.py +++ b/backend/danswer/background/task_utils.py @@ -93,9 +93,16 @@ def wrapped_fn( kwargs_for_build_name = kwargs or {} task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name) with Session(get_sqlalchemy_engine()) as db_session: - # mark the task as started + # register_task must come before fn = apply_async or else the task + # might run mark_task_start (and crash) before the task row exists + db_task = register_task(task_name, db_session) + task = fn(args, kwargs, *other_args, **other_kwargs) - register_task(task.id, task_name, db_session) + + # we update the celery task id for diagnostic purposes + # but it isn't currently used by any code + db_task.task_id = task.id + db_session.commit() return task diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 28abb481143..10fa36a1d8c 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -17,6 +17,7 @@ from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP from danswer.configs.app_configs import NUM_INDEXING_WORKERS from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS +from danswer.configs.constants import DocumentSource from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME from danswer.db.connector import fetch_connectors from danswer.db.connector_credential_pair import fetch_connector_credential_pairs @@ -46,7 +47,6 @@ from shared_configs.configs import LOG_LEVEL from shared_configs.configs import MODEL_SERVER_PORT - logger = setup_logger() # If the indexing dies, it's most likely due to resource constraints, @@ -67,6 +67,10 @@ def _should_create_new_indexing( ) -> bool: connector = cc_pair.connector + # don't kick off indexing for `NOT_APPLICABLE` sources + if connector.source == DocumentSource.NOT_APPLICABLE: + return False + # User can still manually create single indexing attempts via the UI for the # currently in use index if DISABLE_INDEX_UPDATE_ON_SWAP: diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index 6d12d68df08..1f1f15ea700 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -1,5 +1,6 @@ from collections.abc import Iterator from datetime import datetime +from enum import Enum from typing import Any from pydantic import BaseModel @@ -44,6 +45,20 @@ def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: return initial_dict +class StreamStopReason(Enum): + CONTEXT_LENGTH = "context_length" + CANCELLED = "cancelled" + + +class StreamStopInfo(BaseModel): + stop_reason: StreamStopReason + + def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore + data = super().model_dump(mode="json", *args, **kwargs) # type: ignore + data["stop_reason"] = self.stop_reason.name + return data + + class LLMRelevanceFilterResponse(BaseModel): relevant_chunk_indices: list[int] @@ -144,6 +159,7 @@ class CustomToolResponse(BaseModel): | ImageGenerationDisplay | CustomToolResponse | StreamingError + | StreamStopInfo ) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index f6b218c5f56..d7733fdc0ab 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -149,6 +149,16 @@ except ValueError: POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT +REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost" +REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379)) +REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or "" + +# Used for general redis things +REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0)) + +# Used by celery as broker and backend +REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) + ##### # Connector Configs ##### diff --git a/backend/danswer/configs/chat_configs.py b/backend/danswer/configs/chat_configs.py index 2b6b0990e1d..6939ce9bd89 100644 --- a/backend/danswer/configs/chat_configs.py +++ b/backend/danswer/configs/chat_configs.py @@ -88,3 +88,6 @@ # Internet Search BING_API_KEY = os.environ.get("BING_API_KEY") or None + +# Enable in-house model for detecting connector-based filtering in queries +ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False) diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 64c162d7bef..eff8ee30a63 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -57,6 +57,7 @@ KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time" KV_SETTINGS_KEY = "danswer_settings" KV_CUSTOMER_UUID_KEY = "customer_uuid" +KV_INSTANCE_DOMAIN_KEY = "instance_domain" KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings" KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__" diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index e5fa5e74a28..b03ebc712fc 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -51,16 +51,8 @@ # Generative AI Model Configs ##### -# If changing GEN_AI_MODEL_PROVIDER or GEN_AI_MODEL_VERSION from the default, -# be sure to use one that is LiteLLM compatible: -# https://litellm.vercel.app/docs/providers/azure#completion---using-env-variables -# The provider is the prefix before / in the model argument - -# Additionally Danswer supports GPT4All and custom request library based models -# Set GEN_AI_MODEL_PROVIDER to "custom" to use the custom requests approach -# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally -GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai" -# If using Azure, it's the engine name, for example: Danswer +# NOTE: the 3 below should only be used for dev. +GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY") GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") # The fallback display name to use for default model when using a custom model provider GEN_AI_DISPLAY_NAME = os.environ.get("GEN_AI_DISPLAY_NAME") or "Custom LLM" @@ -69,17 +61,6 @@ # as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper FAST_GEN_AI_MODEL_VERSION = os.environ.get("FAST_GEN_AI_MODEL_VERSION") -# If the Generative AI model requires an API key for access, otherwise can leave blank -GEN_AI_API_KEY = ( - os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) or None -) - -# API Base, such as (for Azure): https://danswer.openai.azure.com/ -GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None -# API Version, such as (for Azure): 2023-09-15-preview -GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None -# LiteLLM custom_llm_provider -GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None # Override the auto-detection of LLM max context length GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None diff --git a/backend/danswer/connectors/README.md b/backend/danswer/connectors/README.md index b50232fa256..ef6c63d2697 100644 --- a/backend/danswer/connectors/README.md +++ b/backend/danswer/connectors/README.md @@ -59,6 +59,8 @@ if __name__ == "__main__": latest_docs = test_connector.poll_source(one_day_ago, current) ``` +> Note: Be sure to set PYTHONPATH to danswer/backend before running the above main. + ### Additional Required Changes: #### Backend Changes @@ -68,17 +70,16 @@ if __name__ == "__main__": [here](https://github.com/danswer-ai/danswer/blob/main/backend/danswer/connectors/factory.py#L33) #### Frontend Changes -- Create the new connector directory and admin page under `danswer/web/src/app/admin/connectors/` -- Create the new icon, type, source, and filter changes -(refer to existing [PR](https://github.com/danswer-ai/danswer/pull/139)) +- Add the new Connector definition to the `SOURCE_METADATA_MAP` [here](https://github.com/danswer-ai/danswer/blob/main/web/src/lib/sources.ts#L59). +- Add the definition for the new Form to the `connectorConfigs` object [here](https://github.com/danswer-ai/danswer/blob/main/web/src/lib/connectors/connectors.ts#L79). #### Docs Changes Create the new connector page (with guiding images!) with how to get the connector credentials and how to set up the -connector in Danswer. Then create a Pull Request in https://github.com/danswer-ai/danswer-docs - +connector in Danswer. Then create a Pull Request in https://github.com/danswer-ai/danswer-docs. ### Before opening PR 1. Be sure to fully test changes end to end with setting up the connector and updating the index with new docs from the -new connector. -2. Be sure to run the linting/formatting, refer to the formatting and linting section in +new connector. To make it easier to review, please attach a video showing the successful creation of the connector via the UI (starting from the `Add Connector` page). +2. Add a folder + tests under `backend/tests/daily/connectors` director. For an example, checkout the [test for Confluence](https://github.com/danswer-ai/danswer/blob/main/backend/tests/daily/connectors/confluence/test_confluence_basic.py). In the PR description, include a guide on how to setup the new source to pass the test. Before merging, we will re-create the environment and make sure the test(s) pass. +3. Be sure to run the linting/formatting, refer to the formatting and linting section in [CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md#formatting-and-linting) diff --git a/backend/danswer/connectors/confluence/connector.py b/backend/danswer/connectors/confluence/connector.py index b8dc967a3d9..78efce4ab98 100644 --- a/backend/danswer/connectors/confluence/connector.py +++ b/backend/danswer/connectors/confluence/connector.py @@ -7,7 +7,6 @@ from functools import lru_cache from typing import Any from typing import cast -from urllib.parse import urlparse import bs4 from atlassian import Confluence # type:ignore @@ -53,79 +52,6 @@ ) -def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]: - """Sample - URL w/ page: https://danswer.atlassian.net/wiki/spaces/1234abcd/pages/5678efgh/overview - URL w/o page: https://danswer.atlassian.net/wiki/spaces/ASAM/overview - - wiki_base is https://danswer.atlassian.net/wiki - space is 1234abcd - page_id is 5678efgh - """ - parsed_url = urlparse(wiki_url) - wiki_base = ( - parsed_url.scheme - + "://" - + parsed_url.netloc - + parsed_url.path.split("/spaces")[0] - ) - - path_parts = parsed_url.path.split("/") - space = path_parts[3] - - page_id = path_parts[5] if len(path_parts) > 5 else "" - return wiki_base, space, page_id - - -def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str, str]: - """Sample - URL w/ page https://danswer.ai/confluence/display/1234abcd/pages/5678efgh/overview - URL w/o page https://danswer.ai/confluence/display/1234abcd/overview - wiki_base is https://danswer.ai/confluence - space is 1234abcd - page_id is 5678efgh - """ - # /display/ is always right before the space and at the end of the base print() - DISPLAY = "/display/" - PAGE = "/pages/" - - parsed_url = urlparse(wiki_url) - wiki_base = ( - parsed_url.scheme - + "://" - + parsed_url.netloc - + parsed_url.path.split(DISPLAY)[0] - ) - space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0] - page_id = "" - if (content := parsed_url.path.split(PAGE)) and len(content) > 1: - page_id = content[1] - return wiki_base, space, page_id - - -def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]: - is_confluence_cloud = ( - ".atlassian.net/wiki/spaces/" in wiki_url - or ".jira.com/wiki/spaces/" in wiki_url - ) - - try: - if is_confluence_cloud: - wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url( - wiki_url - ) - else: - wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url( - wiki_url - ) - except Exception as e: - error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base, space, and page id. Exception: {e}" - logger.error(error_msg) - raise ValueError(error_msg) - - return wiki_base, space, page_id, is_confluence_cloud - - @lru_cache() def _get_user(user_id: str, confluence_client: Confluence) -> str: """Get Confluence Display Name based on the account-id or userkey value @@ -372,7 +298,10 @@ def _fetch_single_depth_child_pages( class ConfluenceConnector(LoadConnector, PollConnector): def __init__( self, - wiki_page_url: str, + wiki_base: str, + space: str, + is_cloud: bool, + page_id: str = "", index_recursively: bool = True, batch_size: int = INDEX_BATCH_SIZE, continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE, @@ -386,15 +315,15 @@ def __init__( self.labels_to_skip = set(labels_to_skip) self.recursive_indexer: RecursiveIndexer | None = None self.index_recursively = index_recursively - ( - self.wiki_base, - self.space, - self.page_id, - self.is_cloud, - ) = extract_confluence_keys_from_url(wiki_page_url) - self.space_level_scan = False + # Remove trailing slash from wiki_base if present + self.wiki_base = wiki_base.rstrip("/") + self.space = space + self.page_id = page_id + self.is_cloud = is_cloud + + self.space_level_scan = False self.confluence_client: Confluence | None = None if self.page_id is None or self.page_id == "": @@ -414,7 +343,6 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None username=username if self.is_cloud else None, password=access_token if self.is_cloud else None, token=access_token if not self.is_cloud else None, - cloud=self.is_cloud, ) return None @@ -866,7 +794,13 @@ def poll_source( if __name__ == "__main__": - connector = ConfluenceConnector(os.environ["CONFLUENCE_TEST_SPACE_URL"]) + connector = ConfluenceConnector( + wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"], + space=os.environ["CONFLUENCE_TEST_SPACE"], + is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true", + page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""), + index_recursively=True, + ) connector.load_credentials( { "confluence_username": os.environ["CONFLUENCE_USER_NAME"], diff --git a/backend/danswer/connectors/confluence/rate_limit_handler.py b/backend/danswer/connectors/confluence/rate_limit_handler.py index 8755b78f3f4..822badb9b99 100644 --- a/backend/danswer/connectors/confluence/rate_limit_handler.py +++ b/backend/danswer/connectors/confluence/rate_limit_handler.py @@ -23,7 +23,7 @@ class ConfluenceRateLimitError(Exception): def make_confluence_call_handle_rate_limit(confluence_call: F) -> F: def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: - max_retries = 10 + max_retries = 5 starting_delay = 5 backoff = 2 max_delay = 600 @@ -32,17 +32,24 @@ def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: try: return confluence_call(*args, **kwargs) except HTTPError as e: + # Check if the response or headers are None to avoid potential AttributeError + if e.response is None or e.response.headers is None: + logger.warning("HTTPError with `None` as response or as headers") + raise e + + retry_after_header = e.response.headers.get("Retry-After") if ( e.response.status_code == 429 or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower() ): retry_after = None - try: - retry_after = int(e.response.headers.get("Retry-After")) - except (ValueError, TypeError): - pass + if retry_after_header is not None: + try: + retry_after = int(retry_after_header) + except ValueError: + pass - if retry_after: + if retry_after is not None: logger.warning( f"Rate limit hit. Retrying after {retry_after} seconds..." ) diff --git a/backend/danswer/connectors/danswer_jira/connector.py b/backend/danswer/connectors/danswer_jira/connector.py index 9a8fbb31501..e3562f3a45c 100644 --- a/backend/danswer/connectors/danswer_jira/connector.py +++ b/backend/danswer/connectors/danswer_jira/connector.py @@ -45,10 +45,15 @@ def extract_jira_project(url: str) -> tuple[str, str]: return jira_base, jira_project -def extract_text_from_content(content: dict) -> str: +def extract_text_from_adf(adf: dict | None) -> str: + """Extracts plain text from Atlassian Document Format: + https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/ + + WARNING: This function is incomplete and will e.g. skip lists! + """ texts = [] - if "content" in content: - for block in content["content"]: + if adf is not None and "content" in adf: + for block in adf["content"]: if "content" in block: for item in block["content"]: if item["type"] == "text": @@ -72,18 +77,15 @@ def _get_comment_strs( comment_strs = [] for comment in jira.fields.comment.comments: try: - if hasattr(comment, "body"): - body_text = extract_text_from_content(comment.raw["body"]) - elif hasattr(comment, "raw"): - body = comment.raw.get("body", "No body content available") - body_text = ( - extract_text_from_content(body) if isinstance(body, dict) else body - ) - else: - body_text = "No body attribute found" + body_text = ( + comment.body + if JIRA_API_VERSION == "2" + else extract_text_from_adf(comment.raw["body"]) + ) if ( hasattr(comment, "author") + and hasattr(comment.author, "emailAddress") and comment.author.emailAddress in comment_email_blacklist ): continue # Skip adding comment if author's email is in blacklist @@ -126,11 +128,14 @@ def fetch_jira_issues_batch( ) continue + description = ( + jira.fields.description + if JIRA_API_VERSION == "2" + else extract_text_from_adf(jira.raw["fields"]["description"]) + ) comments = _get_comment_strs(jira, comment_email_blacklist) - semantic_rep = ( - f"{jira.fields.description}\n" - if jira.fields.description - else "" + "\n".join([f"Comment: {comment}" for comment in comments]) + semantic_rep = f"{description}\n" + "\n".join( + [f"Comment: {comment}" for comment in comments if comment] ) page_url = f"{jira_client.client_info()}/browse/{jira.key}" diff --git a/backend/danswer/connectors/file/connector.py b/backend/danswer/connectors/file/connector.py index 6c5501734b0..83d0af2c12e 100644 --- a/backend/danswer/connectors/file/connector.py +++ b/backend/danswer/connectors/file/connector.py @@ -23,7 +23,7 @@ from danswer.file_processing.extract_file_text import get_file_ext from danswer.file_processing.extract_file_text import is_text_file_extension from danswer.file_processing.extract_file_text import load_files_from_zip -from danswer.file_processing.extract_file_text import pdf_to_text +from danswer.file_processing.extract_file_text import read_pdf_file from danswer.file_processing.extract_file_text import read_text_file from danswer.file_store.file_store import get_default_file_store from danswer.utils.logger import setup_logger @@ -75,7 +75,7 @@ def _process_file( # Using the PDF reader function directly to pass in password cleanly elif extension == ".pdf": - file_content_raw = pdf_to_text(file=file, pdf_pass=pdf_pass) + file_content_raw, file_metadata = read_pdf_file(file=file, pdf_pass=pdf_pass) else: file_content_raw = extract_file_text( diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 40a9b73432f..80674b5a37d 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -41,8 +41,8 @@ from danswer.connectors.models import Document from danswer.connectors.models import Section from danswer.file_processing.extract_file_text import docx_to_text -from danswer.file_processing.extract_file_text import pdf_to_text from danswer.file_processing.extract_file_text import pptx_to_text +from danswer.file_processing.extract_file_text import read_pdf_file from danswer.utils.batching import batch_generator from danswer.utils.logger import setup_logger @@ -62,6 +62,8 @@ class GDriveMimeType(str, Enum): POWERPOINT = ( "application/vnd.openxmlformats-officedocument.presentationml.presentation" ) + PLAIN_TEXT = "text/plain" + MARKDOWN = "text/markdown" GoogleDriveFileType = dict[str, Any] @@ -316,25 +318,29 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str: GDriveMimeType.PPT.value, GDriveMimeType.SPREADSHEET.value, ]: - export_mime_type = "text/plain" - if mime_type == GDriveMimeType.SPREADSHEET.value: - export_mime_type = "text/csv" - elif mime_type == GDriveMimeType.PPT.value: - export_mime_type = "text/plain" - - response = ( + export_mime_type = ( + "text/plain" + if mime_type != GDriveMimeType.SPREADSHEET.value + else "text/csv" + ) + return ( service.files() .export(fileId=file["id"], mimeType=export_mime_type) .execute() + .decode("utf-8") ) - return response.decode("utf-8") - + elif mime_type in [ + GDriveMimeType.PLAIN_TEXT.value, + GDriveMimeType.MARKDOWN.value, + ]: + return service.files().get_media(fileId=file["id"]).execute().decode("utf-8") elif mime_type == GDriveMimeType.WORD_DOC.value: response = service.files().get_media(fileId=file["id"]).execute() return docx_to_text(file=io.BytesIO(response)) elif mime_type == GDriveMimeType.PDF.value: response = service.files().get_media(fileId=file["id"]).execute() - return pdf_to_text(file=io.BytesIO(response)) + text, _ = read_pdf_file(file=io.BytesIO(response)) + return text elif mime_type == GDriveMimeType.POWERPOINT.value: response = service.files().get_media(fileId=file["id"]).execute() return pptx_to_text(file=io.BytesIO(response)) diff --git a/backend/danswer/connectors/notion/connector.py b/backend/danswer/connectors/notion/connector.py index fd607e4f97a..7878434da04 100644 --- a/backend/danswer/connectors/notion/connector.py +++ b/backend/danswer/connectors/notion/connector.py @@ -237,6 +237,14 @@ def _read_blocks( ) continue + if result_type == "external_object_instance_page": + logger.warning( + f"Skipping 'external_object_instance_page' ('{result_block_id}') for base block '{base_block_id}': " + f"Notion API does not currently support reading external blocks (as of 24/07/03) " + f"(discussion: https://github.com/danswer-ai/danswer/issues/1761)" + ) + continue + cur_result_text_arr = [] if "rich_text" in result_obj: for rich_text in result_obj["rich_text"]: diff --git a/backend/danswer/connectors/productboard/connector.py b/backend/danswer/connectors/productboard/connector.py index 9ef301aa76d..c7a2d45cae8 100644 --- a/backend/danswer/connectors/productboard/connector.py +++ b/backend/danswer/connectors/productboard/connector.py @@ -98,6 +98,15 @@ def _get_features(self) -> Generator[Document, None, None]: owner = self._get_owner_email(feature) experts = [BasicExpertInfo(email=owner)] if owner else None + metadata: dict[str, str | list[str]] = {} + entity_type = feature.get("type", "feature") + if entity_type: + metadata["entity_type"] = str(entity_type) + + status = feature.get("status", {}).get("name") + if status: + metadata["status"] = str(status) + yield Document( id=feature["id"], sections=[ @@ -110,10 +119,7 @@ def _get_features(self) -> Generator[Document, None, None]: source=DocumentSource.PRODUCTBOARD, doc_updated_at=time_str_to_utc(feature["updatedAt"]), primary_owners=experts, - metadata={ - "entity_type": feature["type"], - "status": feature["status"]["name"], - }, + metadata=metadata, ) def _get_components(self) -> Generator[Document, None, None]: @@ -174,6 +180,12 @@ def _get_objectives(self) -> Generator[Document, None, None]: owner = self._get_owner_email(objective) experts = [BasicExpertInfo(email=owner)] if owner else None + metadata: dict[str, str | list[str]] = { + "entity_type": "objective", + } + if objective.get("state"): + metadata["state"] = str(objective["state"]) + yield Document( id=objective["id"], sections=[ @@ -186,10 +198,7 @@ def _get_objectives(self) -> Generator[Document, None, None]: source=DocumentSource.PRODUCTBOARD, doc_updated_at=time_str_to_utc(objective["updatedAt"]), primary_owners=experts, - metadata={ - "entity_type": "release", - "state": objective["state"], - }, + metadata=metadata, ) def _is_updated_at_out_of_time_range( diff --git a/backend/danswer/connectors/sharepoint/connector.py b/backend/danswer/connectors/sharepoint/connector.py index b66c010d77f..e74dcbf7edd 100644 --- a/backend/danswer/connectors/sharepoint/connector.py +++ b/backend/danswer/connectors/sharepoint/connector.py @@ -25,7 +25,6 @@ from danswer.file_processing.extract_file_text import extract_file_text from danswer.utils.logger import setup_logger - logger = setup_logger() @@ -137,7 +136,7 @@ def _populate_sitedata_sites(self) -> None: .execute_query() ] else: - sites = self.graph_client.sites.get().execute_query() + sites = self.graph_client.sites.get_all().execute_query() self.site_data = [ SiteData(url=None, folder=None, sites=sites, driveitems=[]) ] diff --git a/backend/danswer/connectors/slack/connector.py b/backend/danswer/connectors/slack/connector.py index 6c451389932..975653f5f61 100644 --- a/backend/danswer/connectors/slack/connector.py +++ b/backend/danswer/connectors/slack/connector.py @@ -29,6 +29,7 @@ from danswer.connectors.slack.utils import SlackTextCleaner from danswer.utils.logger import setup_logger + logger = setup_logger() diff --git a/backend/danswer/connectors/web/connector.py b/backend/danswer/connectors/web/connector.py index 6e76e404acd..bb1f64efdfe 100644 --- a/backend/danswer/connectors/web/connector.py +++ b/backend/danswer/connectors/web/connector.py @@ -1,6 +1,8 @@ import io import ipaddress import socket +from datetime import datetime +from datetime import timezone from enum import Enum from typing import Any from typing import cast @@ -27,7 +29,7 @@ from danswer.connectors.interfaces import LoadConnector from danswer.connectors.models import Document from danswer.connectors.models import Section -from danswer.file_processing.extract_file_text import pdf_to_text +from danswer.file_processing.extract_file_text import read_pdf_file from danswer.file_processing.html_utils import web_html_cleanup from danswer.utils.logger import setup_logger from danswer.utils.sitemap import list_pages_for_site @@ -85,7 +87,8 @@ def check_internet_connection(url: str) -> None: response = requests.get(url, timeout=3) response.raise_for_status() except requests.exceptions.HTTPError as e: - status_code = e.response.status_code + # Extract status code from the response, defaulting to -1 if response is None + status_code = e.response.status_code if e.response is not None else -1 error_msg = { 400: "Bad Request", 401: "Unauthorized", @@ -202,6 +205,15 @@ def _read_urls_file(location: str) -> list[str]: return urls +def _get_datetime_from_last_modified_header(last_modified: str) -> datetime | None: + try: + return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace( + tzinfo=timezone.utc + ) + except (ValueError, TypeError): + return None + + class WebConnector(LoadConnector): def __init__( self, @@ -284,7 +296,10 @@ def load_from_state(self) -> GenerateDocumentsOutput: if current_url.split(".")[-1] == "pdf": # PDF files are not checked for links response = requests.get(current_url) - page_text = pdf_to_text(file=io.BytesIO(response.content)) + page_text, metadata = read_pdf_file( + file=io.BytesIO(response.content) + ) + last_modified = response.headers.get("Last-Modified") doc_batch.append( Document( @@ -292,13 +307,23 @@ def load_from_state(self) -> GenerateDocumentsOutput: sections=[Section(link=current_url, text=page_text)], source=DocumentSource.WEB, semantic_identifier=current_url.split("/")[-1], - metadata={}, + metadata=metadata, + doc_updated_at=_get_datetime_from_last_modified_header( + last_modified + ) + if last_modified + else None, ) ) continue page = context.new_page() page_response = page.goto(current_url) + last_modified = ( + page_response.header_value("Last-Modified") + if page_response + else None + ) final_page = page.url if final_page != current_url: logger.info(f"Redirected to {final_page}") @@ -334,6 +359,11 @@ def load_from_state(self) -> GenerateDocumentsOutput: source=DocumentSource.WEB, semantic_identifier=parsed_html.title or current_url, metadata={}, + doc_updated_at=_get_datetime_from_last_modified_header( + last_modified + ) + if last_modified + else None, ) ) diff --git a/backend/danswer/connectors/zendesk/connector.py b/backend/danswer/connectors/zendesk/connector.py index b6d4220b9ce..f85f2efff57 100644 --- a/backend/danswer/connectors/zendesk/connector.py +++ b/backend/danswer/connectors/zendesk/connector.py @@ -3,6 +3,7 @@ import requests from retry import retry from zenpy import Zenpy # type: ignore +from zenpy.lib.api_objects import Ticket # type: ignore from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore from danswer.configs.app_configs import INDEX_BATCH_SIZE @@ -59,10 +60,15 @@ def __init__(self) -> None: class ZendeskConnector(LoadConnector, PollConnector): - def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None: + def __init__( + self, + batch_size: int = INDEX_BATCH_SIZE, + content_type: str = "articles", + ) -> None: self.batch_size = batch_size self.zendesk_client: Zenpy | None = None self.content_tags: dict[str, str] = {} + self.content_type = content_type @retry(tries=3, delay=2, backoff=2) def _set_content_tags( @@ -122,16 +128,86 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None def load_from_state(self) -> GenerateDocumentsOutput: return self.poll_source(None, None) + def _ticket_to_document(self, ticket: Ticket) -> Document: + if self.zendesk_client is None: + raise ZendeskClientNotSetUpError() + + owner = None + if ticket.requester and ticket.requester.name and ticket.requester.email: + owner = [ + BasicExpertInfo( + display_name=ticket.requester.name, email=ticket.requester.email + ) + ] + update_time = time_str_to_utc(ticket.updated_at) if ticket.updated_at else None + + metadata: dict[str, str | list[str]] = {} + if ticket.status is not None: + metadata["status"] = ticket.status + if ticket.priority is not None: + metadata["priority"] = ticket.priority + if ticket.tags: + metadata["tags"] = ticket.tags + if ticket.type is not None: + metadata["ticket_type"] = ticket.type + + # Fetch comments for the ticket + comments = self.zendesk_client.tickets.comments(ticket=ticket) + + # Combine all comments into a single text + comments_text = "\n\n".join( + [ + f"Comment{f' by {comment.author.name}' if comment.author and comment.author.name else ''}" + f"{f' at {comment.created_at}' if comment.created_at else ''}:\n{comment.body}" + for comment in comments + if comment.body + ] + ) + + # Combine ticket description and comments + description = ( + ticket.description + if hasattr(ticket, "description") and ticket.description + else "" + ) + full_text = f"Ticket Description:\n{description}\n\nComments:\n{comments_text}" + + # Extract subdomain from ticket.url + subdomain = ticket.url.split("//")[1].split(".zendesk.com")[0] + + # Build the html url for the ticket + ticket_url = f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.id}" + + return Document( + id=f"zendesk_ticket_{ticket.id}", + sections=[Section(link=ticket_url, text=full_text)], + source=DocumentSource.ZENDESK, + semantic_identifier=f"Ticket #{ticket.id}: {ticket.subject or 'No Subject'}", + doc_updated_at=update_time, + primary_owners=owner, + metadata=metadata, + ) + def poll_source( self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None ) -> GenerateDocumentsOutput: if self.zendesk_client is None: raise ZendeskClientNotSetUpError() + if self.content_type == "articles": + yield from self._poll_articles(start) + elif self.content_type == "tickets": + yield from self._poll_tickets(start) + else: + raise ValueError(f"Unsupported content_type: {self.content_type}") + + def _poll_articles( + self, start: SecondsSinceUnixEpoch | None + ) -> GenerateDocumentsOutput: articles = ( - self.zendesk_client.help_center.articles(cursor_pagination=True) + self.zendesk_client.help_center.articles(cursor_pagination=True) # type: ignore if start is None - else self.zendesk_client.help_center.articles.incremental( + else self.zendesk_client.help_center.articles.incremental( # type: ignore start_time=int(start) ) ) @@ -155,9 +231,43 @@ def poll_source( if doc_batch: yield doc_batch + def _poll_tickets( + self, start: SecondsSinceUnixEpoch | None + ) -> GenerateDocumentsOutput: + if self.zendesk_client is None: + raise ZendeskClientNotSetUpError() + + ticket_generator = self.zendesk_client.tickets.incremental(start_time=start) + + while True: + doc_batch = [] + for _ in range(self.batch_size): + try: + ticket = next(ticket_generator) + + # Check if the ticket status is deleted and skip it if so + if ticket.status == "deleted": + continue + + doc_batch.append(self._ticket_to_document(ticket)) + + if len(doc_batch) >= self.batch_size: + yield doc_batch + doc_batch.clear() + + except StopIteration: + # No more tickets to process + if doc_batch: + yield doc_batch + return + + if doc_batch: + yield doc_batch + if __name__ == "__main__": import os + import time connector = ZendeskConnector() diff --git a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py index 732be8df9db..9e1c171ee4f 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py @@ -11,6 +11,7 @@ from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI +from danswer.connectors.slack.utils import expert_info_from_slack_id from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.danswerbot.slack.blocks import build_follow_up_resolved_blocks from danswer.danswerbot.slack.blocks import get_document_feedback_blocks @@ -87,6 +88,8 @@ def handle_generate_answer_button( message_ts = req.payload["message"]["ts"] thread_ts = req.payload["container"]["thread_ts"] user_id = req.payload["user"]["id"] + expert_info = expert_info_from_slack_id(user_id, client.web_client, user_cache={}) + email = expert_info.email if expert_info else None if not thread_ts: raise ValueError("Missing thread_ts in the payload") @@ -125,6 +128,7 @@ def handle_generate_answer_button( msg_to_respond=cast(str, message_ts or thread_ts), thread_to_respond=cast(str, thread_ts or message_ts), sender=user_id or None, + email=email or None, bypass_filters=True, is_bot_msg=False, is_bot_dm=False, diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index 2edbd973553..cce45331ee7 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -21,6 +21,7 @@ from danswer.danswerbot.slack.utils import update_emote_react from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import SlackBotConfig +from danswer.db.users import add_non_web_user_if_not_exists from danswer.utils.logger import setup_logger from shared_configs.configs import SLACK_CHANNEL_ID @@ -209,6 +210,9 @@ def handle_message( logger.error(f"Was not able to react to user message due to: {e}") with Session(get_sqlalchemy_engine()) as db_session: + if message_info.email: + add_non_web_user_if_not_exists(message_info.email, db_session) + # first check if we need to respond with a standard answer used_standard_answer = handle_standard_answers( message_info=message_info, diff --git a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py index e3a78917a76..09ea4e05332 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py @@ -38,6 +38,7 @@ from danswer.db.models import SlackBotResponseType from danswer.db.persona import fetch_persona_by_id from danswer.db.search_settings import get_current_search_settings +from danswer.db.users import get_user_by_email from danswer.llm.answering.prompts.citations_prompt import ( compute_max_document_tokens_for_persona, ) @@ -99,6 +100,12 @@ def handle_regular_answer( messages = message_info.thread_messages message_ts_to_respond_to = message_info.msg_to_respond is_bot_msg = message_info.is_bot_msg + user = None + if message_info.is_bot_dm: + if message_info.email: + engine = get_sqlalchemy_engine() + with Session(engine) as db_session: + user = get_user_by_email(message_info.email, db_session) document_set_names: list[str] | None = None persona = slack_bot_config.persona if slack_bot_config else None @@ -185,7 +192,7 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non # This also handles creating the query event in postgres answer = get_search_answer( query_req=new_message_request, - user=None, + user=user, max_document_tokens=max_document_tokens, max_history_tokens=max_history_tokens, db_session=db_session, diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index c59f4caf1aa..63f8bcfcd9c 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -13,6 +13,7 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER +from danswer.connectors.slack.utils import expert_info_from_slack_id from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID @@ -256,6 +257,11 @@ def build_request_details( tagged = event.get("type") == "app_mention" message_ts = event.get("ts") thread_ts = event.get("thread_ts") + sender = event.get("user") or None + expert_info = expert_info_from_slack_id( + sender, client.web_client, user_cache={} + ) + email = expert_info.email if expert_info else None msg = remove_danswer_bot_tag(msg, client=client.web_client) @@ -286,7 +292,8 @@ def build_request_details( channel_to_respond=channel, msg_to_respond=cast(str, message_ts or thread_ts), thread_to_respond=cast(str, thread_ts or message_ts), - sender=event.get("user") or None, + sender=sender, + email=email, bypass_filters=tagged, is_bot_msg=False, is_bot_dm=event.get("channel_type") == "im", @@ -296,6 +303,10 @@ def build_request_details( channel = req.payload["channel_id"] msg = req.payload["text"] sender = req.payload["user_id"] + expert_info = expert_info_from_slack_id( + sender, client.web_client, user_cache={} + ) + email = expert_info.email if expert_info else None single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER) @@ -305,6 +316,7 @@ def build_request_details( msg_to_respond=None, thread_to_respond=None, sender=sender, + email=email, bypass_filters=True, is_bot_msg=True, is_bot_dm=False, diff --git a/backend/danswer/danswerbot/slack/models.py b/backend/danswer/danswerbot/slack/models.py index e4521a759a7..6394eab562d 100644 --- a/backend/danswer/danswerbot/slack/models.py +++ b/backend/danswer/danswerbot/slack/models.py @@ -9,6 +9,7 @@ class SlackMessageInfo(BaseModel): msg_to_respond: str | None thread_to_respond: str | None sender: str | None + email: str | None bypass_filters: bool # User has tagged @DanswerBot is_bot_msg: bool # User is using /DanswerBot is_bot_dm: bool # User is direct messaging to DanswerBot diff --git a/backend/danswer/db/auth.py b/backend/danswer/db/auth.py index 7710232d01f..9b54e82cc1f 100644 --- a/backend/danswer/db/auth.py +++ b/backend/danswer/db/auth.py @@ -28,7 +28,7 @@ def get_default_admin_user_emails() -> list[str]: get_default_admin_user_emails_fn: Callable[ [], list[str] ] = fetch_versioned_implementation_with_fallback( - "danswer.auth.users", "get_default_admin_user_emails_", lambda: [] + "danswer.auth.users", "get_default_admin_user_emails_", lambda: list[str]() ) return get_default_admin_user_emails_fn() diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 3cb991dd43b..8485bb4f0ae 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -3,7 +3,6 @@ from datetime import timedelta from uuid import UUID -from sqlalchemy import and_ from sqlalchemy import delete from sqlalchemy import desc from sqlalchemy import func @@ -87,29 +86,57 @@ def get_chat_sessions_by_slack_thread_id( return db_session.scalars(stmt).all() -def get_first_messages_for_chat_sessions( - chat_session_ids: list[int], db_session: Session +def get_valid_messages_from_query_sessions( + chat_session_ids: list[int], + db_session: Session, ) -> dict[int, str]: - subquery = ( - select(ChatMessage.chat_session_id, func.min(ChatMessage.id).label("min_id")) + user_message_subquery = ( + select( + ChatMessage.chat_session_id, func.min(ChatMessage.id).label("user_msg_id") + ) .where( - and_( - ChatMessage.chat_session_id.in_(chat_session_ids), - ChatMessage.message_type == MessageType.USER, # Select USER messages - ) + ChatMessage.chat_session_id.in_(chat_session_ids), + ChatMessage.message_type == MessageType.USER, + ) + .group_by(ChatMessage.chat_session_id) + .subquery() + ) + + assistant_message_subquery = ( + select( + ChatMessage.chat_session_id, + func.min(ChatMessage.id).label("assistant_msg_id"), + ) + .where( + ChatMessage.chat_session_id.in_(chat_session_ids), + ChatMessage.message_type == MessageType.ASSISTANT, ) .group_by(ChatMessage.chat_session_id) .subquery() ) - query = select(ChatMessage.chat_session_id, ChatMessage.message).join( - subquery, - (ChatMessage.chat_session_id == subquery.c.chat_session_id) - & (ChatMessage.id == subquery.c.min_id), + query = ( + select(ChatMessage.chat_session_id, ChatMessage.message) + .join( + user_message_subquery, + ChatMessage.chat_session_id == user_message_subquery.c.chat_session_id, + ) + .join( + assistant_message_subquery, + ChatMessage.chat_session_id == assistant_message_subquery.c.chat_session_id, + ) + .join( + ChatMessage__SearchDoc, + ChatMessage__SearchDoc.chat_message_id + == assistant_message_subquery.c.assistant_msg_id, + ) + .where(ChatMessage.id == user_message_subquery.c.user_msg_id) ) first_messages = db_session.execute(query).all() - return dict([(row.chat_session_id, row.message) for row in first_messages]) + logger.info(f"Retrieved {len(first_messages)} first messages with documents") + + return {row.chat_session_id: row.message for row in first_messages} def get_chat_sessions_by_user( @@ -253,6 +280,13 @@ def delete_chat_session( db_session: Session, hard_delete: bool = HARD_DELETE_CHATS, ) -> None: + chat_session = get_chat_session_by_id( + chat_session_id=chat_session_id, user_id=user_id, db_session=db_session + ) + + if chat_session.deleted: + raise ValueError("Cannot delete an already deleted chat session") + if hard_delete: delete_messages_and_files_from_chat_session(chat_session_id, db_session) db_session.execute(delete(ChatSession).where(ChatSession.id == chat_session_id)) diff --git a/backend/danswer/db/connector_credential_pair.py b/backend/danswer/db/connector_credential_pair.py index a6848232caf..f35aed9186c 100644 --- a/backend/danswer/db/connector_credential_pair.py +++ b/backend/danswer/db/connector_credential_pair.py @@ -159,6 +159,7 @@ def get_connector_credential_pair_from_id( def get_last_successful_attempt_time( connector_id: int, credential_id: int, + earliest_index: float, search_settings: SearchSettings, db_session: Session, ) -> float: @@ -172,7 +173,7 @@ def get_last_successful_attempt_time( connector_credential_pair is None or connector_credential_pair.last_successful_index_time is None ): - return 0.0 + return earliest_index return connector_credential_pair.last_successful_index_time.timestamp() @@ -192,11 +193,9 @@ def get_last_successful_attempt_time( .order_by(IndexAttempt.time_started.desc()) .first() ) + if not attempt or not attempt.time_started: - connector = fetch_connector_by_id(connector_id, db_session) - if connector and connector.indexing_start: - return connector.indexing_start.timestamp() - return 0.0 + return earliest_index return attempt.time_started.timestamp() @@ -335,9 +334,13 @@ def add_credential_to_connector( raise HTTPException(status_code=404, detail="Connector does not exist") if credential is None: + error_msg = ( + f"Credential {credential_id} does not exist or does not belong to user" + ) + logger.error(error_msg) raise HTTPException( status_code=401, - detail="Credential does not exist or does not belong to user", + detail=error_msg, ) existing_association = ( @@ -351,7 +354,7 @@ def add_credential_to_connector( if existing_association is not None: return StatusResponse( success=False, - message=f"Connector already has Credential {credential_id}", + message=f"Connector {connector_id} already has Credential {credential_id}", data=connector_id, ) @@ -375,8 +378,8 @@ def add_credential_to_connector( db_session.commit() return StatusResponse( - success=False, - message=f"Connector already has Credential {credential_id}", + success=True, + message=f"Creating new association between Connector {connector_id} and Credential {credential_id}", data=association.id, ) diff --git a/backend/danswer/db/document_set.py b/backend/danswer/db/document_set.py index 2de61a491f9..c2900593835 100644 --- a/backend/danswer/db/document_set.py +++ b/backend/danswer/db/document_set.py @@ -524,37 +524,66 @@ def fetch_document_sets_for_documents( db_session: Session, ) -> Sequence[tuple[str, list[str]]]: """Gives back a list of (document_id, list[document_set_names]) tuples""" + + """Building subqueries""" + # NOTE: have to build these subqueries first in order to guarantee that we get one + # returned row for each specified document_id. Basically, we want to do the filters first, + # then the outer joins. + + # don't include CC pairs that are being deleted + # NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them + # as we can assume their document sets are no longer relevant + valid_cc_pairs_subquery = aliased( + ConnectorCredentialPair, + select(ConnectorCredentialPair) + .where( + ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING + ) # noqa: E712 + .subquery(), + ) + + valid_document_set__cc_pairs_subquery = aliased( + DocumentSet__ConnectorCredentialPair, + select(DocumentSet__ConnectorCredentialPair) + .where(DocumentSet__ConnectorCredentialPair.is_current == True) # noqa: E712 + .subquery(), + ) + """End building subqueries""" + stmt = ( - select(Document.id, func.array_agg(DocumentSetDBModel.name)) - .join( - DocumentSet__ConnectorCredentialPair, - DocumentSetDBModel.id - == DocumentSet__ConnectorCredentialPair.document_set_id, - ) - .join( - ConnectorCredentialPair, - ConnectorCredentialPair.id - == DocumentSet__ConnectorCredentialPair.connector_credential_pair_id, + select( + Document.id, + func.coalesce( + func.array_remove(func.array_agg(DocumentSetDBModel.name), None), [] + ).label("document_set_names"), ) - .join( + # Here we select document sets by relation: + # Document -> DocumentByConnectorCredentialPair -> ConnectorCredentialPair -> + # DocumentSet__ConnectorCredentialPair -> DocumentSet + .outerjoin( DocumentByConnectorCredentialPair, + Document.id == DocumentByConnectorCredentialPair.id, + ) + .outerjoin( + valid_cc_pairs_subquery, and_( DocumentByConnectorCredentialPair.connector_id - == ConnectorCredentialPair.connector_id, + == valid_cc_pairs_subquery.connector_id, DocumentByConnectorCredentialPair.credential_id - == ConnectorCredentialPair.credential_id, + == valid_cc_pairs_subquery.credential_id, ), ) - .join( - Document, - Document.id == DocumentByConnectorCredentialPair.id, + .outerjoin( + valid_document_set__cc_pairs_subquery, + valid_cc_pairs_subquery.id + == valid_document_set__cc_pairs_subquery.connector_credential_pair_id, + ) + .outerjoin( + DocumentSetDBModel, + DocumentSetDBModel.id + == valid_document_set__cc_pairs_subquery.document_set_id, ) .where(Document.id.in_(document_ids)) - # don't include CC pairs that are being deleted - # NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them - # as we can assume their document sets are no longer relevant - .where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING) - .where(DocumentSet__ConnectorCredentialPair.is_current == True) # noqa: E712 .group_by(Document.id) ) return db_session.execute(stmt).all() # type: ignore diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py index 0932d500bbd..056e4ce968b 100644 --- a/backend/danswer/db/index_attempt.py +++ b/backend/danswer/db/index_attempt.py @@ -211,12 +211,12 @@ def get_latest_index_attempts( return db_session.execute(stmt).scalars().all() -def get_index_attempts_for_connector( +def count_index_attempts_for_connector( db_session: Session, connector_id: int, only_current: bool = True, disinclude_finished: bool = False, -) -> Sequence[IndexAttempt]: +) -> int: stmt = ( select(IndexAttempt) .join(ConnectorCredentialPair) @@ -232,23 +232,60 @@ def get_index_attempts_for_connector( stmt = stmt.join(SearchSettings).where( SearchSettings.status == IndexModelStatus.PRESENT ) + # Count total items for pagination + count_stmt = stmt.with_only_columns(func.count()).order_by(None) + total_count = db_session.execute(count_stmt).scalar_one() + return total_count - stmt = stmt.order_by(IndexAttempt.time_created.desc()) - return db_session.execute(stmt).scalars().all() +def get_paginated_index_attempts_for_cc_pair_id( + db_session: Session, + connector_id: int, + page: int, + page_size: int, + only_current: bool = True, + disinclude_finished: bool = False, +) -> list[IndexAttempt]: + stmt = ( + select(IndexAttempt) + .join(ConnectorCredentialPair) + .where(ConnectorCredentialPair.connector_id == connector_id) + ) + if disinclude_finished: + stmt = stmt.where( + IndexAttempt.status.in_( + [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] + ) + ) + if only_current: + stmt = stmt.join(SearchSettings).where( + SearchSettings.status == IndexModelStatus.PRESENT + ) + + stmt = stmt.order_by(IndexAttempt.time_started.desc()) -def get_latest_finished_index_attempt_for_cc_pair( + # Apply pagination + stmt = stmt.offset((page - 1) * page_size).limit(page_size) + + return list(db_session.execute(stmt).scalars().all()) + + +def get_latest_index_attempt_for_cc_pair_id( + db_session: Session, connector_credential_pair_id: int, secondary_index: bool, - db_session: Session, + only_finished: bool = True, ) -> IndexAttempt | None: - stmt = select(IndexAttempt).distinct() + stmt = select(IndexAttempt) stmt = stmt.where( IndexAttempt.connector_credential_pair_id == connector_credential_pair_id, - IndexAttempt.status.not_in( - [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] - ), ) + if only_finished: + stmt = stmt.where( + IndexAttempt.status.not_in( + [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] + ), + ) if secondary_index: stmt = stmt.join(SearchSettings).where( SearchSettings.status == IndexModelStatus.FUTURE @@ -295,14 +332,11 @@ def get_index_attempts_for_cc_pair( def delete_index_attempts( - connector_id: int, - credential_id: int, + cc_pair_id: int, db_session: Session, ) -> None: stmt = delete(IndexAttempt).where( - IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id, - ConnectorCredentialPair.connector_id == connector_id, - ConnectorCredentialPair.credential_id == credential_id, + IndexAttempt.connector_credential_pair_id == cc_pair_id, ) db_session.execute(stmt) diff --git a/backend/danswer/db/llm.py b/backend/danswer/db/llm.py index 152cb130573..a68beadc084 100644 --- a/backend/danswer/db/llm.py +++ b/backend/danswer/db/llm.py @@ -6,6 +6,7 @@ from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel from danswer.db.models import LLMProvider as LLMProviderModel from danswer.db.models import LLMProvider__UserGroup +from danswer.db.models import SearchSettings from danswer.db.models import User from danswer.db.models import User__UserGroup from danswer.server.manage.embedding.models import CloudEmbeddingProvider @@ -50,6 +51,7 @@ def upsert_cloud_embedding_provider( setattr(existing_provider, key, value) else: new_provider = CloudEmbeddingProviderModel(**provider.model_dump()) + db_session.add(new_provider) existing_provider = new_provider db_session.commit() @@ -58,7 +60,7 @@ def upsert_cloud_embedding_provider( def upsert_llm_provider( - db_session: Session, llm_provider: LLMProviderUpsertRequest + llm_provider: LLMProviderUpsertRequest, db_session: Session ) -> FullLLMProvider: existing_llm_provider = db_session.scalar( select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name) @@ -157,12 +159,19 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | def remove_embedding_provider( db_session: Session, provider_type: EmbeddingProvider ) -> None: + db_session.execute( + delete(SearchSettings).where(SearchSettings.provider_type == provider_type) + ) + + # Delete the embedding provider db_session.execute( delete(CloudEmbeddingProviderModel).where( CloudEmbeddingProviderModel.provider_type == provider_type ) ) + db_session.commit() + def remove_llm_provider(db_session: Session, provider_id: int) -> None: # Remove LLMProvider's dependent relationships @@ -178,7 +187,7 @@ def remove_llm_provider(db_session: Session, provider_id: int) -> None: db_session.commit() -def update_default_provider(db_session: Session, provider_id: int) -> None: +def update_default_provider(provider_id: int, db_session: Session) -> None: new_default = db_session.scalar( select(LLMProviderModel).where(LLMProviderModel.id == provider_id) ) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 3cdec323961..ffc12323a52 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -61,7 +61,7 @@ class Base(DeclarativeBase): - pass + __abstract__ = True class EncryptedString(TypeDecorator): @@ -157,6 +157,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base): notifications: Mapped[list["Notification"]] = relationship( "Notification", back_populates="user" ) + # Whether the user has logged in via web. False if user has only used Danswer through Slack bot + has_web_login: Mapped[bool] = mapped_column(Boolean, default=True) class InputPrompt(Base): @@ -448,7 +450,7 @@ class Document(Base): ) tags = relationship( "Tag", - secondary="document__tag", + secondary=Document__Tag.__table__, back_populates="documents", ) @@ -465,7 +467,7 @@ class Tag(Base): documents = relationship( "Document", - secondary="document__tag", + secondary=Document__Tag.__table__, back_populates="tags", ) @@ -576,6 +578,8 @@ class SearchSettings(Base): Enum(RerankerProvider, native_enum=False), nullable=True ) rerank_api_key: Mapped[str | None] = mapped_column(String, nullable=True) + rerank_api_url: Mapped[str | None] = mapped_column(String, nullable=True) + num_rerank: Mapped[int] = mapped_column(Integer, default=NUM_POSTPROCESSED_RESULTS) cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship( @@ -607,6 +611,10 @@ def __repr__(self) -> str: return f"" + @property + def api_url(self) -> str | None: + return self.cloud_provider.api_url if self.cloud_provider is not None else None + @property def api_key(self) -> str | None: return self.cloud_provider.api_key if self.cloud_provider is not None else None @@ -671,7 +679,11 @@ class IndexAttempt(Base): "SearchSettings", back_populates="index_attempts" ) - error_rows = relationship("IndexAttemptError", back_populates="index_attempt") + error_rows = relationship( + "IndexAttemptError", + back_populates="index_attempt", + cascade="all, delete-orphan", + ) __table_args__ = ( Index( @@ -806,7 +818,7 @@ class SearchDoc(Base): chat_messages = relationship( "ChatMessage", - secondary="chat_message__search_doc", + secondary=ChatMessage__SearchDoc.__table__, back_populates="search_docs", ) @@ -949,7 +961,7 @@ class ChatMessage(Base): ) search_docs: Mapped[list["SearchDoc"]] = relationship( "SearchDoc", - secondary="chat_message__search_doc", + secondary=ChatMessage__SearchDoc.__table__, back_populates="chat_messages", ) # NOTE: Should always be attached to the `assistant` message. @@ -1085,6 +1097,7 @@ class CloudEmbeddingProvider(Base): provider_type: Mapped[EmbeddingProvider] = mapped_column( Enum(EmbeddingProvider), primary_key=True ) + api_url: Mapped[str | None] = mapped_column(String, nullable=True) api_key: Mapped[str | None] = mapped_column(EncryptedString()) search_settings: Mapped[list["SearchSettings"]] = relationship( "SearchSettings", @@ -1400,7 +1413,7 @@ class TaskQueueState(Base): __tablename__ = "task_queue_jobs" id: Mapped[int] = mapped_column(primary_key=True) - # Celery task id + # Celery task id. currently only for readability/diagnostics task_id: Mapped[str] = mapped_column(String) # For any job type, this would be the same task_name: Mapped[str] = mapped_column(String) diff --git a/backend/danswer/db/search_settings.py b/backend/danswer/db/search_settings.py index 1d0c218e10a..01f458493f7 100644 --- a/backend/danswer/db/search_settings.py +++ b/backend/danswer/db/search_settings.py @@ -1,3 +1,5 @@ +from sqlalchemy import and_ +from sqlalchemy import delete from sqlalchemy import select from sqlalchemy.orm import Session @@ -13,6 +15,7 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.llm import fetch_embedding_provider from danswer.db.models import CloudEmbeddingProvider +from danswer.db.models import IndexAttempt from danswer.db.models import IndexModelStatus from danswer.db.models import SearchSettings from danswer.indexing.models import IndexingSetting @@ -89,6 +92,30 @@ def get_current_db_embedding_provider( return current_embedding_provider +def delete_search_settings(db_session: Session, search_settings_id: int) -> None: + current_settings = get_current_search_settings(db_session) + + if current_settings.id == search_settings_id: + raise ValueError("Cannot delete currently active search settings") + + # First, delete associated index attempts + index_attempts_query = delete(IndexAttempt).where( + IndexAttempt.search_settings_id == search_settings_id + ) + db_session.execute(index_attempts_query) + + # Then, delete the search settings + search_settings_query = delete(SearchSettings).where( + and_( + SearchSettings.id == search_settings_id, + SearchSettings.status != IndexModelStatus.PRESENT, + ) + ) + + db_session.execute(search_settings_query) + db_session.commit() + + def get_current_search_settings(db_session: Session) -> SearchSettings: query = ( select(SearchSettings) @@ -115,6 +142,13 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None: return latest_settings +def get_all_search_settings(db_session: Session) -> list[SearchSettings]: + query = select(SearchSettings).order_by(SearchSettings.id.desc()) + result = db_session.execute(query) + all_settings = result.scalars().all() + return list(all_settings) + + def get_multilingual_expansion(db_session: Session | None = None) -> list[str]: if db_session is None: with Session(get_sqlalchemy_engine()) as db_session: @@ -234,6 +268,7 @@ def get_old_default_embedding_model() -> IndexingSetting: passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""), index_name="danswer_chunk", multipass_indexing=False, + api_url=None, ) @@ -246,4 +281,5 @@ def get_new_default_embedding_model() -> IndexingSetting: passage_prefix=ASYM_PASSAGE_PREFIX, index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}", multipass_indexing=False, + api_url=None, ) diff --git a/backend/danswer/db/tasks.py b/backend/danswer/db/tasks.py index 23a7edc9882..a7aec90d260 100644 --- a/backend/danswer/db/tasks.py +++ b/backend/danswer/db/tasks.py @@ -44,12 +44,11 @@ def get_latest_task_by_type( def register_task( - task_id: str, task_name: str, db_session: Session, ) -> TaskQueueState: new_task = TaskQueueState( - task_id=task_id, task_name=task_name, status=TaskStatus.PENDING + task_id="", task_name=task_name, status=TaskStatus.PENDING ) db_session.add(new_task) diff --git a/backend/danswer/db/users.py b/backend/danswer/db/users.py index d824ccfd921..61ba6e475fe 100644 --- a/backend/danswer/db/users.py +++ b/backend/danswer/db/users.py @@ -1,9 +1,11 @@ from collections.abc import Sequence from uuid import UUID +from fastapi_users.password import PasswordHelper from sqlalchemy import select from sqlalchemy.orm import Session +from danswer.auth.schemas import UserRole from danswer.db.models import User @@ -30,3 +32,22 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None: user = db_session.query(User).filter(User.id == user_id).first() # type: ignore return user + + +def add_non_web_user_if_not_exists(email: str, db_session: Session) -> User: + user = get_user_by_email(email, db_session) + if user is not None: + return user + + fastapi_users_pw_helper = PasswordHelper() + password = fastapi_users_pw_helper.generate() + hashed_pass = fastapi_users_pw_helper.hash(password) + user = User( + email=email, + hashed_password=hashed_pass, + has_web_login=False, + role=UserRole.BASIC, + ) + db_session.add(user) + db_session.commit() + return user diff --git a/backend/danswer/document_index/vespa/chunk_retrieval.py b/backend/danswer/document_index/vespa/chunk_retrieval.py index 6a7427630b8..e4b2ad83ce2 100644 --- a/backend/danswer/document_index/vespa/chunk_retrieval.py +++ b/backend/danswer/document_index/vespa/chunk_retrieval.py @@ -30,6 +30,7 @@ from danswer.document_index.vespa_constants import HIDDEN from danswer.document_index.vespa_constants import LARGE_CHUNK_REFERENCE_IDS from danswer.document_index.vespa_constants import MAX_ID_SEARCH_QUERY_SIZE +from danswer.document_index.vespa_constants import MAX_OR_CONDITIONS from danswer.document_index.vespa_constants import METADATA from danswer.document_index.vespa_constants import METADATA_SUFFIX from danswer.document_index.vespa_constants import PRIMARY_OWNERS @@ -292,12 +293,11 @@ def query_vespa( if LOG_VESPA_TIMING_INFORMATION else {}, ) - - response = requests.post( - SEARCH_ENDPOINT, - json=params, - ) try: + response = requests.post( + SEARCH_ENDPOINT, + json=params, + ) response.raise_for_status() except requests.HTTPError as e: request_info = f"Headers: {response.request.headers}\nPayload: {params}" @@ -319,6 +319,12 @@ def query_vespa( logger.debug("Vespa timing info: %s", response_json.get("timing")) hits = response_json["root"].get("children", []) + if not hits: + logger.warning( + f"No hits found for YQL Query: {query_params.get('yql', 'No YQL Query')}" + ) + logger.debug(f"Vespa Response: {response.text}") + for hit in hits: if hit["fields"].get(CONTENT) is None: identifier = hit["fields"].get("documentid") or hit["id"] @@ -379,7 +385,7 @@ def batch_search_api_retrieval( capped_requests: list[VespaChunkRequest] = [] uncapped_requests: list[VespaChunkRequest] = [] chunk_count = 0 - for request in chunk_requests: + for req_ind, request in enumerate(chunk_requests, start=1): # All requests without a chunk range are uncapped # Uncapped requests are retrieved using the Visit API range = request.range @@ -387,9 +393,10 @@ def batch_search_api_retrieval( uncapped_requests.append(request) continue - # If adding the range to the chunk count is greater than the - # max query size, we need to perform a retrieval to avoid hitting the limit - if chunk_count + range > MAX_ID_SEARCH_QUERY_SIZE: + if ( + chunk_count + range > MAX_ID_SEARCH_QUERY_SIZE + or req_ind % MAX_OR_CONDITIONS == 0 + ): retrieved_chunks.extend( _get_chunks_via_batch_search( index_name=index_name, diff --git a/backend/danswer/document_index/vespa_constants.py b/backend/danswer/document_index/vespa_constants.py index 0b8949b4264..07d2f3f74e0 100644 --- a/backend/danswer/document_index/vespa_constants.py +++ b/backend/danswer/document_index/vespa_constants.py @@ -25,6 +25,9 @@ 32 # since Vespa doesn't allow batching of inserts / updates, we use threads ) MAX_ID_SEARCH_QUERY_SIZE = 400 +# Suspect that adding too many "or" conditions will cause Vespa to timeout and return +# an empty list of hits (with no error status and coverage: 0 and degraded) +MAX_OR_CONDITIONS = 10 # up from 500ms for now, since we've seen quite a few timeouts # in the long term, we are looking to improve the performance of Vespa # so that we can bring this back to default diff --git a/backend/danswer/file_processing/extract_file_text.py b/backend/danswer/file_processing/extract_file_text.py index 7143b428714..36df08ac465 100644 --- a/backend/danswer/file_processing/extract_file_text.py +++ b/backend/danswer/file_processing/extract_file_text.py @@ -8,6 +8,7 @@ from email.parser import Parser as EmailParser from pathlib import Path from typing import Any +from typing import Dict from typing import IO import chardet @@ -178,6 +179,17 @@ def read_text_file( def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str: + """Extract text from a PDF file.""" + # Return only the extracted text from read_pdf_file + text, _ = read_pdf_file(file, pdf_pass) + return text + + +def read_pdf_file( + file: IO[Any], + pdf_pass: str | None = None, +) -> tuple[str, dict]: + metadata: Dict[str, Any] = {} try: pdf_reader = PdfReader(file) @@ -189,16 +201,33 @@ def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str: decrypt_success = pdf_reader.decrypt(pdf_pass) != 0 except Exception: logger.error("Unable to decrypt pdf") - else: - logger.warning("No Password available to to decrypt pdf") if not decrypt_success: # By user request, keep files that are unreadable just so they # can be discoverable by title. - return "" - - return TEXT_SECTION_SEPARATOR.join( - page.extract_text() for page in pdf_reader.pages + return "", metadata + else: + logger.warning("No Password available to to decrypt pdf") + + # Extract metadata from the PDF, removing leading '/' from keys if present + # This standardizes the metadata keys for consistency + metadata = {} + if pdf_reader.metadata is not None: + for key, value in pdf_reader.metadata.items(): + clean_key = key.lstrip("/") + if isinstance(value, str) and value.strip(): + metadata[clean_key] = value + + elif isinstance(value, list) and all( + isinstance(item, str) for item in value + ): + metadata[clean_key] = ", ".join(value) + + return ( + TEXT_SECTION_SEPARATOR.join( + page.extract_text() for page in pdf_reader.pages + ), + metadata, ) except PdfStreamError: logger.exception("PDF file is not a valid PDF") @@ -207,13 +236,47 @@ def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str: # File is still discoverable by title # but the contents are not included as they cannot be parsed - return "" + return "", metadata def docx_to_text(file: IO[Any]) -> str: + def is_simple_table(table: docx.table.Table) -> bool: + for row in table.rows: + # No omitted cells + if row.grid_cols_before > 0 or row.grid_cols_after > 0: + return False + + # No nested tables + if any(cell.tables for cell in row.cells): + return False + + return True + + def extract_cell_text(cell: docx.table._Cell) -> str: + cell_paragraphs = [para.text.strip() for para in cell.paragraphs] + return " ".join(p for p in cell_paragraphs if p) or "N/A" + + paragraphs = [] doc = docx.Document(file) - full_text = [para.text for para in doc.paragraphs] - return TEXT_SECTION_SEPARATOR.join(full_text) + for item in doc.iter_inner_content(): + if isinstance(item, docx.text.paragraph.Paragraph): + paragraphs.append(item.text) + + elif isinstance(item, docx.table.Table): + if not item.rows or not is_simple_table(item): + continue + + # Every row is a new line, joined with a single newline + table_content = "\n".join( + [ + ",\t".join(extract_cell_text(cell) for cell in row.cells) + for row in item.rows + ] + ) + paragraphs.append(table_content) + + # Docx already has good spacing between paragraphs + return "\n".join(paragraphs) def pptx_to_text(file: IO[Any]) -> str: diff --git a/backend/danswer/file_store/utils.py b/backend/danswer/file_store/utils.py index 4b849f70d96..b71d20bbbb4 100644 --- a/backend/danswer/file_store/utils.py +++ b/backend/danswer/file_store/utils.py @@ -1,4 +1,6 @@ +from collections.abc import Callable from io import BytesIO +from typing import Any from typing import cast from uuid import uuid4 @@ -73,5 +75,7 @@ def save_file_from_url(url: str) -> str: def save_files_from_urls(urls: list[str]) -> list[str]: - funcs = [(save_file_from_url, (url,)) for url in urls] + funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [ + (save_file_from_url, (url,)) for url in urls + ] return run_functions_tuples_in_parallel(funcs) diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index f7d8f4e7400..d25a0659c62 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -32,6 +32,7 @@ def __init__( passage_prefix: str | None, provider_type: EmbeddingProvider | None, api_key: str | None, + api_url: str | None, ): self.model_name = model_name self.normalize = normalize @@ -39,6 +40,7 @@ def __init__( self.passage_prefix = passage_prefix self.provider_type = provider_type self.api_key = api_key + self.api_url = api_url self.embedding_model = EmbeddingModel( model_name=model_name, @@ -47,6 +49,7 @@ def __init__( normalize=normalize, api_key=api_key, provider_type=provider_type, + api_url=api_url, # The below are globally set, this flow always uses the indexing one server_host=INDEXING_MODEL_SERVER_HOST, server_port=INDEXING_MODEL_SERVER_PORT, @@ -70,9 +73,16 @@ def __init__( passage_prefix: str | None, provider_type: EmbeddingProvider | None = None, api_key: str | None = None, + api_url: str | None = None, ): super().__init__( - model_name, normalize, query_prefix, passage_prefix, provider_type, api_key + model_name, + normalize, + query_prefix, + passage_prefix, + provider_type, + api_key, + api_url, ) @log_function_time() @@ -156,7 +166,7 @@ def embed_chunks( title_embed_dict[title] = title_embedding new_embedded_chunk = IndexChunk( - **chunk.model_dump(), + **chunk.dict(), embeddings=ChunkEmbedding( full_embedding=chunk_embeddings[0], mini_chunk_embeddings=chunk_embeddings[1:], @@ -179,6 +189,7 @@ def from_db_search_settings( passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, api_key=search_settings.api_key, + api_url=search_settings.api_url, ) @@ -202,4 +213,5 @@ def get_embedding_model_from_search_settings( passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, api_key=search_settings.api_key, + api_url=search_settings.api_url, ) diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index 3517b55767d..afe825d11ec 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from danswer.access.access import get_access_for_documents +from danswer.access.models import DocumentAccess from danswer.configs.app_configs import ENABLE_MULTIPASS_INDEXING from danswer.configs.app_configs import INDEXING_EXCEPTION_LIMIT from danswer.configs.constants import DEFAULT_BOOST @@ -263,6 +264,8 @@ def index_doc_batch( Note that the documents should already be batched at this point so that it does not inflate the memory requirements""" + no_access = DocumentAccess.build([], [], False) + ctx = index_doc_batch_prepare( document_batch=document_batch, index_attempt_metadata=index_attempt_metadata, @@ -307,7 +310,9 @@ def index_doc_batch( access_aware_chunks = [ DocMetadataAwareIndexChunk.from_index_chunk( index_chunk=chunk, - access=document_id_to_access_info[chunk.source_document.id], + access=document_id_to_access_info.get( + chunk.source_document.id, no_access + ), document_sets=set( document_id_to_document_set.get(chunk.source_document.id, []) ), diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index b23de0eb477..93dc0f7315d 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -95,10 +95,12 @@ def from_index_chunk( class EmbeddingModelDetail(BaseModel): + id: int | None = None model_name: str normalize: bool query_prefix: str | None passage_prefix: str | None + api_url: str | None = None provider_type: EmbeddingProvider | None = None api_key: str | None = None @@ -111,12 +113,14 @@ def from_db_model( search_settings: "SearchSettings", ) -> "EmbeddingModelDetail": return cls( + id=search_settings.id, model_name=search_settings.model_name, normalize=search_settings.normalize, query_prefix=search_settings.query_prefix, passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, api_key=search_settings.api_key, + api_url=search_settings.api_url, ) diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index a664db217af..0a0a1c52afa 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -1,5 +1,6 @@ from collections.abc import Callable from collections.abc import Iterator +from typing import Any from typing import cast from uuid import uuid4 @@ -12,6 +13,8 @@ from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc +from danswer.chat.models import StreamStopInfo +from danswer.chat.models import StreamStopReason from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE from danswer.file_store.utils import InMemoryChatFile from danswer.llm.answering.models import AnswerStyleConfig @@ -35,7 +38,7 @@ from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping from danswer.llm.answering.stream_processing.utils import map_document_id_order from danswer.llm.interfaces import LLM -from danswer.llm.utils import message_generator_to_string_generator +from danswer.llm.interfaces import ToolChoiceOptions from danswer.natural_language_processing.utils import get_tokenizer from danswer.tools.custom.custom_tool_prompt_builder import ( build_user_message_for_custom_tool_for_non_tool_calling_llm, @@ -190,7 +193,9 @@ def _update_prompt_builder_for_search_tool( def _raw_output_for_explicit_tool_calling_llms( self, - ) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]: + ) -> Iterator[ + str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult + ]: prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) tool_call_chunk: AIMessageChunk | None = None @@ -225,6 +230,7 @@ def _raw_output_for_explicit_tool_calling_llms( self.tools, self.force_use_tool ) ] + for message in self.llm.stream( prompt=prompt, tools=final_tool_definitions if final_tool_definitions else None, @@ -242,6 +248,13 @@ def _raw_output_for_explicit_tool_calling_llms( if self.is_cancelled: return yield cast(str, message.content) + if ( + message.additional_kwargs.get("usage_metadata", {}).get("stop") + == "length" + ): + yield StreamStopInfo( + stop_reason=StreamStopReason.CONTEXT_LENGTH + ) if not tool_call_chunk: return # no tool call needed @@ -298,21 +311,41 @@ def _raw_output_for_explicit_tool_calling_llms( yield tool_runner.tool_final_result() prompt = prompt_builder.build(tool_call_summary=tool_call_summary) - for token in message_generator_to_string_generator( - self.llm.stream( - prompt=prompt, - tools=[tool.tool_definition() for tool in self.tools], - ) - ): - if self.is_cancelled: - return - yield token + + yield from self._process_llm_stream( + prompt=prompt, + tools=[tool.tool_definition() for tool in self.tools], + ) return + # This method processes the LLM stream and yields the content or stop information + def _process_llm_stream( + self, + prompt: Any, + tools: list[dict] | None = None, + tool_choice: ToolChoiceOptions | None = None, + ) -> Iterator[str | StreamStopInfo]: + for message in self.llm.stream( + prompt=prompt, tools=tools, tool_choice=tool_choice + ): + if isinstance(message, AIMessageChunk): + if message.content: + if self.is_cancelled: + return StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) + yield cast(str, message.content) + + if ( + message.additional_kwargs.get("usage_metadata", {}).get("stop") + == "length" + ): + yield StreamStopInfo(stop_reason=StreamStopReason.CONTEXT_LENGTH) + def _raw_output_for_non_explicit_tool_calling_llms( self, - ) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]: + ) -> Iterator[ + str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult + ]: prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) chosen_tool_and_args: tuple[Tool, dict] | None = None @@ -387,13 +420,10 @@ def _raw_output_for_non_explicit_tool_calling_llms( ) ) prompt = prompt_builder.build() - for token in message_generator_to_string_generator( - self.llm.stream(prompt=prompt) - ): - if self.is_cancelled: - return - yield token - + yield from self._process_llm_stream( + prompt=prompt, + tools=None, + ) return tool, tool_args = chosen_tool_and_args @@ -447,12 +477,8 @@ def _raw_output_for_non_explicit_tool_calling_llms( yield final prompt = prompt_builder.build() - for token in message_generator_to_string_generator( - self.llm.stream(prompt=prompt) - ): - if self.is_cancelled: - return - yield token + + yield from self._process_llm_stream(prompt=prompt, tools=None) @property def processed_streamed_output(self) -> AnswerStream: @@ -470,7 +496,7 @@ def processed_streamed_output(self) -> AnswerStream: ) def _process_stream( - stream: Iterator[ToolCallKickoff | ToolResponse | str], + stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo], ) -> AnswerStream: message = None @@ -524,13 +550,22 @@ def _process_stream( answer_style_configs=self.answer_style_config, ) + stream_stop_info = None + def _stream() -> Iterator[str]: - if message: - yield cast(str, message) - yield from cast(Iterator[str], stream) + nonlocal stream_stop_info + yield cast(str, message) + for item in stream: + if isinstance(item, StreamStopInfo): + stream_stop_info = item + return + yield cast(str, item) yield from process_answer_stream_fn(_stream()) + if stream_stop_info: + yield stream_stop_info + processed_stream = [] for processed_packet in _process_stream(output_generator): processed_stream.append(processed_packet) diff --git a/backend/danswer/llm/answering/stream_processing/citation_processing.py b/backend/danswer/llm/answering/stream_processing/citation_processing.py index de80b6f6756..a72fc70a8ff 100644 --- a/backend/danswer/llm/answering/stream_processing/citation_processing.py +++ b/backend/danswer/llm/answering/stream_processing/citation_processing.py @@ -11,7 +11,6 @@ from danswer.prompts.constants import TRIPLE_BACKTICK from danswer.utils.logger import setup_logger - logger = setup_logger() @@ -204,7 +203,9 @@ def extract_citations_from_stream( def build_citation_processor( context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping ) -> StreamProcessor: - def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn: + def stream_processor( + tokens: Iterator[str], + ) -> AnswerQuestionStreamReturn: yield from extract_citations_from_stream( tokens=tokens, context_docs=context_docs, diff --git a/backend/danswer/llm/answering/stream_processing/quotes_processing.py b/backend/danswer/llm/answering/stream_processing/quotes_processing.py index 74f37b85264..501a56b5aa7 100644 --- a/backend/danswer/llm/answering/stream_processing/quotes_processing.py +++ b/backend/danswer/llm/answering/stream_processing/quotes_processing.py @@ -285,7 +285,9 @@ def process_model_tokens( def build_quotes_processor( context_docs: list[LlmDoc], is_json_prompt: bool ) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]: - def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn: + def stream_processor( + tokens: Iterator[str], + ) -> AnswerQuestionStreamReturn: yield from process_model_tokens( tokens=tokens, context_docs=context_docs, diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index 359e3239b9d..52bdfc02999 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -25,9 +25,6 @@ from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING -from danswer.configs.model_configs import GEN_AI_API_ENDPOINT -from danswer.configs.model_configs import GEN_AI_API_VERSION -from danswer.configs.model_configs import GEN_AI_LLM_PROVIDER_TYPE from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLMConfig @@ -141,7 +138,9 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: def _convert_delta_to_message_chunk( - _dict: dict[str, Any], curr_msg: BaseMessage | None + _dict: dict[str, Any], + curr_msg: BaseMessage | None, + stop_reason: str | None = None, ) -> BaseMessageChunk: """Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk""" role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else None) @@ -166,12 +165,23 @@ def _convert_delta_to_message_chunk( args=tool_call.function.arguments, index=0, # only support a single tool call atm ) + return AIMessageChunk( content=content, - additional_kwargs=additional_kwargs, tool_call_chunks=[tool_call_chunk], + additional_kwargs={ + "usage_metadata": {"stop": stop_reason}, + **additional_kwargs, + }, ) - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + + return AIMessageChunk( + content=content, + additional_kwargs={ + "usage_metadata": {"stop": stop_reason}, + **additional_kwargs, + }, + ) elif role == "system": return SystemMessageChunk(content=content) elif role == "function": @@ -192,10 +202,10 @@ def __init__( timeout: int, model_provider: str, model_name: str, + api_base: str | None = None, + api_version: str | None = None, max_output_tokens: int | None = None, - api_base: str | None = GEN_AI_API_ENDPOINT, - api_version: str | None = GEN_AI_API_VERSION, - custom_llm_provider: str | None = GEN_AI_LLM_PROVIDER_TYPE, + custom_llm_provider: str | None = None, temperature: float = GEN_AI_TEMPERATURE, custom_config: dict[str, str] | None = None, extra_headers: dict[str, str] | None = None, @@ -209,7 +219,7 @@ def __init__( self._api_version = api_version self._custom_llm_provider = custom_llm_provider - # This can be used to store the maximum output tkoens for this model. + # This can be used to store the maximum output tokens for this model. # self._max_output_tokens = ( # max_output_tokens # if max_output_tokens is not None @@ -354,10 +364,16 @@ def _stream_implementation( ) try: for part in response: - if len(part["choices"]) == 0: + if not part["choices"]: continue - delta = part["choices"][0]["delta"] - message_chunk = _convert_delta_to_message_chunk(delta, output) + + choice = part["choices"][0] + message_chunk = _convert_delta_to_message_chunk( + choice["delta"], + output, + stop_reason=choice["finish_reason"], + ) + if output is None: output = message_chunk else: diff --git a/backend/danswer/llm/custom_llm.py b/backend/danswer/llm/custom_llm.py index 967e014a903..4a5ba7857c3 100644 --- a/backend/danswer/llm/custom_llm.py +++ b/backend/danswer/llm/custom_llm.py @@ -7,7 +7,6 @@ from langchain_core.messages import BaseMessage from requests import Timeout -from danswer.configs.model_configs import GEN_AI_API_ENDPOINT from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS from danswer.llm.interfaces import LLM from danswer.llm.interfaces import ToolChoiceOptions @@ -37,7 +36,7 @@ def __init__( # Not used here but you probably want a model server that isn't completely open api_key: str | None, timeout: int, - endpoint: str | None = GEN_AI_API_ENDPOINT, + endpoint: str, max_output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS, ): if not endpoint: diff --git a/backend/danswer/llm/llm_initialization.py b/backend/danswer/llm/llm_initialization.py deleted file mode 100644 index db59b836d7f..00000000000 --- a/backend/danswer/llm/llm_initialization.py +++ /dev/null @@ -1,113 +0,0 @@ -from sqlalchemy.orm import Session - -from danswer.configs.app_configs import DISABLE_GENERATIVE_AI -from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION -from danswer.configs.model_configs import GEN_AI_API_ENDPOINT -from danswer.configs.model_configs import GEN_AI_API_KEY -from danswer.configs.model_configs import GEN_AI_API_VERSION -from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION -from danswer.configs.model_configs import GEN_AI_LLM_PROVIDER_TYPE -from danswer.configs.model_configs import GEN_AI_DISPLAY_NAME -from danswer.db.llm import fetch_existing_llm_providers -from danswer.db.llm import update_default_provider -from danswer.db.llm import upsert_llm_provider -from danswer.llm.llm_provider_options import AZURE_PROVIDER_NAME -from danswer.llm.llm_provider_options import BEDROCK_PROVIDER_NAME -from danswer.llm.llm_provider_options import fetch_available_well_known_llms -from danswer.server.manage.llm.models import LLMProviderUpsertRequest -from danswer.utils.logger import setup_logger - - -logger = setup_logger() - - -def load_llm_providers(db_session: Session) -> None: - existing_providers = fetch_existing_llm_providers(db_session) - if existing_providers: - return - - if not GEN_AI_API_KEY or DISABLE_GENERATIVE_AI: - return - - if GEN_AI_MODEL_PROVIDER == "custom": - # Validate that all required env vars are present - for var in ( - GEN_AI_LLM_PROVIDER_TYPE, - GEN_AI_API_ENDPOINT, - GEN_AI_MODEL_VERSION, - GEN_AI_DISPLAY_NAME, - ): - if not var: - logger.error( - "Cannot auto-transition custom LLM provider due to missing env vars." - "The following env vars must all be set:" - "GEN_AI_LLM_PROVIDER_TYPE, GEN_AI_API_ENDPOINT, GEN_AI_MODEL_VERSION, GEN_AI_DISPLAY_NAME" - ) - return None - llm_provider_request = LLMProviderUpsertRequest( - name=GEN_AI_DISPLAY_NAME, - provider=GEN_AI_MODEL_PROVIDER, - api_key=GEN_AI_API_KEY, - api_base=GEN_AI_API_ENDPOINT, - api_version=GEN_AI_API_VERSION, - custom_config={}, - default_model_name=GEN_AI_MODEL_VERSION, - fast_default_model_name=FAST_GEN_AI_MODEL_VERSION, - ) - - else: - - well_known_provider_name_to_provider = { - provider.name: provider - for provider in fetch_available_well_known_llms() - if provider.name != BEDROCK_PROVIDER_NAME - } - - if GEN_AI_MODEL_PROVIDER not in well_known_provider_name_to_provider: - logger.error( - f"Cannot auto-transition LLM provider: {GEN_AI_MODEL_PROVIDER}" - ) - return None - - # Azure provider requires custom model names, - # OpenAI / anthropic can just use the defaults - model_names = ( - [ - name - for name in [ - GEN_AI_MODEL_VERSION, - FAST_GEN_AI_MODEL_VERSION, - ] - if name - ] - if GEN_AI_MODEL_PROVIDER == AZURE_PROVIDER_NAME - else None - ) - - well_known_provider = well_known_provider_name_to_provider[ - GEN_AI_MODEL_PROVIDER - ] - llm_provider_request = LLMProviderUpsertRequest( - name=well_known_provider.display_name, - provider=GEN_AI_MODEL_PROVIDER, - api_key=GEN_AI_API_KEY, - api_base=GEN_AI_API_ENDPOINT, - api_version=GEN_AI_API_VERSION, - custom_config={}, - default_model_name=( - GEN_AI_MODEL_VERSION - or well_known_provider.default_model - or well_known_provider.llm_names[0] - ), - fast_default_model_name=( - FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model - ), - model_names=model_names, - ) - - llm_provider = upsert_llm_provider(db_session, llm_provider_request) - update_default_provider(db_session, llm_provider.id) - logger.notice( - f"Migrated LLM provider from env variables for provider '{GEN_AI_MODEL_PROVIDER}'" - ) diff --git a/backend/danswer/llm/llm_provider_options.py b/backend/danswer/llm/llm_provider_options.py index 24feeb2f27c..1bcfdf7e506 100644 --- a/backend/danswer/llm/llm_provider_options.py +++ b/backend/danswer/llm/llm_provider_options.py @@ -95,8 +95,8 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]: api_version_required=False, custom_config_keys=[], llm_names=fetch_models_for_provider(ANTHROPIC_PROVIDER_NAME), - default_model="claude-3-opus-20240229", - default_fast_model="claude-3-sonnet-20240229", + default_model="claude-3-5-sonnet-20240620", + default_fast_model="claude-3-5-sonnet-20240620", ), WellKnownLLMProviderDescriptor( name=AZURE_PROVIDER_NAME, @@ -128,8 +128,8 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]: ), ], llm_names=fetch_models_for_provider(BEDROCK_PROVIDER_NAME), - default_model="anthropic.claude-3-sonnet-20240229-v1:0", - default_fast_model="anthropic.claude-3-haiku-20240307-v1:0", + default_model="anthropic.claude-3-5-sonnet-20240620-v1:0", + default_fast_model="anthropic.claude-3-5-sonnet-20240620-v1:0", ), ] diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 82617f3f05b..c367f0aa522 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -32,7 +32,6 @@ from danswer.configs.constants import MessageType from danswer.configs.model_configs import GEN_AI_MAX_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS -from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS from danswer.db.models import ChatMessage from danswer.file_store.models import ChatFileType @@ -331,7 +330,7 @@ def test_llm(llm: LLM) -> str | None: def get_llm_max_tokens( model_map: dict, model_name: str, - model_provider: str = GEN_AI_MODEL_PROVIDER, + model_provider: str, ) -> int: """Best effort attempt to get the max tokens for the LLM""" if GEN_AI_MAX_TOKENS: @@ -371,7 +370,7 @@ def get_llm_max_tokens( def get_llm_max_output_tokens( model_map: dict, model_name: str, - model_provider: str = GEN_AI_MODEL_PROVIDER, + model_provider: str, ) -> int: """Best effort attempt to get the max output tokens for the LLM""" try: diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 6652e5d3c39..a00826f11c8 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -1,4 +1,5 @@ import time +import traceback from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from typing import Any @@ -7,7 +8,9 @@ import uvicorn from fastapi import APIRouter from fastapi import FastAPI +from fastapi import HTTPException from fastapi import Request +from fastapi import status from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse @@ -36,6 +39,9 @@ from danswer.configs.constants import KV_REINDEX_KEY from danswer.configs.constants import KV_SEARCH_SETTINGS from danswer.configs.constants import POSTGRES_WEB_APP_NAME +from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION +from danswer.configs.model_configs import GEN_AI_API_KEY +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.db.connector import check_connectors_exist from danswer.db.connector import create_initial_default_connector from danswer.db.connector_credential_pair import associate_default_cc_pair @@ -48,6 +54,9 @@ from danswer.db.engine import warm_up_connections from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import expire_index_attempts +from danswer.db.llm import fetch_default_provider +from danswer.db.llm import update_default_provider +from danswer.db.llm import upsert_llm_provider from danswer.db.persona import delete_old_default_personas from danswer.db.search_settings import get_current_search_settings from danswer.db.search_settings import get_secondary_search_settings @@ -60,7 +69,6 @@ from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.indexing.models import IndexingSetting -from danswer.llm.llm_initialization import load_llm_providers from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder @@ -91,6 +99,7 @@ from danswer.server.manage.get_state import router as state_router from danswer.server.manage.llm.api import admin_router as llm_admin_router from danswer.server.manage.llm.api import basic_router as llm_router +from danswer.server.manage.llm.models import LLMProviderUpsertRequest from danswer.server.manage.search_settings import router as search_settings_router from danswer.server.manage.slack_bot import router as slack_bot_management_router from danswer.server.manage.standard_answer import router as standard_answer_router @@ -109,7 +118,9 @@ from danswer.tools.built_in_tools import auto_add_search_tool_to_personas from danswer.tools.built_in_tools import load_builtin_tools from danswer.tools.built_in_tools import refresh_built_in_tools_cache +from danswer.utils.gpu_utils import gpu_status_request from danswer.utils.logger import setup_logger +from danswer.utils.telemetry import get_or_generate_uuid from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType from danswer.utils.variable_functionality import fetch_versioned_implementation @@ -179,9 +190,6 @@ def setup_postgres(db_session: Session) -> None: logger.notice("Verifying default standard answer category exists.") create_initial_default_standard_answer_category(db_session) - logger.notice("Loading LLM providers from env variables") - load_llm_providers(db_session) - logger.notice("Loading default Prompts and Personas") delete_old_default_personas(db_session) load_chat_yamls() @@ -191,6 +199,58 @@ def setup_postgres(db_session: Session) -> None: refresh_built_in_tools_cache(db_session) auto_add_search_tool_to_personas(db_session) + if GEN_AI_API_KEY and fetch_default_provider(db_session) is None: + # Only for dev flows + logger.notice("Setting up default OpenAI LLM for dev.") + llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini" + fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini" + model_req = LLMProviderUpsertRequest( + name="DevEnvPresetOpenAI", + provider="openai", + api_key=GEN_AI_API_KEY, + api_base=None, + api_version=None, + custom_config=None, + default_model_name=llm_model, + fast_default_model_name=fast_model, + is_public=True, + groups=[], + display_model_names=[llm_model, fast_model], + model_names=[llm_model, fast_model], + ) + new_llm_provider = upsert_llm_provider( + llm_provider=model_req, db_session=db_session + ) + update_default_provider(provider_id=new_llm_provider.id, db_session=db_session) + + +def update_default_multipass_indexing(db_session: Session) -> None: + docs_exist = check_docs_exist(db_session) + connectors_exist = check_connectors_exist(db_session) + logger.debug(f"Docs exist: {docs_exist}, Connectors exist: {connectors_exist}") + + if not docs_exist and not connectors_exist: + logger.info( + "No existing docs or connectors found. Checking GPU availability for multipass indexing." + ) + gpu_available = gpu_status_request() + logger.info(f"GPU available: {gpu_available}") + + current_settings = get_current_search_settings(db_session) + + logger.notice(f"Updating multipass indexing setting to: {gpu_available}") + updated_settings = SavedSearchSettings.from_db_model(current_settings) + # Enable multipass indexing if GPU is available or if using a cloud provider + updated_settings.multipass_indexing = ( + gpu_available or current_settings.cloud_provider is not None + ) + update_current_search_settings(db_session, updated_settings) + + else: + logger.debug( + "Existing docs or connectors found. Skipping multipass indexing update." + ) + def translate_saved_search_settings(db_session: Session) -> None: kv_store = get_dynamic_config_store() @@ -260,21 +320,32 @@ def setup_vespa( document_index: DocumentIndex, index_setting: IndexingSetting, secondary_index_setting: IndexingSetting | None, -) -> None: +) -> bool: # Vespa startup is a bit slow, so give it a few seconds - wait_time = 5 - for _ in range(5): + WAIT_SECONDS = 5 + VESPA_ATTEMPTS = 5 + for x in range(VESPA_ATTEMPTS): try: + logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...") document_index.ensure_indices_exist( index_embedding_dim=index_setting.model_dim, secondary_index_embedding_dim=secondary_index_setting.model_dim if secondary_index_setting else None, ) - break + + logger.notice("Vespa setup complete.") + return True except Exception: - logger.notice(f"Waiting on Vespa, retrying in {wait_time} seconds...") - time.sleep(wait_time) + logger.notice( + f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds." + ) + time.sleep(WAIT_SECONDS) + + logger.error( + f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})" + ) + return False @asynccontextmanager @@ -297,6 +368,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: # fill up Postgres connection pools await warm_up_connections() + # We cache this at the beginning so there is no delay in the first telemetry + get_or_generate_uuid() + with Session(engine) as db_session: check_index_swap(db_session=db_session) search_settings = get_current_search_settings(db_session) @@ -329,8 +403,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: logger.notice( f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}." ) - - if search_settings.rerank_model_name and not search_settings.provider_type: + if ( + search_settings.rerank_model_name + and not search_settings.provider_type + and not search_settings.rerank_provider_type + ): warm_up_cross_encoder(search_settings.rerank_model_name) logger.notice("Verifying query preprocessing (NLTK) data is downloaded") @@ -353,13 +430,18 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: if secondary_search_settings else None, ) - setup_vespa( + + success = setup_vespa( document_index, IndexingSetting.from_db_model(search_settings), IndexingSetting.from_db_model(secondary_search_settings) if secondary_search_settings else None, ) + if not success: + raise RuntimeError( + "Could not connect to Vespa within the specified timeout." + ) logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}") if search_settings.provider_type is None: @@ -371,15 +453,41 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: ), ) + # update multipass indexing setting based on GPU availability + update_default_multipass_indexing(db_session) + optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__}) yield +def log_http_error(_: Request, exc: Exception) -> JSONResponse: + status_code = getattr(exc, "status_code", 500) + if status_code >= 400: + error_msg = f"{str(exc)}\n" + error_msg += "".join(traceback.format_tb(exc.__traceback__)) + logger.error(error_msg) + + detail = exc.detail if isinstance(exc, HTTPException) else str(exc) + return JSONResponse( + status_code=status_code, + content={"detail": detail}, + ) + + def get_application() -> FastAPI: application = FastAPI( title="Danswer Backend", version=__version__, lifespan=lifespan ) + # Add the custom exception handler + application.add_exception_handler(status.HTTP_400_BAD_REQUEST, log_http_error) + application.add_exception_handler(status.HTTP_401_UNAUTHORIZED, log_http_error) + application.add_exception_handler(status.HTTP_403_FORBIDDEN, log_http_error) + application.add_exception_handler(status.HTTP_404_NOT_FOUND, log_http_error) + application.add_exception_handler( + status.HTTP_500_INTERNAL_SERVER_ERROR, log_http_error + ) + include_router_with_global_prefix_prepended(application, chat_router) include_router_with_global_prefix_prepended(application, query_router) include_router_with_global_prefix_prepended(application, document_router) diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index b7835c4e906..6dcec724345 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -24,6 +24,8 @@ from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbedTextType from shared_configs.enums import RerankerProvider +from shared_configs.model_server_models import ConnectorClassificationRequest +from shared_configs.model_server_models import ConnectorClassificationResponse from shared_configs.model_server_models import Embedding from shared_configs.model_server_models import EmbedRequest from shared_configs.model_server_models import EmbedResponse @@ -90,6 +92,7 @@ def __init__( query_prefix: str | None, passage_prefix: str | None, api_key: str | None, + api_url: str | None, provider_type: EmbeddingProvider | None, retrim_content: bool = False, ) -> None: @@ -100,6 +103,7 @@ def __init__( self.normalize = normalize self.model_name = model_name self.retrim_content = retrim_content + self.api_url = api_url self.tokenizer = get_tokenizer( model_name=model_name, provider_type=provider_type ) @@ -157,6 +161,7 @@ def _batch_encode_texts( text_type=text_type, manual_query_prefix=self.query_prefix, manual_passage_prefix=self.passage_prefix, + api_url=self.api_url, ) response = self._make_model_server_request(embed_request) @@ -226,6 +231,7 @@ def from_db_model( passage_prefix=search_settings.passage_prefix, api_key=search_settings.api_key, provider_type=search_settings.provider_type, + api_url=search_settings.api_url, retrim_content=retrim_content, ) @@ -236,6 +242,7 @@ def __init__( model_name: str, provider_type: RerankerProvider | None, api_key: str | None, + api_url: str | None, model_server_host: str = MODEL_SERVER_HOST, model_server_port: int = MODEL_SERVER_PORT, ) -> None: @@ -244,6 +251,7 @@ def __init__( self.model_name = model_name self.provider_type = provider_type self.api_key = api_key + self.api_url = api_url def predict(self, query: str, passages: list[str]) -> list[float]: rerank_request = RerankRequest( @@ -252,6 +260,7 @@ def predict(self, query: str, passages: list[str]) -> list[float]: model_name=self.model_name, provider_type=self.provider_type, api_key=self.api_key, + api_url=self.api_url, ) response = requests.post( @@ -297,6 +306,37 @@ def predict( return response_model.is_keyword, response_model.keywords +class ConnectorClassificationModel: + def __init__( + self, + model_server_host: str = MODEL_SERVER_HOST, + model_server_port: int = MODEL_SERVER_PORT, + ): + model_server_url = build_model_server_url(model_server_host, model_server_port) + self.connector_classification_endpoint = ( + model_server_url + "/custom/connector-classification" + ) + + def predict( + self, + query: str, + available_connectors: list[str], + ) -> list[str]: + connector_classification_request = ConnectorClassificationRequest( + available_connectors=available_connectors, + query=query, + ) + response = requests.post( + self.connector_classification_endpoint, + json=connector_classification_request.dict(), + ) + response.raise_for_status() + + response_model = ConnectorClassificationResponse(**response.json()) + + return response_model.connectors + + def warm_up_retry( func: Callable[..., Any], tries: int = 20, @@ -312,8 +352,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return func(*args, **kwargs) except Exception as e: exceptions.append(e) - logger.exception( - f"Attempt {attempt + 1} failed; retrying in {delay} seconds..." + logger.info( + f"Attempt {attempt + 1}/{tries} failed; retrying in {delay} seconds..." ) time.sleep(delay) raise Exception(f"All retries failed: {exceptions}") @@ -363,6 +403,7 @@ def warm_up_cross_encoder( reranking_model = RerankingModel( model_name=rerank_model_name, provider_type=None, + api_url=None, api_key=None, ) diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index 15387e6c63e..678877812a2 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -26,6 +26,7 @@ class RerankingDetails(BaseModel): # If model is None (or num_rerank is 0), then reranking is turned off rerank_model_name: str | None + rerank_api_url: str | None rerank_provider_type: RerankerProvider | None rerank_api_key: str | None = None @@ -42,6 +43,7 @@ def from_db_model(cls, search_settings: SearchSettings) -> "RerankingDetails": rerank_provider_type=search_settings.rerank_provider_type, rerank_api_key=search_settings.rerank_api_key, num_rerank=search_settings.num_rerank, + rerank_api_url=search_settings.rerank_api_url, ) @@ -81,6 +83,7 @@ def from_db_model(cls, search_settings: SearchSettings) -> "SavedSearchSettings" num_rerank=search_settings.num_rerank, # Multilingual Expansion multilingual_expansion=search_settings.multilingual_expansion, + rerank_api_url=search_settings.rerank_api_url, ) diff --git a/backend/danswer/search/pipeline.py b/backend/danswer/search/pipeline.py index ad3e19e149d..183c8729d67 100644 --- a/backend/danswer/search/pipeline.py +++ b/backend/danswer/search/pipeline.py @@ -209,7 +209,9 @@ def _get_sections(self) -> list[InferenceSection]: if inference_section is not None: expanded_inference_sections.append(inference_section) else: - logger.warning("Skipped creation of section, no chunks found") + logger.warning( + "Skipped creation of section for full docs, no chunks found" + ) self._retrieved_sections = expanded_inference_sections return expanded_inference_sections @@ -270,6 +272,11 @@ def _get_sections(self) -> list[InferenceSection]: (chunk.document_id, chunk.chunk_id): chunk for chunk in inference_chunks } + # In case of failed parallel calls to Vespa, at least we should have the initial retrieved chunks + doc_chunk_ind_to_chunk.update( + {(chunk.document_id, chunk.chunk_id): chunk for chunk in retrieved_chunks} + ) + # Build the surroundings for all of the initial retrieved chunks for chunk in retrieved_chunks: start_ind = max(0, chunk.chunk_id - above) @@ -360,10 +367,10 @@ def section_relevance(self) -> list[SectionRelevancePiece] | None: try: results = run_functions_in_parallel(function_calls=functions) self._section_relevance = list(results.values()) - except Exception: + except Exception as e: raise ValueError( - "An issue occured during the agentic evaluation proecss." - ) + "An issue occured during the agentic evaluation process." + ) from e elif self.search_query.evaluation_type == LLMEvaluationType.BASIC: if DISABLE_LLM_DOC_RELEVANCE: diff --git a/backend/danswer/search/postprocessing/postprocessing.py b/backend/danswer/search/postprocessing/postprocessing.py index 6a3d2dc2dcd..b4a1e48bd39 100644 --- a/backend/danswer/search/postprocessing/postprocessing.py +++ b/backend/danswer/search/postprocessing/postprocessing.py @@ -100,6 +100,7 @@ def semantic_reranking( model_name=rerank_settings.rerank_model_name, provider_type=rerank_settings.rerank_provider_type, api_key=rerank_settings.rerank_api_key, + api_url=rerank_settings.rerank_api_url, ) passages = [ @@ -253,8 +254,8 @@ def search_postprocessing( if not retrieved_sections: # Avoids trying to rerank an empty list which throws an error - yield [] - yield [] + yield cast(list[InferenceSection], []) + yield cast(list[SectionRelevancePiece], []) return rerank_task_id = None diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py index 31582f90819..30347464ff8 100644 --- a/backend/danswer/search/retrieval/search_runner.py +++ b/backend/danswer/search/retrieval/search_runner.py @@ -3,7 +3,6 @@ import nltk # type:ignore from nltk.corpus import stopwords # type:ignore -from nltk.stem import WordNetLemmatizer # type:ignore from nltk.tokenize import word_tokenize # type:ignore from sqlalchemy.orm import Session @@ -40,7 +39,7 @@ def download_nltk_data() -> None: resources = { "stopwords": "corpora/stopwords", - "wordnet": "corpora/wordnet", + # "wordnet": "corpora/wordnet", # Not in use "punkt": "tokenizers/punkt", } @@ -58,15 +57,16 @@ def download_nltk_data() -> None: def lemmatize_text(keywords: list[str]) -> list[str]: - try: - query = " ".join(keywords) - lemmatizer = WordNetLemmatizer() - word_tokens = word_tokenize(query) - lemmatized_words = [lemmatizer.lemmatize(word) for word in word_tokens] - combined_keywords = list(set(keywords + lemmatized_words)) - return combined_keywords - except Exception: - return keywords + raise NotImplementedError("Lemmatization should not be used currently") + # try: + # query = " ".join(keywords) + # lemmatizer = WordNetLemmatizer() + # word_tokens = word_tokenize(query) + # lemmatized_words = [lemmatizer.lemmatize(word) for word in word_tokens] + # combined_keywords = list(set(keywords + lemmatized_words)) + # return combined_keywords + # except Exception: + # return keywords def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]: diff --git a/backend/danswer/secondary_llm_flows/agentic_evaluation.py b/backend/danswer/secondary_llm_flows/agentic_evaluation.py index 3de9db00be6..03121e3cf1d 100644 --- a/backend/danswer/secondary_llm_flows/agentic_evaluation.py +++ b/backend/danswer/secondary_llm_flows/agentic_evaluation.py @@ -58,25 +58,30 @@ def _get_metadata_str(metadata: dict[str, str | list[str]]) -> str: center_metadata=center_metadata_str, ) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = message_to_string(llm.invoke(filled_llm_prompt)) + try: + model_output = message_to_string(llm.invoke(filled_llm_prompt)) - # Search for the "Useful Analysis" section in the model output - # This regex looks for "2. Useful Analysis" (case-insensitive) followed by an optional colon, - # then any text up to "3. Final Relevance" - # The (?i) flag makes it case-insensitive, and re.DOTALL allows the dot to match newlines - # If no match is found, the entire model output is used as the analysis - analysis_match = re.search( - r"(?i)2\.\s*useful analysis:?\s*(.+?)\n\n3\.\s*final relevance", - model_output, - re.DOTALL, - ) - analysis = analysis_match.group(1).strip() if analysis_match else model_output + # Search for the "Useful Analysis" section in the model output + # This regex looks for "2. Useful Analysis" (case-insensitive) followed by an optional colon, + # then any text up to "3. Final Relevance" + # The (?i) flag makes it case-insensitive, and re.DOTALL allows the dot to match newlines + # If no match is found, the entire model output is used as the analysis + analysis_match = re.search( + r"(?i)2\.\s*useful analysis:?\s*(.+?)\n\n3\.\s*final relevance", + model_output, + re.DOTALL, + ) + analysis = analysis_match.group(1).strip() if analysis_match else model_output - # Get the last non-empty line - last_line = next( - (line for line in reversed(model_output.split("\n")) if line.strip()), "" - ) - relevant = last_line.strip().lower().startswith("true") + # Get the last non-empty line + last_line = next( + (line for line in reversed(model_output.split("\n")) if line.strip()), "" + ) + relevant = last_line.strip().lower().startswith("true") + except Exception as e: + logger.exception(f"An issue occured during the agentic evaluation process. {e}") + relevant = False + analysis = "" return SectionRelevancePiece( document_id=document_id, diff --git a/backend/danswer/secondary_llm_flows/source_filter.py b/backend/danswer/secondary_llm_flows/source_filter.py index 802a14f42fa..f58a91016e0 100644 --- a/backend/danswer/secondary_llm_flows/source_filter.py +++ b/backend/danswer/secondary_llm_flows/source_filter.py @@ -3,12 +3,16 @@ from sqlalchemy.orm import Session +from danswer.configs.chat_configs import ENABLE_CONNECTOR_CLASSIFIER from danswer.configs.constants import DocumentSource from danswer.db.connector import fetch_unique_document_sources from danswer.db.engine import get_sqlalchemy_engine from danswer.llm.interfaces import LLM from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string +from danswer.natural_language_processing.search_nlp_models import ( + ConnectorClassificationModel, +) from danswer.prompts.constants import SOURCES_KEY from danswer.prompts.filter_extration import FILE_SOURCE_WARNING from danswer.prompts.filter_extration import SOURCE_FILTER_PROMPT @@ -42,11 +46,38 @@ def _sample_document_sources( return random.sample(valid_sources, num_sample) +def _sample_documents_using_custom_connector_classifier( + query: str, + valid_sources: list[DocumentSource], +) -> list[DocumentSource] | None: + query_joined = "".join(ch for ch in query.lower() if ch.isalnum()) + available_connectors = list( + filter( + lambda conn: conn.lower() in query_joined, + [item.value for item in valid_sources], + ) + ) + + if not available_connectors: + return None + + connectors = ConnectorClassificationModel().predict(query, available_connectors) + + return strings_to_document_sources(connectors) if connectors else None + + def extract_source_filter( query: str, llm: LLM, db_session: Session ) -> list[DocumentSource] | None: """Returns a list of valid sources for search or None if no specific sources were detected""" + valid_sources = fetch_unique_document_sources(db_session) + if not valid_sources: + return None + + if ENABLE_CONNECTOR_CLASSIFIER: + return _sample_documents_using_custom_connector_classifier(query, valid_sources) + def _get_source_filter_messages( query: str, valid_sources: list[DocumentSource], @@ -146,10 +177,6 @@ def _extract_source_filters_from_llm_out( logger.warning("LLM failed to provide a valid Source Filter output") return None - valid_sources = fetch_unique_document_sources(db_session) - if not valid_sources: - return None - messages = _get_source_filter_messages(query=query, valid_sources=valid_sources) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) model_output = message_to_string(llm.invoke(filled_llm_prompt)) diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 69ae9916348..97ed3a82812 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -1,7 +1,9 @@ +import math + from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException -from pydantic import BaseModel +from fastapi import Query from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -19,20 +21,56 @@ from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair from danswer.db.index_attempt import cancel_indexing_attempts_past_model -from danswer.db.index_attempt import get_index_attempts_for_connector +from danswer.db.index_attempt import count_index_attempts_for_connector +from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id +from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id from danswer.db.models import User -from danswer.db.models import UserRole from danswer.server.documents.models import CCPairFullInfo +from danswer.server.documents.models import CCStatusUpdateRequest from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorCredentialPairMetadata +from danswer.server.documents.models import PaginatedIndexAttempts from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger +from ee.danswer.db.user_group import validate_user_creation_permissions logger = setup_logger() router = APIRouter(prefix="/manage") +@router.get("/admin/cc-pair/{cc_pair_id}/index-attempts") +def get_cc_pair_index_attempts( + cc_pair_id: int, + page: int = Query(1, ge=1), + page_size: int = Query(10, ge=1, le=1000), + user: User | None = Depends(current_curator_or_admin_user), + db_session: Session = Depends(get_session), +) -> PaginatedIndexAttempts: + cc_pair = get_connector_credential_pair_from_id( + cc_pair_id, db_session, user, get_editable=False + ) + if not cc_pair: + raise HTTPException( + status_code=400, detail="CC Pair not found for current user permissions" + ) + total_count = count_index_attempts_for_connector( + db_session=db_session, + connector_id=cc_pair.connector_id, + ) + index_attempts = get_paginated_index_attempts_for_cc_pair_id( + db_session=db_session, + connector_id=cc_pair.connector_id, + page=page, + page_size=page_size, + ) + return PaginatedIndexAttempts.from_models( + index_attempt_models=index_attempts, + page=page, + total_pages=math.ceil(total_count / page_size), + ) + + @router.get("/admin/cc-pair/{cc_pair_id}") def get_cc_pair_full_info( cc_pair_id: int, @@ -56,11 +94,6 @@ def get_cc_pair_full_info( credential_id=cc_pair.credential_id, ) - index_attempts = get_index_attempts_for_connector( - db_session, - cc_pair.connector_id, - ) - document_count_info_list = list( get_document_cnts_for_cc_pairs( db_session=db_session, @@ -71,9 +104,20 @@ def get_cc_pair_full_info( document_count_info_list[0][-1] if document_count_info_list else 0 ) + latest_attempt = get_latest_index_attempt_for_cc_pair_id( + db_session=db_session, + connector_credential_pair_id=cc_pair.id, + secondary_index=False, + only_finished=False, + ) + return CCPairFullInfo.from_models( cc_pair_model=cc_pair, - index_attempt_models=list(index_attempts), + number_of_index_attempts=count_index_attempts_for_connector( + db_session=db_session, + connector_id=cc_pair.connector_id, + ), + last_index_attempt=latest_attempt, latest_deletion_attempt=get_deletion_attempt_snapshot( connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, @@ -84,10 +128,6 @@ def get_cc_pair_full_info( ) -class CCStatusUpdateRequest(BaseModel): - status: ConnectorCredentialPairStatus - - @router.put("/admin/cc-pair/{cc_pair_id}/status") def update_cc_pair_status( cc_pair_id: int, @@ -157,11 +197,12 @@ def associate_credential_to_connector( user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: - if user and user.role != UserRole.ADMIN and metadata.is_public: - raise HTTPException( - status_code=400, - detail="Public connections cannot be created by non-admin users", - ) + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=metadata.groups, + object_is_public=metadata.is_public, + ) try: response = add_credential_to_connector( @@ -170,7 +211,7 @@ def associate_credential_to_connector( connector_id=connector_id, credential_id=credential_id, cc_pair_name=metadata.name, - is_public=metadata.is_public or True, + is_public=True if metadata.is_public is None else metadata.is_public, groups=metadata.groups, ) diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index 8d6b0ffc773..cc27d1cabaa 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -66,7 +66,7 @@ from danswer.db.engine import get_session from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_index_attempts_for_cc_pair -from danswer.db.index_attempt import get_latest_finished_index_attempt_for_cc_pair +from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id from danswer.db.index_attempt import get_latest_index_attempts from danswer.db.models import User from danswer.db.models import UserRole @@ -75,7 +75,6 @@ from danswer.file_store.file_store import get_default_file_store from danswer.server.documents.models import AuthStatus from danswer.server.documents.models import AuthUrl -from danswer.server.documents.models import ConnectorBase from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorIndexingStatus from danswer.server.documents.models import ConnectorSnapshot @@ -93,6 +92,7 @@ from danswer.server.documents.models import RunConnectorRequest from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger +from ee.danswer.db.user_group import validate_user_creation_permissions logger = setup_logger() @@ -387,7 +387,12 @@ def get_connector_indexing_status( ) -> list[ConnectorIndexingStatus]: indexing_statuses: list[ConnectorIndexingStatus] = [] - # TODO: make this one query + # NOTE: If the connector is deleting behind the scenes, + # accessing cc_pairs can be inconsistent and members like + # connector or credential may be None. + # Additional checks are done to make sure the connector and credential still exists. + # TODO: make this one query ... possibly eager load or wrap in a read transaction + # to avoid the complexity of trying to error check throughout the function cc_pairs = get_connector_credential_pairs( db_session=db_session, user=user, @@ -440,14 +445,19 @@ def get_connector_indexing_status( connector = cc_pair.connector credential = cc_pair.credential + if not connector or not credential: + # This may happen if background deletion is happening + continue + latest_index_attempt = cc_pair_to_latest_index_attempt.get( (connector.id, credential.id) ) - latest_finished_attempt = get_latest_finished_index_attempt_for_cc_pair( + latest_finished_attempt = get_latest_index_attempt_for_cc_pair_id( + db_session=db_session, connector_credential_pair_id=cc_pair.id, secondary_index=secondary_index, - db_session=db_session, + only_finished=True, ) indexing_statuses.append( @@ -514,35 +524,6 @@ def _validate_connector_allowed(source: DocumentSource) -> None: ) -def _check_connector_permissions( - connector_data: ConnectorUpdateRequest, user: User | None -) -> ConnectorBase: - """ - This is not a proper permission check, but this should prevent curators creating bad situations - until a long-term solution is implemented (Replacing CC pairs/Connectors with Connections) - """ - if user and user.role != UserRole.ADMIN: - if connector_data.is_public: - raise HTTPException( - status_code=400, - detail="Public connectors can only be created by admins", - ) - if not connector_data.groups: - raise HTTPException( - status_code=400, - detail="Connectors created by curators must have groups", - ) - return ConnectorBase( - name=connector_data.name, - source=connector_data.source, - input_type=connector_data.input_type, - connector_specific_config=connector_data.connector_specific_config, - refresh_freq=connector_data.refresh_freq, - prune_freq=connector_data.prune_freq, - indexing_start=connector_data.indexing_start, - ) - - @router.post("/admin/connector") def create_connector_from_model( connector_data: ConnectorUpdateRequest, @@ -551,12 +532,19 @@ def create_connector_from_model( ) -> ObjectCreationIdResponse: try: _validate_connector_allowed(connector_data.source) - connector_base = _check_connector_permissions(connector_data, user) + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=connector_data.groups, + object_is_public=connector_data.is_public, + ) + connector_base = connector_data.to_connector_base() return create_connector( db_session=db_session, connector_data=connector_base, ) except ValueError as e: + logger.error(f"Error creating connector: {e}") raise HTTPException(status_code=400, detail=str(e)) @@ -607,12 +595,18 @@ def create_connector_with_mock_credential( def update_connector_from_model( connector_id: int, connector_data: ConnectorUpdateRequest, - user: User = Depends(current_admin_user), + user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ConnectorSnapshot | StatusResponse[int]: try: _validate_connector_allowed(connector_data.source) - connector_base = _check_connector_permissions(connector_data, user) + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=connector_data.groups, + object_is_public=connector_data.is_public, + ) + connector_base = connector_data.to_connector_base() except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -642,7 +636,7 @@ def update_connector_from_model( @router.delete("/admin/connector/{connector_id}", response_model=StatusResponse[int]) def delete_connector_by_id( connector_id: int, - _: User = Depends(current_admin_user), + _: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: try: diff --git a/backend/danswer/server/documents/credential.py b/backend/danswer/server/documents/credential.py index ba30b65f2f9..3d965481bf5 100644 --- a/backend/danswer/server/documents/credential.py +++ b/backend/danswer/server/documents/credential.py @@ -7,7 +7,6 @@ from danswer.auth.users import current_admin_user from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user -from danswer.auth.users import validate_curator_request from danswer.db.credentials import alter_credential from danswer.db.credentials import create_credential from danswer.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE @@ -20,7 +19,6 @@ from danswer.db.engine import get_session from danswer.db.models import DocumentSource from danswer.db.models import User -from danswer.db.models import UserRole from danswer.server.documents.models import CredentialBase from danswer.server.documents.models import CredentialDataUpdateRequest from danswer.server.documents.models import CredentialSnapshot @@ -28,6 +26,7 @@ from danswer.server.documents.models import ObjectCreationIdResponse from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger +from ee.danswer.db.user_group import validate_user_creation_permissions logger = setup_logger() @@ -80,7 +79,7 @@ def get_cc_source_full_info( ] -@router.get("/credentials/{id}") +@router.get("/credential/{id}") def list_credentials_by_id( user: User | None = Depends(current_user), db_session: Session = Depends(get_session), @@ -105,7 +104,7 @@ def delete_credential_by_id_admin( ) -@router.put("/admin/credentials/swap") +@router.put("/admin/credential/swap") def swap_credentials_for_connector( credential_swap_req: CredentialSwapRequest, user: User | None = Depends(current_user), @@ -131,14 +130,12 @@ def create_credential_from_model( user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ObjectCreationIdResponse: - if ( - user - and user.role != UserRole.ADMIN - and not _ignore_credential_permissions(credential_info.source) - ): - validate_curator_request( - groups=credential_info.groups, - is_public=credential_info.curator_public, + if not _ignore_credential_permissions(credential_info.source): + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=credential_info.groups, + object_is_public=credential_info.curator_public, ) credential = create_credential(credential_info, user, db_session) @@ -179,7 +176,7 @@ def get_credential_by_id( return CredentialSnapshot.from_credential_db_model(credential) -@router.put("/admin/credentials/{credential_id}") +@router.put("/admin/credential/{credential_id}") def update_credential_data( credential_id: int, credential_update: CredentialDataUpdateRequest, diff --git a/backend/danswer/server/documents/models.py b/backend/danswer/server/documents/models.py index ba011afc196..2bed0cf54d5 100644 --- a/backend/danswer/server/documents/models.py +++ b/backend/danswer/server/documents/models.py @@ -4,6 +4,7 @@ from pydantic import BaseModel from pydantic import Field +from pydantic import model_validator from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX from danswer.configs.constants import DocumentSource @@ -48,9 +49,12 @@ class ConnectorBase(BaseModel): class ConnectorUpdateRequest(ConnectorBase): - is_public: bool | None = None + is_public: bool = True groups: list[int] = Field(default_factory=list) + def to_connector_base(self) -> ConnectorBase: + return ConnectorBase(**self.model_dump(exclude={"is_public", "groups"})) + class ConnectorSnapshot(ConnectorBase): id: int @@ -103,11 +107,6 @@ class CredentialSnapshot(CredentialBase): user_id: UUID | None time_created: datetime time_updated: datetime - name: str | None - source: DocumentSource - credential_json: dict[str, Any] - admin_public: bool - curator_public: bool @classmethod def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot": @@ -187,6 +186,28 @@ def from_db_model(cls, error: DbIndexAttemptError) -> "IndexAttemptError": ) +class PaginatedIndexAttempts(BaseModel): + index_attempts: list[IndexAttemptSnapshot] + page: int + total_pages: int + + @classmethod + def from_models( + cls, + index_attempt_models: list[IndexAttempt], + page: int, + total_pages: int, + ) -> "PaginatedIndexAttempts": + return cls( + index_attempts=[ + IndexAttemptSnapshot.from_index_attempt_db_model(index_attempt_model) + for index_attempt_model in index_attempt_models + ], + page=page, + total_pages=total_pages, + ) + + class CCPairFullInfo(BaseModel): id: int name: str @@ -194,7 +215,8 @@ class CCPairFullInfo(BaseModel): num_docs_indexed: int connector: ConnectorSnapshot credential: CredentialSnapshot - index_attempts: list[IndexAttemptSnapshot] + number_of_index_attempts: int + last_index_attempt_status: IndexingStatus | None latest_deletion_attempt: DeletionAttemptSnapshot | None is_public: bool is_editable_for_current_user: bool @@ -203,11 +225,27 @@ class CCPairFullInfo(BaseModel): def from_models( cls, cc_pair_model: ConnectorCredentialPair, - index_attempt_models: list[IndexAttempt], latest_deletion_attempt: DeletionAttemptSnapshot | None, + number_of_index_attempts: int, + last_index_attempt: IndexAttempt | None, num_docs_indexed: int, # not ideal, but this must be computed separately is_editable_for_current_user: bool, ) -> "CCPairFullInfo": + # figure out if we need to artificially deflate the number of docs indexed. + # This is required since the total number of docs indexed by a CC Pair is + # updated before the new docs for an indexing attempt. If we don't do this, + # there is a mismatch between these two numbers which may confuse users. + last_indexing_status = last_index_attempt.status if last_index_attempt else None + if ( + last_indexing_status == IndexingStatus.SUCCESS + and number_of_index_attempts == 1 + and last_index_attempt + and last_index_attempt.new_docs_indexed + ): + num_docs_indexed = ( + last_index_attempt.new_docs_indexed if last_index_attempt else 0 + ) + return cls( id=cc_pair_model.id, name=cc_pair_model.name, @@ -219,10 +257,8 @@ def from_models( credential=CredentialSnapshot.from_credential_db_model( cc_pair_model.credential ), - index_attempts=[ - IndexAttemptSnapshot.from_index_attempt_db_model(index_attempt_model) - for index_attempt_model in index_attempt_models - ], + number_of_index_attempts=number_of_index_attempts, + last_index_attempt_status=last_indexing_status, latest_deletion_attempt=latest_deletion_attempt, is_public=cc_pair_model.is_public, is_editable_for_current_user=is_editable_for_current_user, @@ -261,6 +297,10 @@ class ConnectorCredentialPairMetadata(BaseModel): groups: list[int] = Field(default_factory=list) +class CCStatusUpdateRequest(BaseModel): + status: ConnectorCredentialPairStatus + + class ConnectorCredentialPairDescriptor(BaseModel): id: int name: str | None = None @@ -307,8 +347,18 @@ class GoogleServiceAccountKey(BaseModel): class GoogleServiceAccountCredentialRequest(BaseModel): - google_drive_delegated_user: str | None # email of user to impersonate - gmail_delegated_user: str | None # email of user to impersonate + google_drive_delegated_user: str | None = None # email of user to impersonate + gmail_delegated_user: str | None = None # email of user to impersonate + + @model_validator(mode="after") + def check_user_delegation(self) -> "GoogleServiceAccountCredentialRequest": + if (self.google_drive_delegated_user is None) == ( + self.gmail_delegated_user is None + ): + raise ValueError( + "Exactly one of google_drive_delegated_user or gmail_delegated_user must be set" + ) + return self class FileUploadResponse(BaseModel): diff --git a/backend/danswer/server/features/document_set/api.py b/backend/danswer/server/features/document_set/api.py index d1eff082891..c9cea2cf2a2 100644 --- a/backend/danswer/server/features/document_set/api.py +++ b/backend/danswer/server/features/document_set/api.py @@ -6,7 +6,6 @@ from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user -from danswer.auth.users import validate_curator_request from danswer.db.document_set import check_document_sets_are_public from danswer.db.document_set import fetch_all_document_sets_for_user from danswer.db.document_set import insert_document_set @@ -14,12 +13,12 @@ from danswer.db.document_set import update_document_set from danswer.db.engine import get_session from danswer.db.models import User -from danswer.db.models import UserRole from danswer.server.features.document_set.models import CheckDocSetPublicRequest from danswer.server.features.document_set.models import CheckDocSetPublicResponse from danswer.server.features.document_set.models import DocumentSet from danswer.server.features.document_set.models import DocumentSetCreationRequest from danswer.server.features.document_set.models import DocumentSetUpdateRequest +from ee.danswer.db.user_group import validate_user_creation_permissions router = APIRouter(prefix="/manage") @@ -31,11 +30,12 @@ def create_document_set( user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> int: - if user and user.role != UserRole.ADMIN: - validate_curator_request( - groups=document_set_creation_request.groups, - is_public=document_set_creation_request.is_public, - ) + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=document_set_creation_request.groups, + object_is_public=document_set_creation_request.is_public, + ) try: document_set_db_model, _ = insert_document_set( document_set_creation_request=document_set_creation_request, @@ -53,11 +53,12 @@ def patch_document_set( user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> None: - if user and user.role != UserRole.ADMIN: - validate_curator_request( - groups=document_set_update_request.groups, - is_public=document_set_update_request.is_public, - ) + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=document_set_update_request.groups, + object_is_public=document_set_update_request.is_public, + ) try: update_document_set( document_set_update_request=document_set_update_request, diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index 0ac90ba8d11..a2d7156892c 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -166,10 +166,14 @@ def create_deletion_attempt_for_connector_id( get_editable=True, ) if cc_pair is None: + error = ( + f"Connector with ID '{connector_id}' and credential ID " + f"'{credential_id}' does not exist. Has it already been deleted?" + ) + logger.error(error) raise HTTPException( status_code=404, - detail=f"Connector with ID '{connector_id}' and credential ID " - f"'{credential_id}' does not exist. Has it already been deleted?", + detail=error, ) # Cancel any scheduled indexing attempts diff --git a/backend/danswer/server/manage/embedding/api.py b/backend/danswer/server/manage/embedding/api.py index 90fa69401c2..eac872810ef 100644 --- a/backend/danswer/server/manage/embedding/api.py +++ b/backend/danswer/server/manage/embedding/api.py @@ -9,7 +9,9 @@ from danswer.db.llm import remove_embedding_provider from danswer.db.llm import upsert_cloud_embedding_provider from danswer.db.models import User +from danswer.db.search_settings import get_all_search_settings from danswer.db.search_settings import get_current_db_embedding_provider +from danswer.indexing.models import EmbeddingModelDetail from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.server.manage.embedding.models import CloudEmbeddingProvider from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest @@ -20,6 +22,7 @@ from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbedTextType + logger = setup_logger() @@ -37,11 +40,12 @@ def test_embedding_configuration( server_host=MODEL_SERVER_HOST, server_port=MODEL_SERVER_PORT, api_key=test_llm_request.api_key, + api_url=test_llm_request.api_url, provider_type=test_llm_request.provider_type, + model_name=test_llm_request.model_name, normalize=False, query_prefix=None, passage_prefix=None, - model_name=None, ) test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY) @@ -56,6 +60,15 @@ def test_embedding_configuration( raise HTTPException(status_code=400, detail=error_msg) +@admin_router.get("", response_model=list[EmbeddingModelDetail]) +def list_embedding_models( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[EmbeddingModelDetail]: + search_settings = get_all_search_settings(db_session) + return [EmbeddingModelDetail.from_db_model(setting) for setting in search_settings] + + @admin_router.get("/embedding-provider") def list_embedding_providers( _: User | None = Depends(current_admin_user), diff --git a/backend/danswer/server/manage/embedding/models.py b/backend/danswer/server/manage/embedding/models.py index 132d311413c..b4ca7862b55 100644 --- a/backend/danswer/server/manage/embedding/models.py +++ b/backend/danswer/server/manage/embedding/models.py @@ -8,14 +8,21 @@ from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel +class SearchSettingsDeleteRequest(BaseModel): + search_settings_id: int + + class TestEmbeddingRequest(BaseModel): provider_type: EmbeddingProvider api_key: str | None = None + api_url: str | None = None + model_name: str | None = None class CloudEmbeddingProvider(BaseModel): provider_type: EmbeddingProvider api_key: str | None = None + api_url: str | None = None @classmethod def from_request( @@ -24,9 +31,11 @@ def from_request( return cls( provider_type=cloud_provider_model.provider_type, api_key=cloud_provider_model.api_key, + api_url=cloud_provider_model.api_url, ) class CloudEmbeddingProviderCreationRequest(BaseModel): provider_type: EmbeddingProvider api_key: str | None = None + api_url: str | None = None diff --git a/backend/danswer/server/manage/llm/api.py b/backend/danswer/server/manage/llm/api.py index 9ea9fe927db..4e57ec7bc35 100644 --- a/backend/danswer/server/manage/llm/api.py +++ b/backend/danswer/server/manage/llm/api.py @@ -121,7 +121,7 @@ def put_llm_provider( _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> FullLLMProvider: - return upsert_llm_provider(db_session, llm_provider) + return upsert_llm_provider(llm_provider=llm_provider, db_session=db_session) @admin_router.delete("/provider/{provider_id}") @@ -139,7 +139,7 @@ def set_provider_as_default( _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> None: - update_default_provider(db_session, provider_id) + update_default_provider(provider_id=provider_id, db_session=db_session) """Endpoints for all""" diff --git a/backend/danswer/server/manage/search_settings.py b/backend/danswer/server/manage/search_settings.py index db483eff5da..c8433467f6c 100644 --- a/backend/danswer/server/manage/search_settings.py +++ b/backend/danswer/server/manage/search_settings.py @@ -14,6 +14,7 @@ from danswer.db.models import IndexModelStatus from danswer.db.models import User from danswer.db.search_settings import create_search_settings +from danswer.db.search_settings import delete_search_settings from danswer.db.search_settings import get_current_search_settings from danswer.db.search_settings import get_embedding_provider_from_provider_type from danswer.db.search_settings import get_secondary_search_settings @@ -23,6 +24,7 @@ from danswer.natural_language_processing.search_nlp_models import clean_model_name from danswer.search.models import SavedSearchSettings from danswer.search.models import SearchSettingsCreationRequest +from danswer.server.manage.embedding.models import SearchSettingsDeleteRequest from danswer.server.manage.models import FullModelVersionResponse from danswer.server.models import IdReturn from danswer.utils.logger import setup_logger @@ -45,7 +47,7 @@ def set_new_search_settings( if search_settings_new.index_name: logger.warning("Index name was specified by request, this is not suggested") - # Validate cloud provider exists + # Validate cloud provider exists or create new LiteLLM provider if search_settings_new.provider_type is not None: cloud_provider = get_embedding_provider_from_provider_type( db_session, provider_type=search_settings_new.provider_type @@ -97,6 +99,7 @@ def set_new_search_settings( primary_index_name=search_settings.index_name, secondary_index_name=new_search_settings.index_name, ) + document_index.ensure_indices_exist( index_embedding_dim=search_settings.model_dim, secondary_index_embedding_dim=new_search_settings.model_dim, @@ -132,8 +135,23 @@ def cancel_new_embedding( ) +@router.delete("/delete-search-settings") +def delete_search_settings_endpoint( + deletion_request: SearchSettingsDeleteRequest, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + try: + delete_search_settings( + db_session=db_session, + search_settings_id=deletion_request.search_settings_id, + ) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + + @router.get("/get-current-search-settings") -def get_curr_search_settings( +def get_current_search_settings_endpoint( _: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> SavedSearchSettings: @@ -142,7 +160,7 @@ def get_curr_search_settings( @router.get("/get-secondary-search-settings") -def get_sec_search_settings( +def get_secondary_search_settings_endpoint( _: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> SavedSearchSettings | None: diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index d2fd981b5b5..96c79b4cbe7 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -213,6 +213,52 @@ def deactivate_user( db_session.commit() +@router.delete("/manage/admin/delete-user") +async def delete_user( + user_email: UserByEmail, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + user_to_delete = get_user_by_email( + email=user_email.user_email, db_session=db_session + ) + if not user_to_delete: + raise HTTPException(status_code=404, detail="User not found") + + if user_to_delete.is_active is True: + logger.warning( + "{} must be deactivated before deleting".format(user_to_delete.email) + ) + raise HTTPException( + status_code=400, detail="User must be deactivated before deleting" + ) + + # Detach the user from the current session + db_session.expunge(user_to_delete) + + try: + # Delete related OAuthAccounts first + for oauth_account in user_to_delete.oauth_accounts: + db_session.delete(oauth_account) + + db_session.delete(user_to_delete) + db_session.commit() + + # NOTE: edge case may exist with race conditions + # with this `invited user` scheme generally. + user_emails = get_invited_users() + remaining_users = [ + user for user in user_emails if user != user_email.user_email + ] + write_invited_users(remaining_users) + + logger.info(f"Deleted user {user_to_delete.email}") + except Exception as e: + db_session.rollback() + logger.error(f"Error deleting user {user_to_delete.email}: {str(e)}") + raise HTTPException(status_code=500, detail="Error deleting user") + + @router.patch("/manage/admin/activate-user") def activate_user( user_email: UserByEmail, diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index a37758336a2..20ae7124fa1 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -269,7 +269,10 @@ def delete_chat_session_by_id( db_session: Session = Depends(get_session), ) -> None: user_id = user.id if user is not None else None - delete_chat_session(user_id, session_id, db_session) + try: + delete_chat_session(user_id, session_id, db_session) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) async def is_disconnected(request: Request) -> Callable[[], bool]: diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index 704b16d5eaa..e20de5a3027 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -11,8 +11,8 @@ from danswer.db.chat import get_chat_messages_by_session from danswer.db.chat import get_chat_session_by_id from danswer.db.chat import get_chat_sessions_by_user -from danswer.db.chat import get_first_messages_for_chat_sessions from danswer.db.chat import get_search_docs_for_chat_message +from danswer.db.chat import get_valid_messages_from_query_sessions from danswer.db.chat import translate_db_message_to_chat_message_detail from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.engine import get_session @@ -142,18 +142,20 @@ def get_user_search_sessions( raise HTTPException( status_code=404, detail="Chat session does not exist or has been deleted" ) - + # Extract IDs from search sessions search_session_ids = [chat.id for chat in search_sessions] - first_messages = get_first_messages_for_chat_sessions( + # Fetch first messages for each session, only including those with documents + sessions_with_documents = get_valid_messages_from_query_sessions( search_session_ids, db_session ) - first_messages_dict = dict(first_messages) + sessions_with_documents_dict = dict(sessions_with_documents) + # Prepare response with detailed information for each valid search session response = ChatSessionsResponse( sessions=[ ChatSessionDetails( id=search.id, - name=first_messages_dict.get(search.id, search.description), + name=sessions_with_documents_dict[search.id], persona_id=search.persona_id, time_created=search.time_created.isoformat(), shared_status=search.shared_status, @@ -161,8 +163,11 @@ def get_user_search_sessions( current_alternate_model=search.current_alternate_model, ) for search in search_sessions + if search.id + in sessions_with_documents_dict # Only include sessions with documents ] ) + return response diff --git a/backend/danswer/server/settings/api.py b/backend/danswer/server/settings/api.py index 3330f6cc5ff..5b8564c3d3a 100644 --- a/backend/danswer/server/settings/api.py +++ b/backend/danswer/server/settings/api.py @@ -66,7 +66,7 @@ def fetch_settings( return UserSettings( **general_settings.model_dump(), notifications=user_notifications, - needs_reindexing=needs_reindexing + needs_reindexing=needs_reindexing, ) diff --git a/backend/danswer/tools/tool_runner.py b/backend/danswer/tools/tool_runner.py index f962c214a03..58b94bdb0c8 100644 --- a/backend/danswer/tools/tool_runner.py +++ b/backend/danswer/tools/tool_runner.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from collections.abc import Generator from typing import Any @@ -47,7 +48,7 @@ def tool_final_result(self) -> ToolCallFinalResult: def check_which_tools_should_run_for_non_tool_calling_llm( tools: list[Tool], query: str, history: list[PreviousMessage], llm: LLM ) -> list[dict[str, Any] | None]: - tool_args_list = [ + tool_args_list: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [ (tool.get_args_for_non_tool_calling_llm, (query, history, llm)) for tool in tools ] diff --git a/backend/danswer/utils/gpu_utils.py b/backend/danswer/utils/gpu_utils.py new file mode 100644 index 00000000000..70a3dbc2c95 --- /dev/null +++ b/backend/danswer/utils/gpu_utils.py @@ -0,0 +1,30 @@ +import requests +from retry import retry + +from danswer.utils.logger import setup_logger +from shared_configs.configs import INDEXING_MODEL_SERVER_HOST +from shared_configs.configs import INDEXING_MODEL_SERVER_PORT +from shared_configs.configs import MODEL_SERVER_HOST +from shared_configs.configs import MODEL_SERVER_PORT + +logger = setup_logger() + + +@retry(tries=5, delay=5) +def gpu_status_request(indexing: bool = True) -> bool: + if indexing: + model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}" + else: + model_server_url = f"{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}" + + if "http" not in model_server_url: + model_server_url = f"http://{model_server_url}" + + try: + response = requests.get(f"{model_server_url}/api/gpu-status", timeout=10) + response.raise_for_status() + gpu_status = response.json() + return gpu_status["gpu_available"] + except requests.RequestException as e: + logger.error(f"Error: Unable to fetch GPU status. Error: {str(e)}") + raise # Re-raise exception to trigger a retry diff --git a/backend/danswer/utils/telemetry.py b/backend/danswer/utils/telemetry.py index 80fcba65a16..d8a021877e6 100644 --- a/backend/danswer/utils/telemetry.py +++ b/backend/danswer/utils/telemetry.py @@ -4,13 +4,20 @@ from typing import cast import requests +from sqlalchemy.orm import Session from danswer.configs.app_configs import DISABLE_TELEMETRY +from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED from danswer.configs.constants import KV_CUSTOMER_UUID_KEY +from danswer.configs.constants import KV_INSTANCE_DOMAIN_KEY +from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.models import User from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError -DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.danswer.ai/anonymous_telemetry" +_DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.danswer.ai/anonymous_telemetry" +_CACHED_UUID: str | None = None +_CACHED_INSTANCE_DOMAIN: str | None = None class RecordType(str, Enum): @@ -22,13 +29,42 @@ class RecordType(str, Enum): def get_or_generate_uuid() -> str: + global _CACHED_UUID + + if _CACHED_UUID is not None: + return _CACHED_UUID + + kv_store = get_dynamic_config_store() + + try: + _CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY)) + except ConfigNotFoundError: + _CACHED_UUID = str(uuid.uuid4()) + kv_store.store(KV_CUSTOMER_UUID_KEY, _CACHED_UUID, encrypt=True) + + return _CACHED_UUID + + +def _get_or_generate_instance_domain() -> str | None: + global _CACHED_INSTANCE_DOMAIN + + if _CACHED_INSTANCE_DOMAIN is not None: + return _CACHED_INSTANCE_DOMAIN + kv_store = get_dynamic_config_store() + try: - return cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY)) + _CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY)) except ConfigNotFoundError: - customer_id = str(uuid.uuid4()) - kv_store.store(KV_CUSTOMER_UUID_KEY, customer_id, encrypt=True) - return customer_id + with Session(get_sqlalchemy_engine()) as db_session: + first_user = db_session.query(User).first() + if first_user: + _CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1] + kv_store.store( + KV_INSTANCE_DOMAIN_KEY, _CACHED_INSTANCE_DOMAIN, encrypt=True + ) + + return _CACHED_INSTANCE_DOMAIN def optional_telemetry( @@ -41,16 +77,19 @@ def optional_telemetry( def telemetry_logic() -> None: try: + customer_uuid = get_or_generate_uuid() payload = { "data": data, "record": record_type, # If None then it's a flow that doesn't include a user # For cases where the User itself is None, a string is provided instead "user_id": user_id, - "customer_uuid": get_or_generate_uuid(), + "customer_uuid": customer_uuid, } + if ENTERPRISE_EDITION_ENABLED: + payload["instance_domain"] = _get_or_generate_instance_domain() requests.post( - DANSWER_TELEMETRY_ENDPOINT, + _DANSWER_TELEMETRY_ENDPOINT, headers={"Content-Type": "application/json"}, json=payload, ) diff --git a/backend/danswer/utils/text_processing.py b/backend/danswer/utils/text_processing.py index b0fbcdfa1e9..134859d4e74 100644 --- a/backend/danswer/utils/text_processing.py +++ b/backend/danswer/utils/text_processing.py @@ -43,6 +43,35 @@ def replace_whitespaces_w_space(s: str) -> str: return re.sub(r"\s", " ", s) +# Function to remove punctuation from a string +def remove_punctuation(s: str) -> str: + return s.translate(str.maketrans("", "", string.punctuation)) + + +def escape_quotes(original_json_str: str) -> str: + result = [] + in_string = False + for i, char in enumerate(original_json_str): + if char == '"': + if not in_string: + in_string = True + result.append(char) + else: + next_char = ( + original_json_str[i + 1] if i + 1 < len(original_json_str) else None + ) + if result and result[-1] == "\\": + result.append(char) + elif next_char not in [",", ":", "}", "\n"]: + result.append("\\" + char) + else: + result.append(char) + in_string = False + else: + result.append(char) + return "".join(result) + + def extract_embedded_json(s: str) -> dict: first_brace_index = s.find("{") last_brace_index = s.rfind("}") @@ -50,7 +79,15 @@ def extract_embedded_json(s: str) -> dict: if first_brace_index == -1 or last_brace_index == -1: raise ValueError("No valid json found") - return json.loads(s[first_brace_index : last_brace_index + 1], strict=False) + json_str = s[first_brace_index : last_brace_index + 1] + try: + return json.loads(json_str, strict=False) + + except json.JSONDecodeError: + try: + return json.loads(escape_quotes(json_str), strict=False) + except json.JSONDecodeError as e: + raise ValueError("Failed to parse JSON, even after escaping quotes") from e def clean_up_code_blocks(model_out_raw: str) -> str: diff --git a/backend/ee/danswer/db/user_group.py b/backend/ee/danswer/db/user_group.py index 9d172c5d716..00e7d4d5ebd 100644 --- a/backend/ee/danswer/db/user_group.py +++ b/backend/ee/danswer/db/user_group.py @@ -2,6 +2,7 @@ from operator import and_ from uuid import UUID +from fastapi import HTTPException from sqlalchemy import delete from sqlalchemy import func from sqlalchemy import select @@ -30,6 +31,50 @@ logger = setup_logger() +def validate_user_creation_permissions( + db_session: Session, + user: User | None, + target_group_ids: list[int] | None, + object_is_public: bool | None, +) -> None: + """ + All admin actions are allowed. + Prevents non-admins from creating/editing: + - public objects + - objects with no groups + - objects that belong to a group they don't curate + """ + if not user or user.role == UserRole.ADMIN: + return + + if object_is_public: + detail = "User does not have permission to create public credentials" + logger.error(detail) + raise HTTPException( + status_code=400, + detail=detail, + ) + if not target_group_ids: + detail = "Curators must specify 1+ groups" + logger.error(detail) + raise HTTPException( + status_code=400, + detail=detail, + ) + user_curated_groups = fetch_user_groups_for_user( + db_session=db_session, user_id=user.id, only_curator_groups=True + ) + user_curated_group_ids = set([group.id for group in user_curated_groups]) + target_group_ids_set = set(target_group_ids) + if not target_group_ids_set.issubset(user_curated_group_ids): + detail = "Curators cannot control groups they don't curate" + logger.error(detail) + raise HTTPException( + status_code=400, + detail=detail, + ) + + def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | None: stmt = select(UserGroup).where(UserGroup.id == user_group_id) return db_session.scalar(stmt) diff --git a/backend/ee/danswer/server/enterprise_settings/api.py b/backend/ee/danswer/server/enterprise_settings/api.py index 736296517db..8590fd6c5e7 100644 --- a/backend/ee/danswer/server/enterprise_settings/api.py +++ b/backend/ee/danswer/server/enterprise_settings/api.py @@ -1,14 +1,24 @@ +from datetime import datetime +from datetime import timedelta +from datetime import timezone + +import httpx from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Response +from fastapi import status from fastapi import UploadFile from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user +from danswer.auth.users import current_user +from danswer.auth.users import get_user_manager +from danswer.auth.users import UserManager from danswer.db.engine import get_session from danswer.db.models import User from danswer.file_store.file_store import get_default_file_store +from danswer.utils.logger import setup_logger from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload from ee.danswer.server.enterprise_settings.models import EnterpriseSettings from ee.danswer.server.enterprise_settings.store import _LOGO_FILENAME @@ -18,10 +28,117 @@ from ee.danswer.server.enterprise_settings.store import store_analytics_script from ee.danswer.server.enterprise_settings.store import store_settings from ee.danswer.server.enterprise_settings.store import upload_logo +from shared_configs.configs import CUSTOM_REFRESH_URL admin_router = APIRouter(prefix="/admin/enterprise-settings") basic_router = APIRouter(prefix="/enterprise-settings") +logger = setup_logger() + + +def mocked_refresh_token() -> dict: + """ + This function mocks the response from a token refresh endpoint. + It generates a mock access token, refresh token, and user information + with an expiration time set to 1 hour from now. + This is useful for testing or development when the actual refresh endpoint is not available. + """ + mock_exp = int((datetime.now() + timedelta(hours=1)).timestamp() * 1000) + data = { + "access_token": "asdf Mock access token", + "refresh_token": "asdf Mock refresh token", + "session": {"exp": mock_exp}, + "userinfo": { + "sub": "Mock email", + "familyName": "Mock name", + "givenName": "Mock name", + "fullName": "Mock name", + "userId": "Mock User ID", + "email": "test_email@danswer.ai", + }, + } + return data + + +@basic_router.get("/refresh-token") +async def refresh_access_token( + user: User = Depends(current_user), + user_manager: UserManager = Depends(get_user_manager), +) -> None: + # return + if CUSTOM_REFRESH_URL is None: + logger.error( + "Custom refresh URL is not set and client is attempting to custom refresh" + ) + raise HTTPException( + status_code=500, + detail="Custom refresh URL is not set", + ) + + try: + async with httpx.AsyncClient() as client: + logger.debug(f"Sending request to custom refresh URL for user {user.id}") + access_token = user.oauth_accounts[0].access_token + + response = await client.get( + CUSTOM_REFRESH_URL, + params={"info": "json", "access_token_refresh_interval": 3600}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + response.raise_for_status() + data = response.json() + + # NOTE: Here is where we can mock the response + # data = mocked_refresh_token() + + logger.debug(f"Received response from Meechum auth URL for user {user.id}") + + # Extract new tokens + new_access_token = data["access_token"] + new_refresh_token = data["refresh_token"] + + new_expiry = datetime.fromtimestamp( + data["session"]["exp"] / 1000, tz=timezone.utc + ) + expires_at_timestamp = int(new_expiry.timestamp()) + + logger.debug(f"Access token has been refreshed for user {user.id}") + + await user_manager.oauth_callback( + oauth_name="custom", + access_token=new_access_token, + account_id=data["userinfo"]["userId"], + account_email=data["userinfo"]["email"], + expires_at=expires_at_timestamp, + refresh_token=new_refresh_token, + associate_by_email=True, + ) + + logger.info(f"Successfully refreshed tokens for user {user.id}") + + except httpx.HTTPStatusError as e: + if e.response.status_code == 401: + logger.warning(f"Full authentication required for user {user.id}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Full authentication required", + ) + logger.error( + f"HTTP error occurred while refreshing token for user {user.id}: {str(e)}" + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to refresh token", + ) + except Exception as e: + logger.error( + f"Unexpected error occurred while refreshing token for user {user.id}: {str(e)}" + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred", + ) + @admin_router.put("") def put_settings( diff --git a/backend/ee/danswer/server/enterprise_settings/models.py b/backend/ee/danswer/server/enterprise_settings/models.py index c9831d87aeb..c770fbd73e7 100644 --- a/backend/ee/danswer/server/enterprise_settings/models.py +++ b/backend/ee/danswer/server/enterprise_settings/models.py @@ -1,4 +1,13 @@ +from typing import List + from pydantic import BaseModel +from pydantic import Field + + +class NavigationItem(BaseModel): + link: str + icon: str + title: str class EnterpriseSettings(BaseModel): @@ -10,11 +19,16 @@ class EnterpriseSettings(BaseModel): use_custom_logo: bool = False use_custom_logotype: bool = False + # custom navigation + custom_nav_items: List[NavigationItem] = Field(default_factory=list) + # custom Chat components + two_lines_for_chat_header: bool | None = None custom_lower_disclaimer_content: str | None = None custom_header_content: str | None = None custom_popup_header: str | None = None custom_popup_content: str | None = None + enable_consent_screen: bool | None = None def check_validity(self) -> None: return diff --git a/backend/ee/danswer/server/saml.py b/backend/ee/danswer/server/saml.py index 5bc62e98d61..38966c15756 100644 --- a/backend/ee/danswer/server/saml.py +++ b/backend/ee/danswer/server/saml.py @@ -65,6 +65,7 @@ async def upsert_saml_user(email: str) -> User: password=hashed_pass, is_verified=True, role=role, + has_web_login=True, ) ) diff --git a/backend/ee/danswer/server/seeding.py b/backend/ee/danswer/server/seeding.py index bbca5acc20a..10dc1afb972 100644 --- a/backend/ee/danswer/server/seeding.py +++ b/backend/ee/danswer/server/seeding.py @@ -51,10 +51,12 @@ def _seed_llms( if llm_upsert_requests: logger.notice("Seeding LLMs") seeded_providers = [ - upsert_llm_provider(db_session, llm_upsert_request) + upsert_llm_provider(llm_upsert_request, db_session) for llm_upsert_request in llm_upsert_requests ] - update_default_provider(db_session, seeded_providers[0].id) + update_default_provider( + provider_id=seeded_providers[0].id, db_session=db_session + ) def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None: diff --git a/backend/ee/danswer/server/user_group/api.py b/backend/ee/danswer/server/user_group/api.py index e18487d5491..b33daddea64 100644 --- a/backend/ee/danswer/server/user_group/api.py +++ b/backend/ee/danswer/server/user_group/api.py @@ -9,6 +9,7 @@ from danswer.db.engine import get_session from danswer.db.models import User from danswer.db.models import UserRole +from danswer.utils.logger import setup_logger from ee.danswer.db.user_group import fetch_user_groups from ee.danswer.db.user_group import fetch_user_groups_for_user from ee.danswer.db.user_group import insert_user_group @@ -20,6 +21,8 @@ from ee.danswer.server.user_group.models import UserGroupCreate from ee.danswer.server.user_group.models import UserGroupUpdate +logger = setup_logger() + router = APIRouter(prefix="/manage") @@ -90,6 +93,7 @@ def set_user_curator( set_curator_request=set_curator_request, ) except ValueError as e: + logger.error(f"Error setting user curator: {e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/backend/model_server/custom_models.py b/backend/model_server/custom_models.py index 38bf4b077fa..fde3c8d0dc9 100644 --- a/backend/model_server/custom_models.py +++ b/backend/model_server/custom_models.py @@ -3,15 +3,21 @@ from fastapi import APIRouter from huggingface_hub import snapshot_download # type: ignore from transformers import AutoTokenizer # type: ignore -from transformers import BatchEncoding +from transformers import BatchEncoding # type: ignore +from transformers import PreTrainedTokenizer # type: ignore from danswer.utils.logger import setup_logger from model_server.constants import MODEL_WARM_UP_STRING +from model_server.danswer_torch_model import ConnectorClassifier from model_server.danswer_torch_model import HybridClassifier from model_server.utils import simple_log_function_time +from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO +from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG from shared_configs.configs import INDEXING_ONLY from shared_configs.configs import INTENT_MODEL_TAG from shared_configs.configs import INTENT_MODEL_VERSION +from shared_configs.model_server_models import ConnectorClassificationRequest +from shared_configs.model_server_models import ConnectorClassificationResponse from shared_configs.model_server_models import IntentRequest from shared_configs.model_server_models import IntentResponse @@ -19,10 +25,55 @@ router = APIRouter(prefix="/custom") +_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None +_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None + _INTENT_TOKENIZER: AutoTokenizer | None = None _INTENT_MODEL: HybridClassifier | None = None +def get_connector_classifier_tokenizer() -> AutoTokenizer: + global _CONNECTOR_CLASSIFIER_TOKENIZER + if _CONNECTOR_CLASSIFIER_TOKENIZER is None: + # The tokenizer details are not uploaded to the HF hub since it's just the + # unmodified distilbert tokenizer. + _CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained( + "distilbert-base-uncased" + ) + return _CONNECTOR_CLASSIFIER_TOKENIZER + + +def get_local_connector_classifier( + model_name_or_path: str = CONNECTOR_CLASSIFIER_MODEL_REPO, + tag: str = CONNECTOR_CLASSIFIER_MODEL_TAG, +) -> ConnectorClassifier: + global _CONNECTOR_CLASSIFIER_MODEL + if _CONNECTOR_CLASSIFIER_MODEL is None: + try: + # Calculate where the cache should be, then load from local if available + local_path = snapshot_download( + repo_id=model_name_or_path, revision=tag, local_files_only=True + ) + _CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained( + local_path + ) + except Exception as e: + logger.warning(f"Failed to load model directly: {e}") + try: + # Attempt to download the model snapshot + logger.info(f"Downloading model snapshot for {model_name_or_path}") + local_path = snapshot_download(repo_id=model_name_or_path, revision=tag) + _CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained( + local_path + ) + except Exception as e: + logger.error( + f"Failed to load model even after attempted snapshot download: {e}" + ) + raise + return _CONNECTOR_CLASSIFIER_MODEL + + def get_intent_model_tokenizer() -> AutoTokenizer: global _INTENT_TOKENIZER if _INTENT_TOKENIZER is None: @@ -61,6 +112,74 @@ def get_local_intent_model( return _INTENT_MODEL +def tokenize_connector_classification_query( + connectors: list[str], + query: str, + tokenizer: PreTrainedTokenizer, + connector_token_end_id: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Tokenize the connectors & user query into one prompt for the forward pass of ConnectorClassifier models + + The attention mask is just all 1s. The prompt is CLS + each connector name suffixed with the connector end + token and then the user query. + """ + + input_ids = torch.tensor([tokenizer.cls_token_id], dtype=torch.long) + + for connector in connectors: + connector_token_ids = tokenizer( + connector, + add_special_tokens=False, + return_tensors="pt", + ) + + input_ids = torch.cat( + ( + input_ids, + connector_token_ids["input_ids"].squeeze(dim=0), + torch.tensor([connector_token_end_id], dtype=torch.long), + ), + dim=-1, + ) + query_token_ids = tokenizer( + query, + add_special_tokens=False, + return_tensors="pt", + ) + + input_ids = torch.cat( + ( + input_ids, + query_token_ids["input_ids"].squeeze(dim=0), + torch.tensor([tokenizer.sep_token_id], dtype=torch.long), + ), + dim=-1, + ) + attention_mask = torch.ones(input_ids.numel(), dtype=torch.long) + + return input_ids.unsqueeze(0), attention_mask.unsqueeze(0) + + +def warm_up_connector_classifier_model() -> None: + logger.info( + f"Warming up connector_classifier model {CONNECTOR_CLASSIFIER_MODEL_TAG}" + ) + connector_classifier_tokenizer = get_connector_classifier_tokenizer() + connector_classifier = get_local_connector_classifier() + + input_ids, attention_mask = tokenize_connector_classification_query( + ["GitHub"], + "danswer classifier query google doc", + connector_classifier_tokenizer, + connector_classifier.connector_end_token_id, + ) + input_ids = input_ids.to(connector_classifier.device) + attention_mask = attention_mask.to(connector_classifier.device) + + connector_classifier(input_ids, attention_mask) + + def warm_up_intent_model() -> None: logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}") intent_tokenizer = get_intent_model_tokenizer() @@ -157,6 +276,35 @@ def clean_keywords(keywords: list[str]) -> list[str]: return cleaned_words +def run_connector_classification(req: ConnectorClassificationRequest) -> list[str]: + tokenizer = get_connector_classifier_tokenizer() + model = get_local_connector_classifier() + + connector_names = req.available_connectors + + input_ids, attention_mask = tokenize_connector_classification_query( + connector_names, + req.query, + tokenizer, + model.connector_end_token_id, + ) + input_ids = input_ids.to(model.device) + attention_mask = attention_mask.to(model.device) + + global_confidence, classifier_confidence = model(input_ids, attention_mask) + + if global_confidence.item() < 0.5: + return [] + + passed_connectors = [] + + for i, connector_name in enumerate(connector_names): + if classifier_confidence.view(-1)[i].item() > 0.5: + passed_connectors.append(connector_name) + + return passed_connectors + + def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]: tokenizer = get_intent_model_tokenizer() model_input = tokenizer( @@ -189,6 +337,22 @@ def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]: return is_keyword_sequence, cleaned_keywords +@router.post("/connector-classification") +async def process_connector_classification_request( + classification_request: ConnectorClassificationRequest, +) -> ConnectorClassificationResponse: + if INDEXING_ONLY: + raise RuntimeError( + "Indexing model server should not call connector classification endpoint" + ) + + if len(classification_request.available_connectors) == 0: + return ConnectorClassificationResponse(connectors=[]) + + connectors = run_connector_classification(classification_request) + return ConnectorClassificationResponse(connectors=connectors) + + @router.post("/query-analysis") async def process_analysis_request( intent_request: IntentRequest, diff --git a/backend/model_server/danswer_torch_model.py b/backend/model_server/danswer_torch_model.py index 28554a4fd2d..7390a97e049 100644 --- a/backend/model_server/danswer_torch_model.py +++ b/backend/model_server/danswer_torch_model.py @@ -4,7 +4,8 @@ import torch import torch.nn as nn from transformers import DistilBertConfig # type: ignore -from transformers import DistilBertModel +from transformers import DistilBertModel # type: ignore +from transformers import DistilBertTokenizer # type: ignore class HybridClassifier(nn.Module): @@ -21,7 +22,6 @@ def __init__(self) -> None: self.distilbert.config.dim, self.distilbert.config.dim ) self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2) - self.dropout = nn.Dropout(self.distilbert.config.seq_classif_dropout) self.device = torch.device("cpu") @@ -36,8 +36,7 @@ def forward( # Intent classification on the CLS token cls_token_state = sequence_output[:, 0, :] pre_classifier_out = self.pre_classifier(cls_token_state) - dropout_out = self.dropout(pre_classifier_out) - intent_logits = self.intent_classifier(dropout_out) + intent_logits = self.intent_classifier(pre_classifier_out) # Keyword classification on all tokens token_logits = self.keyword_classifier(sequence_output) @@ -72,3 +71,70 @@ def from_pretrained(cls, load_directory: str) -> "HybridClassifier": param.requires_grad = False return model + + +class ConnectorClassifier(nn.Module): + def __init__(self, config: DistilBertConfig) -> None: + super().__init__() + + self.config = config + self.distilbert = DistilBertModel(config) + self.connector_global_classifier = nn.Linear(self.distilbert.config.dim, 1) + self.connector_match_classifier = nn.Linear(self.distilbert.config.dim, 1) + self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") + + # Token indicating end of connector name, and on which classifier is used + self.connector_end_token_id = self.tokenizer.get_vocab()[ + self.config.connector_end_token + ] + + self.device = torch.device("cpu") + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states = self.distilbert( + input_ids=input_ids, attention_mask=attention_mask + ).last_hidden_state + + cls_hidden_states = hidden_states[ + :, 0, : + ] # Take leap of faith that first token is always [CLS] + global_logits = self.connector_global_classifier(cls_hidden_states).view(-1) + global_confidence = torch.sigmoid(global_logits).view(-1) + + connector_end_position_ids = input_ids == self.connector_end_token_id + connector_end_hidden_states = hidden_states[connector_end_position_ids] + classifier_output = self.connector_match_classifier(connector_end_hidden_states) + classifier_confidence = torch.nn.functional.sigmoid(classifier_output).view(-1) + + return global_confidence, classifier_confidence + + @classmethod + def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier": + config = DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")) + device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("mps") + if torch.backends.mps.is_available() + else torch.device("cpu") + ) + state_dict = torch.load( + os.path.join(repo_dir, "pytorch_model.pt"), + map_location=device, + weights_only=True, + ) + + model = cls(config) + model.load_state_dict(state_dict) + model.to(device) + model.device = device + model.eval() + + for param in model.parameters(): + param.requires_grad = False + + return model diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 4e97bd00f27..860151b3dc4 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -2,6 +2,7 @@ from typing import Any from typing import Optional +import httpx import openai import vertexai # type: ignore import voyageai # type: ignore @@ -83,7 +84,7 @@ def __init__( self.client = _initialize_client(api_key, self.provider, model) def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: - if model is None: + if not model: model = DEFAULT_OPENAI_MODEL # OpenAI does not seem to provide truncation option, however @@ -110,7 +111,7 @@ def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: def _embed_cohere( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: - if model is None: + if not model: model = DEFAULT_COHERE_MODEL final_embeddings: list[Embedding] = [] @@ -129,7 +130,7 @@ def _embed_cohere( def _embed_voyage( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: - if model is None: + if not model: model = DEFAULT_VOYAGE_MODEL # Similar to Cohere, the API server will do approximate size chunking @@ -145,7 +146,7 @@ def _embed_voyage( def _embed_vertex( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: - if model is None: + if not model: model = DEFAULT_VERTEX_MODEL embeddings = self.client.get_embeddings( @@ -171,7 +172,6 @@ def embed( try: if self.provider == EmbeddingProvider.OPENAI: return self._embed_openai(texts, model_name) - embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) if self.provider == EmbeddingProvider.COHERE: return self._embed_cohere(texts, model_name, embedding_type) @@ -235,6 +235,25 @@ def get_local_reranking_model( return _RERANK_MODEL +def embed_with_litellm_proxy( + texts: list[str], api_url: str, model_name: str, api_key: str | None +) -> list[Embedding]: + headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"} + + with httpx.Client() as client: + response = client.post( + api_url, + json={ + "model": model_name, + "input": texts, + }, + headers=headers, + ) + response.raise_for_status() + result = response.json() + return [embedding["embedding"] for embedding in result["data"]] + + @simple_log_function_time() def embed_text( texts: list[str], @@ -245,21 +264,42 @@ def embed_text( api_key: str | None, provider_type: EmbeddingProvider | None, prefix: str | None, + api_url: str | None, ) -> list[Embedding]: + logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}") + if not all(texts): + logger.error("Empty strings provided for embedding") raise ValueError("Empty strings are not allowed for embedding.") - # Third party API based embedding model if not texts: + logger.error("No texts provided for embedding") raise ValueError("No texts provided for embedding.") + + if provider_type == EmbeddingProvider.LITELLM: + logger.debug(f"Using LiteLLM proxy for embedding with URL: {api_url}") + if not api_url: + logger.error("API URL not provided for LiteLLM proxy") + raise ValueError("API URL is required for LiteLLM proxy embedding.") + try: + return embed_with_litellm_proxy( + texts=texts, + api_url=api_url, + model_name=model_name or "", + api_key=api_key, + ) + except Exception as e: + logger.exception(f"Error during LiteLLM proxy embedding: {str(e)}") + raise + elif provider_type is not None: - logger.debug(f"Embedding text with provider: {provider_type}") + logger.debug(f"Using cloud provider {provider_type} for embedding") if api_key is None: + logger.error("API key not provided for cloud model") raise RuntimeError("API key not provided for cloud model") if prefix: - # This may change in the future if some providers require the user - # to manually append a prefix but this is not the case currently + logger.warning("Prefix provided for cloud model, which is not supported") raise ValueError( "Prefix string is not valid for cloud models. " "Cloud models take an explicit text type instead." @@ -274,14 +314,15 @@ def embed_text( text_type=text_type, ) - # Check for None values in embeddings if any(embedding is None for embedding in embeddings): error_message = "Embeddings contain None values\n" error_message += "Corresponding texts:\n" error_message += "\n".join(texts) + logger.error(error_message) raise ValueError(error_message) elif model_name is not None: + logger.debug(f"Using local model {model_name} for embedding") prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts local_model = get_embedding_model( @@ -296,10 +337,12 @@ def embed_text( ] else: + logger.error("Neither model name nor provider specified for embedding") raise ValueError( "Either model name or provider must be provided to run embeddings." ) + logger.info(f"Successfully embedded {len(texts)} texts") return embeddings @@ -319,6 +362,28 @@ def cohere_rerank( return [result.relevance_score for result in sorted_results] +def litellm_rerank( + query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None +) -> list[float]: + headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"} + with httpx.Client() as client: + response = client.post( + api_url, + json={ + "model": model_name, + "query": query, + "documents": docs, + }, + headers=headers, + ) + response.raise_for_status() + result = response.json() + return [ + item["relevance_score"] + for item in sorted(result["results"], key=lambda x: x["index"]) + ] + + @router.post("/bi-encoder-embed") async def process_embed_request( embed_request: EmbedRequest, @@ -344,6 +409,7 @@ async def process_embed_request( api_key=embed_request.api_key, provider_type=embed_request.provider_type, text_type=embed_request.text_type, + api_url=embed_request.api_url, prefix=prefix, ) return EmbedResponse(embeddings=embeddings) @@ -374,6 +440,20 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons model_name=rerank_request.model_name, ) return RerankResponse(scores=sim_scores) + elif rerank_request.provider_type == RerankerProvider.LITELLM: + if rerank_request.api_url is None: + raise ValueError("API URL is required for LiteLLM reranking.") + + sim_scores = litellm_rerank( + query=rerank_request.query, + docs=rerank_request.documents, + api_url=rerank_request.api_url, + model_name=rerank_request.model_name, + api_key=rerank_request.api_key, + ) + + return RerankResponse(scores=sim_scores) + elif rerank_request.provider_type == RerankerProvider.COHERE: if rerank_request.api_key is None: raise RuntimeError("Cohere Rerank Requires an API Key") diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 9427335c47d..5b9d57b9d35 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -1,4 +1,4 @@ -aiohttp==3.9.4 +aiohttp==3.10.2 alembic==1.10.4 asyncpg==0.27.0 atlassian-python-api==3.37.0 @@ -12,7 +12,7 @@ distributed==2023.8.1 fastapi==0.109.2 fastapi-users==12.1.3 fastapi-users-db-sqlalchemy==5.0.0 -filelock==3.12.0 +filelock==3.15.4 google-api-python-client==2.86.0 google-auth-httplib2==0.1.0 google-auth-oauthlib==1.0.0 @@ -26,13 +26,12 @@ huggingface-hub==0.20.1 jira==3.5.1 jsonref==1.1.0 langchain==0.1.17 -langchain-community==0.0.36 langchain-core==0.1.50 langchain-text-splitters==0.0.1 litellm==1.43.18 llama-index==0.9.45 Mako==1.2.4 -msal==1.26.0 +msal==1.28.0 nltk==3.8.1 Office365-REST-Python-Client==2.5.9 oauthlib==3.2.2 @@ -50,10 +49,11 @@ python-pptx==0.6.23 pypdf==3.17.0 pytest-mock==3.12.0 pytest-playwright==0.3.2 -python-docx==1.1.0 +python-docx==1.1.2 python-dotenv==1.0.0 python-multipart==0.0.7 pywikibot==9.0.0 +redis==5.0.8 requests==2.32.2 requests-oauthlib==1.3.1 retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image diff --git a/backend/requirements/model_server.txt b/backend/requirements/model_server.txt index 0fb0e74b67b..18c2cefed28 100644 --- a/backend/requirements/model_server.txt +++ b/backend/requirements/model_server.txt @@ -8,7 +8,7 @@ pydantic==2.8.2 retry==0.9.2 safetensors==0.4.2 sentence-transformers==2.6.1 -torch==2.0.1 +torch==2.2.0 transformers==4.39.2 uvicorn==0.21.1 voyageai==0.2.3 diff --git a/backend/scripts/force_delete_connector_by_id.py b/backend/scripts/force_delete_connector_by_id.py index 118a4dfa4b4..0a9857304c8 100755 --- a/backend/scripts/force_delete_connector_by_id.py +++ b/backend/scripts/force_delete_connector_by_id.py @@ -83,8 +83,7 @@ def _unsafe_deletion( # Delete index attempts delete_index_attempts( db_session=db_session, - connector_id=connector_id, - credential_id=credential_id, + cc_pair_id=cc_pair.id, ) # Delete document sets diff --git a/backend/scripts/restart_containers.sh b/backend/scripts/restart_containers.sh index c60d1905eb5..838df5b5c79 100755 --- a/backend/scripts/restart_containers.sh +++ b/backend/scripts/restart_containers.sh @@ -1,15 +1,16 @@ #!/bin/bash # Usage of the script with optional volume arguments -# ./restart_containers.sh [vespa_volume] [postgres_volume] +# ./restart_containers.sh [vespa_volume] [postgres_volume] [redis_volume] VESPA_VOLUME=${1:-""} # Default is empty if not provided POSTGRES_VOLUME=${2:-""} # Default is empty if not provided +REDIS_VOLUME=${3:-""} # Default is empty if not provided # Stop and remove the existing containers echo "Stopping and removing existing containers..." -docker stop danswer_postgres danswer_vespa -docker rm danswer_postgres danswer_vespa +docker stop danswer_postgres danswer_vespa danswer_redis +docker rm danswer_postgres danswer_vespa danswer_redis # Start the PostgreSQL container with optional volume echo "Starting PostgreSQL container..." @@ -27,6 +28,14 @@ else docker run --detach --name danswer_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8 fi +# Start the Redis container with optional volume +echo "Starting Redis container..." +if [[ -n "$REDIS_VOLUME" ]]; then + docker run --detach --name danswer_redis --publish 6379:6379 -v $REDIS_VOLUME:/data redis +else + docker run --detach --name danswer_redis --publish 6379:6379 redis +fi + # Ensure alembic runs in the correct directory SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" PARENT_DIR="$(dirname "$SCRIPT_DIR")" diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index 5ad36cc93c4..fe933227009 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -16,9 +16,12 @@ ) # Danswer custom Deep Learning Models +CONNECTOR_CLASSIFIER_MODEL_REPO = "Danswer/filter-extraction-model" +CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0" INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier" INTENT_MODEL_TAG = "v1.0.3" + # Bi-Encoder, other details DOC_EMBEDDING_CONTEXT_SIZE = 512 @@ -58,9 +61,11 @@ # Fields which should only be set on new search setting PRESERVED_SEARCH_FIELDS = [ + "id", "provider_type", "api_key", "model_name", + "api_url", "index_name", "multipass_indexing", "model_dim", @@ -68,3 +73,5 @@ "passage_prefix", "query_prefix", ] + +CUSTOM_REFRESH_URL = os.environ.get("CUSTOM_REFRESH_URL") or "/settings/refresh-token" diff --git a/backend/shared_configs/enums.py b/backend/shared_configs/enums.py index 918872d44b3..b58ac0a8928 100644 --- a/backend/shared_configs/enums.py +++ b/backend/shared_configs/enums.py @@ -6,10 +6,12 @@ class EmbeddingProvider(str, Enum): COHERE = "cohere" VOYAGE = "voyage" GOOGLE = "google" + LITELLM = "litellm" class RerankerProvider(str, Enum): COHERE = "cohere" + LITELLM = "litellm" class EmbedTextType(str, Enum): diff --git a/backend/shared_configs/model_server_models.py b/backend/shared_configs/model_server_models.py index 3014616c620..dd846ed6bad 100644 --- a/backend/shared_configs/model_server_models.py +++ b/backend/shared_configs/model_server_models.py @@ -7,6 +7,15 @@ Embedding = list[float] +class ConnectorClassificationRequest(BaseModel): + available_connectors: list[str] + query: str + + +class ConnectorClassificationResponse(BaseModel): + connectors: list[str] + + class EmbedRequest(BaseModel): texts: list[str] # Can be none for cloud embedding model requests, error handling logic exists for other cases @@ -18,6 +27,7 @@ class EmbedRequest(BaseModel): text_type: EmbedTextType manual_query_prefix: str | None = None manual_passage_prefix: str | None = None + api_url: str | None = None # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} @@ -33,6 +43,7 @@ class RerankRequest(BaseModel): model_name: str provider_type: RerankerProvider | None = None api_key: str | None = None + api_url: str | None = None # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} diff --git a/backend/tests/daily/connectors/confluence/test_confluence_basic.py b/backend/tests/daily/connectors/confluence/test_confluence_basic.py index 7f05242c50b..4eb25207814 100644 --- a/backend/tests/daily/connectors/confluence/test_confluence_basic.py +++ b/backend/tests/daily/connectors/confluence/test_confluence_basic.py @@ -8,7 +8,13 @@ @pytest.fixture def confluence_connector() -> ConfluenceConnector: - connector = ConfluenceConnector(os.environ["CONFLUENCE_TEST_SPACE_URL"]) + connector = ConfluenceConnector( + wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"], + space=os.environ["CONFLUENCE_TEST_SPACE"], + is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true", + page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""), + ) + connector.load_credentials( { "confluence_username": os.environ["CONFLUENCE_USER_NAME"], diff --git a/backend/tests/daily/embedding/test_embeddings.py b/backend/tests/daily/embedding/test_embeddings.py index a9c12b236cf..b736f374741 100644 --- a/backend/tests/daily/embedding/test_embeddings.py +++ b/backend/tests/daily/embedding/test_embeddings.py @@ -32,6 +32,7 @@ def openai_embedding_model() -> EmbeddingModel: passage_prefix=None, api_key=os.getenv("OPENAI_API_KEY"), provider_type=EmbeddingProvider.OPENAI, + api_url=None, ) @@ -51,6 +52,7 @@ def cohere_embedding_model() -> EmbeddingModel: passage_prefix=None, api_key=os.getenv("COHERE_API_KEY"), provider_type=EmbeddingProvider.COHERE, + api_url=None, ) @@ -70,6 +72,7 @@ def local_nomic_embedding_model() -> EmbeddingModel: passage_prefix="search_document: ", api_key=None, provider_type=None, + api_url=None, ) diff --git a/backend/tests/integration/README.md b/backend/tests/integration/README.md new file mode 100644 index 00000000000..bc5e388082f --- /dev/null +++ b/backend/tests/integration/README.md @@ -0,0 +1,70 @@ +# Integration Tests + +## General Testing Overview +The integration tests are designed with a "manager" class and a "test" class for each type of object being manipulated (e.g., user, persona, credential): +- **Manager Class**: Contains methods for each type of API call. Responsible for creating, deleting, and verifying the existence of an entity. +- **Test Class**: Stores data for each entity being tested. This is our "expected state" of the object. + +The idea is that each test can use the manager class to create (.create()) a "test_" object. It can then perform an operation on the object (e.g., send a request to the API) and then check if the "test_" object is in the expected state by using the manager class (.verify()) function. + +## Instructions for Running Integration Tests Locally +1. Launch danswer (using Docker or running with a debugger), ensuring the API server is running on port 8080. + a. If you'd like to set environment variables, you can do so by creating a `.env` file in the danswer/backend/tests/integration/ directory. +2. Navigate to `danswer/backend`. +3. Run the following command in the terminal: + ```sh + pytest -s tests/integration/tests/ + ``` + or to run all tests in a file: + ```sh + pytest -s tests/integration/tests/path_to/test_file.py + ``` + or to run a single test: + ```sh + pytest -s tests/integration/tests/path_to/test_file.py::test_function_name + ``` + +## Guidelines for Writing Integration Tests +- As authentication is currently required for all tests, each test should start by creating a user. +- Each test should ideally focus on a single API flow. +- The test writer should try to consider failure cases and edge cases for the flow and write the tests to check for these cases. +- Every step of the test should be commented describing what is being done and what the expected behavior is. +- A summary of the test should be given at the top of the test function as well! +- When writing new tests, manager classes, manager functions, and test classes, try to copy the style of the other ones that have already been written. +- Be careful for scope creep! + - No need to overcomplicate every test by verifying after every single API call so long as the case you would be verifying is covered elsewhere (ideally in a test focused on covering that case). + - An example of this is: Creating an admin user is done at the beginning of nearly every test, but we only need to verify that the user is actually an admin in the test focused on checking admin permissions. For every other test, we can just create the admin user and assume that the permissions are working as expected. + +## Current Testing Limitations +### Test coverage +- All tests are probably not as high coverage as they could be. +- The "connector" tests in particular are super bare bones because we will be reworking connector/cc_pair sometime soon. +- Global Curator role is not thoroughly tested. +- No auth is not tested at all. +### Failure checking +- While we test expected auth failures, we only check that it failed at all. +- We dont check that the return codes are what we expect. +- This means that a test could be failing for a different reason than expected. +- We should ensure that the proper codes are being returned for each failure case. +- We should also query the db after each failure to ensure that the db is in the expected state. +### Scope/focus +- The tests may be scoped sub-optimally. +- The scoping of each test may be overlapping. + +## Current Testing Coverage +The current testing coverage should be checked by reading the comments at the top of each test file. + + +## TODO: Testing Coverage +- Persona permissions testing +- Read only (and/or basic) user permissions + - Ensuring proper permission enforcement using the chat/doc_search endpoints +- No auth + +## Ideas for integration testing design +### Combine the "test" and "manager" classes +This could make test writing a bit cleaner by preventing test writers from having to pass around objects into functions that the objects have a 1:1 relationship with. + +### Rework VespaClient +Right now, its used a fixture and has to be passed around between manager classes. +Could just be built where its used diff --git a/backend/tests/integration/common_utils/connectors.py b/backend/tests/integration/common_utils/connectors.py deleted file mode 100644 index e7734cec3c8..00000000000 --- a/backend/tests/integration/common_utils/connectors.py +++ /dev/null @@ -1,114 +0,0 @@ -import uuid -from typing import cast - -import requests -from pydantic import BaseModel - -from danswer.configs.constants import DocumentSource -from danswer.db.enums import ConnectorCredentialPairStatus -from tests.integration.common_utils.constants import API_SERVER_URL - - -class ConnectorCreationDetails(BaseModel): - connector_id: int - credential_id: int - cc_pair_id: int - - -class ConnectorClient: - @staticmethod - def create_connector( - name_prefix: str = "test_connector", credential_id: int | None = None - ) -> ConnectorCreationDetails: - unique_id = uuid.uuid4() - - connector_name = f"{name_prefix}_{unique_id}" - connector_data = { - "name": connector_name, - "source": DocumentSource.NOT_APPLICABLE, - "input_type": "load_state", - "connector_specific_config": {}, - "refresh_freq": 60, - "disabled": True, - } - response = requests.post( - f"{API_SERVER_URL}/manage/admin/connector", - json=connector_data, - ) - response.raise_for_status() - connector_id = response.json()["id"] - - # associate the credential with the connector - if not credential_id: - print("ID not specified, creating new credential") - # Create a new credential - credential_data = { - "credential_json": {}, - "admin_public": True, - "source": DocumentSource.NOT_APPLICABLE, - } - response = requests.post( - f"{API_SERVER_URL}/manage/credential", - json=credential_data, - ) - response.raise_for_status() - credential_id = cast(int, response.json()["id"]) - - cc_pair_metadata = {"name": f"test_cc_pair_{unique_id}", "is_public": True} - response = requests.put( - f"{API_SERVER_URL}/manage/connector/{connector_id}/credential/{credential_id}", - json=cc_pair_metadata, - ) - response.raise_for_status() - - # fetch the conenector credential pair id using the indexing status API - response = requests.get( - f"{API_SERVER_URL}/manage/admin/connector/indexing-status" - ) - response.raise_for_status() - indexing_statuses = response.json() - - cc_pair_id = None - for status in indexing_statuses: - if ( - status["connector"]["id"] == connector_id - and status["credential"]["id"] == credential_id - ): - cc_pair_id = status["cc_pair_id"] - break - - if cc_pair_id is None: - raise ValueError("Could not find the connector credential pair id") - - print( - f"Created connector with connector_id: {connector_id}, credential_id: {credential_id}, cc_pair_id: {cc_pair_id}" - ) - return ConnectorCreationDetails( - connector_id=int(connector_id), - credential_id=int(credential_id), - cc_pair_id=int(cc_pair_id), - ) - - @staticmethod - def update_connector_status( - cc_pair_id: int, status: ConnectorCredentialPairStatus - ) -> None: - response = requests.put( - f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/status", - json={"status": status}, - ) - response.raise_for_status() - - @staticmethod - def delete_connector(connector_id: int, credential_id: int) -> None: - response = requests.post( - f"{API_SERVER_URL}/manage/admin/deletion-attempt", - json={"connector_id": connector_id, "credential_id": credential_id}, - ) - response.raise_for_status() - - @staticmethod - def get_connectors() -> list[dict]: - response = requests.get(f"{API_SERVER_URL}/manage/connector") - response.raise_for_status() - return response.json() diff --git a/backend/tests/integration/common_utils/constants.py b/backend/tests/integration/common_utils/constants.py index efc98dde7de..7d729191cf6 100644 --- a/backend/tests/integration/common_utils/constants.py +++ b/backend/tests/integration/common_utils/constants.py @@ -5,3 +5,7 @@ API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080" API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}" MAX_DELAY = 30 + +GENERAL_HEADERS = {"Content-Type": "application/json"} + +NUM_DOCS = 5 diff --git a/backend/tests/integration/common_utils/document_sets.py b/backend/tests/integration/common_utils/document_sets.py deleted file mode 100644 index dc898611108..00000000000 --- a/backend/tests/integration/common_utils/document_sets.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import cast - -import requests - -from danswer.server.features.document_set.models import DocumentSet -from danswer.server.features.document_set.models import DocumentSetCreationRequest -from tests.integration.common_utils.constants import API_SERVER_URL - - -class DocumentSetClient: - @staticmethod - def create_document_set( - doc_set_creation_request: DocumentSetCreationRequest, - ) -> int: - response = requests.post( - f"{API_SERVER_URL}/manage/admin/document-set", - json=doc_set_creation_request.model_dump(), - ) - response.raise_for_status() - return cast(int, response.json()) - - @staticmethod - def fetch_document_sets() -> list[DocumentSet]: - response = requests.get(f"{API_SERVER_URL}/manage/document-set") - response.raise_for_status() - - document_sets = [ - DocumentSet.parse_obj(doc_set_data) for doc_set_data in response.json() - ] - return document_sets diff --git a/backend/tests/integration/common_utils/llm.py b/backend/tests/integration/common_utils/llm.py index ba8b89d6b4d..f74b40073c9 100644 --- a/backend/tests/integration/common_utils/llm.py +++ b/backend/tests/integration/common_utils/llm.py @@ -1,62 +1,88 @@ import os -from typing import cast +from uuid import uuid4 import requests -from pydantic import BaseModel -from pydantic import PrivateAttr from danswer.server.manage.llm.models import LLMProviderUpsertRequest from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestLLMProvider +from tests.integration.common_utils.test_models import TestUser -class LLMProvider(BaseModel): - provider: str - api_key: str - default_model_name: str - api_base: str | None = None - api_version: str | None = None - is_default: bool = True +class LLMProviderManager: + @staticmethod + def create( + name: str | None = None, + provider: str | None = None, + api_key: str | None = None, + default_model_name: str | None = None, + api_base: str | None = None, + api_version: str | None = None, + groups: list[int] | None = None, + is_public: bool | None = None, + user_performing_action: TestUser | None = None, + ) -> TestLLMProvider: + print("Seeding LLM Providers...") - # only populated after creation - _provider_id: int | None = PrivateAttr() - - def create(self) -> int: llm_provider = LLMProviderUpsertRequest( - name=self.provider, - provider=self.provider, - default_model_name=self.default_model_name, - api_key=self.api_key, - api_base=self.api_base, - api_version=self.api_version, + name=name or f"test-provider-{uuid4()}", + provider=provider or "openai", + default_model_name=default_model_name or "gpt-4o-mini", + api_key=api_key or os.environ["OPENAI_API_KEY"], + api_base=api_base, + api_version=api_version, custom_config=None, - fast_default_model_name=None, - is_public=True, - groups=[], + fast_default_model_name=default_model_name or "gpt-4o-mini", + is_public=is_public or True, + groups=groups or [], display_model_names=None, model_names=None, ) - response = requests.put( + llm_response = requests.put( f"{API_SERVER_URL}/admin/llm/provider", - json=llm_provider.dict(), + json=llm_provider.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + llm_response.raise_for_status() + response_data = llm_response.json() + result_llm = TestLLMProvider( + id=response_data["id"], + name=response_data["name"], + provider=response_data["provider"], + api_key=response_data["api_key"], + default_model_name=response_data["default_model_name"], + is_public=response_data["is_public"], + groups=response_data["groups"], + api_base=response_data["api_base"], + api_version=response_data["api_version"], ) - response.raise_for_status() - self._provider_id = cast(int, response.json()["id"]) - return self._provider_id + set_default_response = requests.post( + f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + set_default_response.raise_for_status() - def delete(self) -> None: + return result_llm + + @staticmethod + def delete( + llm_provider: TestLLMProvider, + user_performing_action: TestUser | None = None, + ) -> bool: + if not llm_provider.id: + raise ValueError("LLM Provider ID is required to delete a provider") response = requests.delete( - f"{API_SERVER_URL}/admin/llm/provider/{self._provider_id}" + f"{API_SERVER_URL}/admin/llm/provider/{llm_provider.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, ) response.raise_for_status() - - -def seed_default_openai_provider() -> LLMProvider: - llm = LLMProvider( - provider="openai", - default_model_name="gpt-4o-mini", - api_key=os.environ["OPENAI_API_KEY"], - ) - llm.create() - return llm + return True diff --git a/backend/tests/integration/common_utils/managers/api_key.py b/backend/tests/integration/common_utils/managers/api_key.py new file mode 100644 index 00000000000..b6d2c29b732 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/api_key.py @@ -0,0 +1,92 @@ +from uuid import uuid4 + +import requests + +from danswer.db.models import UserRole +from ee.danswer.server.api_key.models import APIKeyArgs +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestAPIKey +from tests.integration.common_utils.test_models import TestUser + + +class APIKeyManager: + @staticmethod + def create( + name: str | None = None, + api_key_role: UserRole = UserRole.ADMIN, + user_performing_action: TestUser | None = None, + ) -> TestAPIKey: + name = f"{name}-api-key" if name else f"test-api-key-{uuid4()}" + api_key_request = APIKeyArgs( + name=name, + role=api_key_role, + ) + api_key_response = requests.post( + f"{API_SERVER_URL}/admin/api-key", + json=api_key_request.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + api_key_response.raise_for_status() + api_key = api_key_response.json() + result_api_key = TestAPIKey( + api_key_id=api_key["api_key_id"], + api_key_display=api_key["api_key_display"], + api_key=api_key["api_key"], + api_key_name=name, + api_key_role=api_key_role, + user_id=api_key["user_id"], + headers=GENERAL_HEADERS, + ) + result_api_key.headers["Authorization"] = f"Bearer {result_api_key.api_key}" + return result_api_key + + @staticmethod + def delete( + api_key: TestAPIKey, + user_performing_action: TestUser | None = None, + ) -> None: + api_key_response = requests.delete( + f"{API_SERVER_URL}/admin/api-key/{api_key.api_key_id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + api_key_response.raise_for_status() + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[TestAPIKey]: + api_key_response = requests.get( + f"{API_SERVER_URL}/admin/api-key", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + api_key_response.raise_for_status() + return [TestAPIKey(**api_key) for api_key in api_key_response.json()] + + @staticmethod + def verify( + api_key: TestAPIKey, + verify_deleted: bool = False, + user_performing_action: TestUser | None = None, + ) -> None: + retrieved_keys = APIKeyManager.get_all( + user_performing_action=user_performing_action + ) + for key in retrieved_keys: + if key.api_key_id == api_key.api_key_id: + if verify_deleted: + raise ValueError("API Key found when it should have been deleted") + if ( + key.api_key_name == api_key.api_key_name + and key.api_key_role == api_key.api_key_role + ): + return + + if not verify_deleted: + raise Exception("API Key not found") diff --git a/backend/tests/integration/common_utils/managers/cc_pair.py b/backend/tests/integration/common_utils/managers/cc_pair.py new file mode 100644 index 00000000000..6498252bbe8 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/cc_pair.py @@ -0,0 +1,202 @@ +import time +from typing import Any +from uuid import uuid4 + +import requests + +from danswer.connectors.models import InputType +from danswer.db.enums import ConnectorCredentialPairStatus +from danswer.server.documents.models import ConnectorCredentialPairIdentifier +from danswer.server.documents.models import ConnectorIndexingStatus +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.constants import MAX_DELAY +from tests.integration.common_utils.managers.connector import ConnectorManager +from tests.integration.common_utils.managers.credential import CredentialManager +from tests.integration.common_utils.test_models import TestCCPair +from tests.integration.common_utils.test_models import TestUser + + +def _cc_pair_creator( + connector_id: int, + credential_id: int, + name: str | None = None, + is_public: bool = True, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, +) -> TestCCPair: + name = f"{name}-cc-pair" if name else f"test-cc-pair-{uuid4()}" + + request = { + "name": name, + "is_public": is_public, + "groups": groups or [], + } + + response = requests.put( + url=f"{API_SERVER_URL}/manage/connector/{connector_id}/credential/{credential_id}", + json=request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return TestCCPair( + id=response.json()["data"], + name=name, + connector_id=connector_id, + credential_id=credential_id, + is_public=is_public, + groups=groups or [], + ) + + +class CCPairManager: + @staticmethod + def create_from_scratch( + name: str | None = None, + is_public: bool = True, + groups: list[int] | None = None, + source: DocumentSource = DocumentSource.FILE, + input_type: InputType = InputType.LOAD_STATE, + connector_specific_config: dict[str, Any] | None = None, + credential_json: dict[str, Any] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestCCPair: + connector = ConnectorManager.create( + name=name, + source=source, + input_type=input_type, + connector_specific_config=connector_specific_config, + is_public=is_public, + groups=groups, + user_performing_action=user_performing_action, + ) + credential = CredentialManager.create( + credential_json=credential_json, + name=name, + source=source, + curator_public=is_public, + groups=groups, + user_performing_action=user_performing_action, + ) + return _cc_pair_creator( + connector_id=connector.id, + credential_id=credential.id, + name=name, + is_public=is_public, + groups=groups, + user_performing_action=user_performing_action, + ) + + @staticmethod + def create( + connector_id: int, + credential_id: int, + name: str | None = None, + is_public: bool = True, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestCCPair: + return _cc_pair_creator( + connector_id=connector_id, + credential_id=credential_id, + name=name, + is_public=is_public, + groups=groups, + user_performing_action=user_performing_action, + ) + + @staticmethod + def pause_cc_pair( + cc_pair: TestCCPair, + user_performing_action: TestUser | None = None, + ) -> None: + result = requests.put( + url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status", + json={"status": "PAUSED"}, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + result.raise_for_status() + + @staticmethod + def delete( + cc_pair: TestCCPair, + user_performing_action: TestUser | None = None, + ) -> None: + cc_pair_identifier = ConnectorCredentialPairIdentifier( + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + ) + result = requests.post( + url=f"{API_SERVER_URL}/manage/admin/deletion-attempt", + json=cc_pair_identifier.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + result.raise_for_status() + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[ConnectorIndexingStatus]: + response = requests.get( + f"{API_SERVER_URL}/manage/admin/connector/indexing-status", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [ConnectorIndexingStatus(**cc_pair) for cc_pair in response.json()] + + @staticmethod + def verify( + cc_pair: TestCCPair, + verify_deleted: bool = False, + user_performing_action: TestUser | None = None, + ) -> None: + all_cc_pairs = CCPairManager.get_all(user_performing_action) + for retrieved_cc_pair in all_cc_pairs: + if retrieved_cc_pair.cc_pair_id == cc_pair.id: + if verify_deleted: + # We assume that this check will be performed after the deletion is + # already waited for + raise ValueError( + f"CC pair {cc_pair.id} found but should be deleted" + ) + if ( + retrieved_cc_pair.name == cc_pair.name + and retrieved_cc_pair.connector.id == cc_pair.connector_id + and retrieved_cc_pair.credential.id == cc_pair.credential_id + and retrieved_cc_pair.public_doc == cc_pair.is_public + and set(retrieved_cc_pair.groups) == set(cc_pair.groups) + ): + return + + if not verify_deleted: + raise ValueError(f"CC pair {cc_pair.id} not found") + + @staticmethod + def wait_for_deletion_completion( + user_performing_action: TestUser | None = None, + ) -> None: + start = time.time() + while True: + cc_pairs = CCPairManager.get_all(user_performing_action) + if all( + cc_pair.cc_pair_status != ConnectorCredentialPairStatus.DELETING + for cc_pair in cc_pairs + ): + return + + if time.time() - start > MAX_DELAY: + raise TimeoutError( + f"CC pairs deletion was not completed within the {MAX_DELAY} seconds" + ) + else: + print("Some CC pairs are still being deleted, waiting...") + time.sleep(2) diff --git a/backend/tests/integration/common_utils/managers/connector.py b/backend/tests/integration/common_utils/managers/connector.py new file mode 100644 index 00000000000..f72d079683b --- /dev/null +++ b/backend/tests/integration/common_utils/managers/connector.py @@ -0,0 +1,124 @@ +from typing import Any +from uuid import uuid4 + +import requests + +from danswer.connectors.models import InputType +from danswer.server.documents.models import ConnectorUpdateRequest +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestConnector +from tests.integration.common_utils.test_models import TestUser + + +class ConnectorManager: + @staticmethod + def create( + name: str | None = None, + source: DocumentSource = DocumentSource.FILE, + input_type: InputType = InputType.LOAD_STATE, + connector_specific_config: dict[str, Any] | None = None, + is_public: bool = True, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestConnector: + name = f"{name}-connector" if name else f"test-connector-{uuid4()}" + + connector_update_request = ConnectorUpdateRequest( + name=name, + source=source, + input_type=input_type, + connector_specific_config=connector_specific_config or {}, + is_public=is_public, + groups=groups or [], + ) + + response = requests.post( + url=f"{API_SERVER_URL}/manage/admin/connector", + json=connector_update_request.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + response_data = response.json() + return TestConnector( + id=response_data.get("id"), + name=name, + source=source, + input_type=input_type, + connector_specific_config=connector_specific_config or {}, + groups=groups, + is_public=is_public, + ) + + @staticmethod + def edit( + connector: TestConnector, + user_performing_action: TestUser | None = None, + ) -> None: + response = requests.patch( + url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}", + json=connector.model_dump(exclude={"id"}), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def delete( + connector: TestConnector, + user_performing_action: TestUser | None = None, + ) -> None: + response = requests.delete( + url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[TestConnector]: + response = requests.get( + url=f"{API_SERVER_URL}/manage/connector", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [ + TestConnector( + id=conn.get("id"), + name=conn.get("name", ""), + source=conn.get("source", DocumentSource.FILE), + input_type=conn.get("input_type", InputType.LOAD_STATE), + connector_specific_config=conn.get("connector_specific_config", {}), + ) + for conn in response.json() + ] + + @staticmethod + def get( + connector_id: int, user_performing_action: TestUser | None = None + ) -> TestConnector: + response = requests.get( + url=f"{API_SERVER_URL}/manage/connector/{connector_id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + conn = response.json() + return TestConnector( + id=conn.get("id"), + name=conn.get("name", ""), + source=conn.get("source", DocumentSource.FILE), + input_type=conn.get("input_type", InputType.LOAD_STATE), + connector_specific_config=conn.get("connector_specific_config", {}), + ) diff --git a/backend/tests/integration/common_utils/managers/credential.py b/backend/tests/integration/common_utils/managers/credential.py new file mode 100644 index 00000000000..c05cd1b5a3e --- /dev/null +++ b/backend/tests/integration/common_utils/managers/credential.py @@ -0,0 +1,129 @@ +from typing import Any +from uuid import uuid4 + +import requests + +from danswer.server.documents.models import CredentialSnapshot +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestCredential +from tests.integration.common_utils.test_models import TestUser + + +class CredentialManager: + @staticmethod + def create( + credential_json: dict[str, Any] | None = None, + admin_public: bool = True, + name: str | None = None, + source: DocumentSource = DocumentSource.FILE, + curator_public: bool = True, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestCredential: + name = f"{name}-credential" if name else f"test-credential-{uuid4()}" + + credential_request = { + "name": name, + "credential_json": credential_json or {}, + "admin_public": admin_public, + "source": source, + "curator_public": curator_public, + "groups": groups or [], + } + response = requests.post( + url=f"{API_SERVER_URL}/manage/credential", + json=credential_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + + response.raise_for_status() + return TestCredential( + id=response.json()["id"], + name=name, + credential_json=credential_json or {}, + admin_public=admin_public, + source=source, + curator_public=curator_public, + groups=groups or [], + ) + + @staticmethod + def edit( + credential: TestCredential, + user_performing_action: TestUser | None = None, + ) -> None: + request = credential.model_dump(include={"name", "credential_json"}) + response = requests.put( + url=f"{API_SERVER_URL}/manage/admin/credential/{credential.id}", + json=request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def delete( + credential: TestCredential, + user_performing_action: TestUser | None = None, + ) -> None: + response = requests.delete( + url=f"{API_SERVER_URL}/manage/credential/{credential.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def get( + credential_id: int, user_performing_action: TestUser | None = None + ) -> CredentialSnapshot: + response = requests.get( + url=f"{API_SERVER_URL}/manage/credential/{credential_id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return CredentialSnapshot(**response.json()) + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[CredentialSnapshot]: + response = requests.get( + f"{API_SERVER_URL}/manage/credential", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [CredentialSnapshot(**cred) for cred in response.json()] + + @staticmethod + def verify( + credential: TestCredential, + verify_deleted: bool = False, + user_performing_action: TestUser | None = None, + ) -> None: + all_credentials = CredentialManager.get_all(user_performing_action) + for fetched_credential in all_credentials: + if credential.id == fetched_credential.id: + if verify_deleted: + raise ValueError( + f"Credential {credential.id} found but should be deleted" + ) + if ( + credential.name == fetched_credential.name + and credential.admin_public == fetched_credential.admin_public + and credential.source == fetched_credential.source + and credential.curator_public == fetched_credential.curator_public + ): + return + if not verify_deleted: + raise ValueError(f"Credential {credential.id} not found") diff --git a/backend/tests/integration/common_utils/managers/document.py b/backend/tests/integration/common_utils/managers/document.py new file mode 100644 index 00000000000..3f691eca8f9 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/document.py @@ -0,0 +1,153 @@ +from uuid import uuid4 + +import requests + +from danswer.configs.constants import DocumentSource +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.constants import NUM_DOCS +from tests.integration.common_utils.managers.api_key import TestAPIKey +from tests.integration.common_utils.managers.cc_pair import TestCCPair +from tests.integration.common_utils.test_models import SimpleTestDocument +from tests.integration.common_utils.test_models import TestUser +from tests.integration.common_utils.vespa import TestVespaClient + + +def _verify_document_permissions( + retrieved_doc: dict, + cc_pair: TestCCPair, + doc_set_names: list[str] | None = None, + group_names: list[str] | None = None, + doc_creating_user: TestUser | None = None, +) -> None: + acl_keys = set(retrieved_doc["access_control_list"].keys()) + print(f"ACL keys: {acl_keys}") + if cc_pair.is_public: + if "PUBLIC" not in acl_keys: + raise ValueError( + f"Document {retrieved_doc['document_id']} is public but" + " does not have the PUBLIC ACL key" + ) + + if doc_creating_user is not None: + if f"user_id:{doc_creating_user.id}" not in acl_keys: + raise ValueError( + f"Document {retrieved_doc['document_id']} was created by user" + f" {doc_creating_user.id} but does not have the user_id:{doc_creating_user.id} ACL key" + ) + + if group_names is not None: + expected_group_keys = {f"group:{group_name}" for group_name in group_names} + found_group_keys = {key for key in acl_keys if key.startswith("group:")} + if found_group_keys != expected_group_keys: + raise ValueError( + f"Document {retrieved_doc['document_id']} has incorrect group ACL keys. Found: {found_group_keys}, \n" + f"Expected: {expected_group_keys}" + ) + + if doc_set_names is not None: + found_doc_set_names = set(retrieved_doc.get("document_sets", {}).keys()) + if found_doc_set_names != set(doc_set_names): + raise ValueError( + f"Document set names mismatch. \nFound: {found_doc_set_names}, \n" + f"Expected: {set(doc_set_names)}" + ) + + +def _generate_dummy_document(document_id: str, cc_pair_id: int) -> dict: + return { + "document": { + "id": document_id, + "sections": [ + { + "text": f"This is test document {document_id}", + "link": f"{document_id}", + } + ], + "source": DocumentSource.NOT_APPLICABLE, + # just for testing metadata + "metadata": {"document_id": document_id}, + "semantic_identifier": f"Test Document {document_id}", + "from_ingestion_api": True, + }, + "cc_pair_id": cc_pair_id, + } + + +class DocumentManager: + @staticmethod + def seed_and_attach_docs( + cc_pair: TestCCPair, + num_docs: int = NUM_DOCS, + document_ids: list[str] | None = None, + api_key: TestAPIKey | None = None, + ) -> TestCCPair: + # Use provided document_ids if available, otherwise generate random UUIDs + if document_ids is None: + document_ids = [f"test-doc-{uuid4()}" for _ in range(num_docs)] + else: + num_docs = len(document_ids) + # Create and ingest some documents + documents: list[dict] = [] + for document_id in document_ids: + document = _generate_dummy_document(document_id, cc_pair.id) + documents.append(document) + response = requests.post( + f"{API_SERVER_URL}/danswer-api/ingestion", + json=document, + headers=api_key.headers if api_key else GENERAL_HEADERS, + ) + response.raise_for_status() + + print("Seeding completed successfully.") + cc_pair.documents = [ + SimpleTestDocument( + id=document["document"]["id"], + content=document["document"]["sections"][0]["text"], + ) + for document in documents + ] + return cc_pair + + @staticmethod + def verify( + vespa_client: TestVespaClient, + cc_pair: TestCCPair, + # If None, will not check doc sets or groups + # If empty list, will check for empty doc sets or groups + doc_set_names: list[str] | None = None, + group_names: list[str] | None = None, + doc_creating_user: TestUser | None = None, + verify_deleted: bool = False, + ) -> None: + doc_ids = [document.id for document in cc_pair.documents] + retrieved_docs_dict = vespa_client.get_documents_by_id(doc_ids)["documents"] + retrieved_docs = { + doc["fields"]["document_id"]: doc["fields"] for doc in retrieved_docs_dict + } + # Left this here for debugging purposes. + # import json + # for doc in retrieved_docs.values(): + # printable_doc = doc.copy() + # print(printable_doc.keys()) + # printable_doc.pop("embeddings") + # printable_doc.pop("title_embedding") + # print(json.dumps(printable_doc, indent=2)) + + for document in cc_pair.documents: + retrieved_doc = retrieved_docs.get(document.id) + if not retrieved_doc: + if not verify_deleted: + raise ValueError(f"Document not found: {document.id}") + continue + if verify_deleted: + raise ValueError( + f"Document found when it should be deleted: {document.id}" + ) + _verify_document_permissions( + retrieved_doc, + cc_pair, + doc_set_names, + group_names, + doc_creating_user, + ) diff --git a/backend/tests/integration/common_utils/managers/document_set.py b/backend/tests/integration/common_utils/managers/document_set.py new file mode 100644 index 00000000000..8133ccc8712 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/document_set.py @@ -0,0 +1,171 @@ +import time +from uuid import uuid4 + +import requests + +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.constants import MAX_DELAY +from tests.integration.common_utils.test_models import TestDocumentSet +from tests.integration.common_utils.test_models import TestUser + + +class DocumentSetManager: + @staticmethod + def create( + name: str | None = None, + description: str | None = None, + cc_pair_ids: list[int] | None = None, + is_public: bool = True, + users: list[str] | None = None, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestDocumentSet: + if name is None: + name = f"test_doc_set_{str(uuid4())}" + + doc_set_creation_request = { + "name": name, + "description": description or name, + "cc_pair_ids": cc_pair_ids or [], + "is_public": is_public, + "users": users or [], + "groups": groups or [], + } + + response = requests.post( + f"{API_SERVER_URL}/manage/admin/document-set", + json=doc_set_creation_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + return TestDocumentSet( + id=int(response.json()), + name=name, + description=description or name, + cc_pair_ids=cc_pair_ids or [], + is_public=is_public, + is_up_to_date=True, + users=users or [], + groups=groups or [], + ) + + @staticmethod + def edit( + document_set: TestDocumentSet, + user_performing_action: TestUser | None = None, + ) -> bool: + doc_set_update_request = { + "id": document_set.id, + "description": document_set.description, + "cc_pair_ids": document_set.cc_pair_ids, + "is_public": document_set.is_public, + "users": document_set.users, + "groups": document_set.groups, + } + response = requests.patch( + f"{API_SERVER_URL}/manage/admin/document-set", + json=doc_set_update_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return True + + @staticmethod + def delete( + document_set: TestDocumentSet, + user_performing_action: TestUser | None = None, + ) -> bool: + response = requests.delete( + f"{API_SERVER_URL}/manage/admin/document-set/{document_set.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return True + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[TestDocumentSet]: + response = requests.get( + f"{API_SERVER_URL}/manage/document-set", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [ + TestDocumentSet( + id=doc_set["id"], + name=doc_set["name"], + description=doc_set["description"], + cc_pair_ids=[ + cc_pair["id"] for cc_pair in doc_set["cc_pair_descriptors"] + ], + is_public=doc_set["is_public"], + is_up_to_date=doc_set["is_up_to_date"], + users=doc_set["users"], + groups=doc_set["groups"], + ) + for doc_set in response.json() + ] + + @staticmethod + def wait_for_sync( + document_sets_to_check: list[TestDocumentSet] | None = None, + user_performing_action: TestUser | None = None, + ) -> None: + # wait for document sets to be synced + start = time.time() + while True: + doc_sets = DocumentSetManager.get_all(user_performing_action) + if document_sets_to_check: + check_ids = {doc_set.id for doc_set in document_sets_to_check} + doc_set_ids = {doc_set.id for doc_set in doc_sets} + if not check_ids.issubset(doc_set_ids): + raise RuntimeError("Document set not found") + doc_sets = [doc_set for doc_set in doc_sets if doc_set.id in check_ids] + all_up_to_date = all(doc_set.is_up_to_date for doc_set in doc_sets) + + if all_up_to_date: + break + + if time.time() - start > MAX_DELAY: + raise TimeoutError( + f"Document sets were not synced within the {MAX_DELAY} seconds" + ) + else: + print("Document sets were not synced yet, waiting...") + + time.sleep(2) + + @staticmethod + def verify( + document_set: TestDocumentSet, + verify_deleted: bool = False, + user_performing_action: TestUser | None = None, + ) -> None: + doc_sets = DocumentSetManager.get_all(user_performing_action) + for doc_set in doc_sets: + if doc_set.id == document_set.id: + if verify_deleted: + raise ValueError( + f"Document set {document_set.id} found but should have been deleted" + ) + if ( + doc_set.name == document_set.name + and set(doc_set.cc_pair_ids) == set(document_set.cc_pair_ids) + and doc_set.is_public == document_set.is_public + and set(doc_set.users) == set(document_set.users) + and set(doc_set.groups) == set(document_set.groups) + ): + return + if not verify_deleted: + raise ValueError(f"Document set {document_set.id} not found") diff --git a/backend/tests/integration/common_utils/managers/persona.py b/backend/tests/integration/common_utils/managers/persona.py new file mode 100644 index 00000000000..41ff43edb6f --- /dev/null +++ b/backend/tests/integration/common_utils/managers/persona.py @@ -0,0 +1,206 @@ +from uuid import uuid4 + +import requests + +from danswer.search.enums import RecencyBiasSetting +from danswer.server.features.persona.models import PersonaSnapshot +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestPersona +from tests.integration.common_utils.test_models import TestUser + + +class PersonaManager: + @staticmethod + def create( + name: str | None = None, + description: str | None = None, + num_chunks: float = 5, + llm_relevance_filter: bool = True, + is_public: bool = True, + llm_filter_extraction: bool = True, + recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO, + prompt_ids: list[int] | None = None, + document_set_ids: list[int] | None = None, + tool_ids: list[int] | None = None, + llm_model_provider_override: str | None = None, + llm_model_version_override: str | None = None, + users: list[str] | None = None, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestPersona: + name = name or f"test-persona-{uuid4()}" + description = description or f"Description for {name}" + + persona_creation_request = { + "name": name, + "description": description, + "num_chunks": num_chunks, + "llm_relevance_filter": llm_relevance_filter, + "is_public": is_public, + "llm_filter_extraction": llm_filter_extraction, + "recency_bias": recency_bias, + "prompt_ids": prompt_ids or [], + "document_set_ids": document_set_ids or [], + "tool_ids": tool_ids or [], + "llm_model_provider_override": llm_model_provider_override, + "llm_model_version_override": llm_model_version_override, + "users": users or [], + "groups": groups or [], + } + + response = requests.post( + f"{API_SERVER_URL}/persona", + json=persona_creation_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + persona_data = response.json() + + return TestPersona( + id=persona_data["id"], + name=name, + description=description, + num_chunks=num_chunks, + llm_relevance_filter=llm_relevance_filter, + is_public=is_public, + llm_filter_extraction=llm_filter_extraction, + recency_bias=recency_bias, + prompt_ids=prompt_ids or [], + document_set_ids=document_set_ids or [], + tool_ids=tool_ids or [], + llm_model_provider_override=llm_model_provider_override, + llm_model_version_override=llm_model_version_override, + users=users or [], + groups=groups or [], + ) + + @staticmethod + def edit( + persona: TestPersona, + name: str | None = None, + description: str | None = None, + num_chunks: float | None = None, + llm_relevance_filter: bool | None = None, + is_public: bool | None = None, + llm_filter_extraction: bool | None = None, + recency_bias: RecencyBiasSetting | None = None, + prompt_ids: list[int] | None = None, + document_set_ids: list[int] | None = None, + tool_ids: list[int] | None = None, + llm_model_provider_override: str | None = None, + llm_model_version_override: str | None = None, + users: list[str] | None = None, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestPersona: + persona_update_request = { + "name": name or persona.name, + "description": description or persona.description, + "num_chunks": num_chunks or persona.num_chunks, + "llm_relevance_filter": llm_relevance_filter + or persona.llm_relevance_filter, + "is_public": is_public or persona.is_public, + "llm_filter_extraction": llm_filter_extraction + or persona.llm_filter_extraction, + "recency_bias": recency_bias or persona.recency_bias, + "prompt_ids": prompt_ids or persona.prompt_ids, + "document_set_ids": document_set_ids or persona.document_set_ids, + "tool_ids": tool_ids or persona.tool_ids, + "llm_model_provider_override": llm_model_provider_override + or persona.llm_model_provider_override, + "llm_model_version_override": llm_model_version_override + or persona.llm_model_version_override, + "users": users or persona.users, + "groups": groups or persona.groups, + } + + response = requests.patch( + f"{API_SERVER_URL}/persona/{persona.id}", + json=persona_update_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + updated_persona_data = response.json() + + return TestPersona( + id=updated_persona_data["id"], + name=updated_persona_data["name"], + description=updated_persona_data["description"], + num_chunks=updated_persona_data["num_chunks"], + llm_relevance_filter=updated_persona_data["llm_relevance_filter"], + is_public=updated_persona_data["is_public"], + llm_filter_extraction=updated_persona_data["llm_filter_extraction"], + recency_bias=updated_persona_data["recency_bias"], + prompt_ids=updated_persona_data["prompts"], + document_set_ids=updated_persona_data["document_sets"], + tool_ids=updated_persona_data["tools"], + llm_model_provider_override=updated_persona_data[ + "llm_model_provider_override" + ], + llm_model_version_override=updated_persona_data[ + "llm_model_version_override" + ], + users=[user["email"] for user in updated_persona_data["users"]], + groups=updated_persona_data["groups"], + ) + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[PersonaSnapshot]: + response = requests.get( + f"{API_SERVER_URL}/admin/persona", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [PersonaSnapshot(**persona) for persona in response.json()] + + @staticmethod + def verify( + test_persona: TestPersona, + user_performing_action: TestUser | None = None, + ) -> bool: + all_personas = PersonaManager.get_all(user_performing_action) + for persona in all_personas: + if persona.id == test_persona.id: + return ( + persona.name == test_persona.name + and persona.description == test_persona.description + and persona.num_chunks == test_persona.num_chunks + and persona.llm_relevance_filter + == test_persona.llm_relevance_filter + and persona.is_public == test_persona.is_public + and persona.llm_filter_extraction + == test_persona.llm_filter_extraction + and persona.llm_model_provider_override + == test_persona.llm_model_provider_override + and persona.llm_model_version_override + == test_persona.llm_model_version_override + and set(persona.prompts) == set(test_persona.prompt_ids) + and set(persona.document_sets) == set(test_persona.document_set_ids) + and set(persona.tools) == set(test_persona.tool_ids) + and set(user.email for user in persona.users) + == set(test_persona.users) + and set(persona.groups) == set(test_persona.groups) + ) + return False + + @staticmethod + def delete( + persona: TestPersona, + user_performing_action: TestUser | None = None, + ) -> bool: + response = requests.delete( + f"{API_SERVER_URL}/persona/{persona.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + return response.ok diff --git a/backend/tests/integration/common_utils/managers/user.py b/backend/tests/integration/common_utils/managers/user.py new file mode 100644 index 00000000000..0946b8b1fca --- /dev/null +++ b/backend/tests/integration/common_utils/managers/user.py @@ -0,0 +1,122 @@ +from copy import deepcopy +from urllib.parse import urlencode +from uuid import uuid4 + +import requests + +from danswer.db.models import UserRole +from danswer.server.manage.models import AllUsersResponse +from danswer.server.models import FullUserSnapshot +from danswer.server.models import InvitedUserSnapshot +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestUser + + +class UserManager: + @staticmethod + def create( + name: str | None = None, + ) -> TestUser: + if name is None: + name = f"test{str(uuid4())}" + + email = f"{name}@test.com" + password = "test" + + body = { + "email": email, + "username": email, + "password": password, + } + response = requests.post( + url=f"{API_SERVER_URL}/auth/register", + json=body, + headers=GENERAL_HEADERS, + ) + response.raise_for_status() + + test_user = TestUser( + id=response.json()["id"], + email=email, + password=password, + headers=deepcopy(GENERAL_HEADERS), + ) + print(f"Created user {test_user.email}") + + test_user.headers["Cookie"] = UserManager.login_as_user(test_user) + + return test_user + + @staticmethod + def login_as_user(test_user: TestUser) -> str: + data = urlencode( + { + "username": test_user.email, + "password": test_user.password, + } + ) + headers = test_user.headers.copy() + headers["Content-Type"] = "application/x-www-form-urlencoded" + + response = requests.post( + url=f"{API_SERVER_URL}/auth/login", + data=data, + headers=headers, + ) + response.raise_for_status() + result_cookie = next(iter(response.cookies), None) + + if not result_cookie: + raise Exception("Failed to login") + + print(f"Logged in as {test_user.email}") + return f"{result_cookie.name}={result_cookie.value}" + + @staticmethod + def verify_role(user_to_verify: TestUser, target_role: UserRole) -> bool: + response = requests.get( + url=f"{API_SERVER_URL}/me", + headers=user_to_verify.headers, + ) + response.raise_for_status() + return target_role == UserRole(response.json().get("role", "")) + + @staticmethod + def set_role( + user_to_set: TestUser, + target_role: UserRole, + user_to_perform_action: TestUser | None = None, + ) -> None: + if user_to_perform_action is None: + user_to_perform_action = user_to_set + response = requests.patch( + url=f"{API_SERVER_URL}/manage/set-user-role", + json={"user_email": user_to_set.email, "new_role": target_role.value}, + headers=user_to_perform_action.headers, + ) + response.raise_for_status() + + @staticmethod + def verify(user: TestUser, user_to_perform_action: TestUser | None = None) -> None: + if user_to_perform_action is None: + user_to_perform_action = user + response = requests.get( + url=f"{API_SERVER_URL}/manage/users", + headers=user_to_perform_action.headers + if user_to_perform_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + data = response.json() + all_users = AllUsersResponse( + accepted=[FullUserSnapshot(**user) for user in data["accepted"]], + invited=[InvitedUserSnapshot(**user) for user in data["invited"]], + accepted_pages=data["accepted_pages"], + invited_pages=data["invited_pages"], + ) + for accepted_user in all_users.accepted: + if accepted_user.email == user.email and accepted_user.id == user.id: + return + raise ValueError(f"User {user.email} not found") diff --git a/backend/tests/integration/common_utils/managers/user_group.py b/backend/tests/integration/common_utils/managers/user_group.py new file mode 100644 index 00000000000..5f5ac6b0e30 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/user_group.py @@ -0,0 +1,148 @@ +import time +from uuid import uuid4 + +import requests + +from ee.danswer.server.user_group.models import UserGroup +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.constants import MAX_DELAY +from tests.integration.common_utils.test_models import TestUser +from tests.integration.common_utils.test_models import TestUserGroup + + +class UserGroupManager: + @staticmethod + def create( + name: str | None = None, + user_ids: list[str] | None = None, + cc_pair_ids: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestUserGroup: + name = f"{name}-user-group" if name else f"test-user-group-{uuid4()}" + + request = { + "name": name, + "user_ids": user_ids or [], + "cc_pair_ids": cc_pair_ids or [], + } + response = requests.post( + f"{API_SERVER_URL}/manage/admin/user-group", + json=request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + test_user_group = TestUserGroup( + id=response.json()["id"], + name=response.json()["name"], + user_ids=[user["id"] for user in response.json()["users"]], + cc_pair_ids=[cc_pair["id"] for cc_pair in response.json()["cc_pairs"]], + ) + return test_user_group + + @staticmethod + def edit( + user_group: TestUserGroup, + user_performing_action: TestUser | None = None, + ) -> None: + if not user_group.id: + raise ValueError("User group has no ID") + response = requests.patch( + f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}", + json=user_group.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def set_curator_status( + test_user_group: TestUserGroup, + user_to_set_as_curator: TestUser, + is_curator: bool = True, + user_performing_action: TestUser | None = None, + ) -> None: + if not user_to_set_as_curator.id: + raise ValueError("User has no ID") + set_curator_request = { + "user_id": user_to_set_as_curator.id, + "is_curator": is_curator, + } + response = requests.post( + f"{API_SERVER_URL}/manage/admin/user-group/{test_user_group.id}/set-curator", + json=set_curator_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[UserGroup]: + response = requests.get( + f"{API_SERVER_URL}/manage/admin/user-group", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [UserGroup(**ug) for ug in response.json()] + + @staticmethod + def verify( + user_group: TestUserGroup, + verify_deleted: bool = False, + user_performing_action: TestUser | None = None, + ) -> None: + all_user_groups = UserGroupManager.get_all(user_performing_action) + for fetched_user_group in all_user_groups: + if user_group.id == fetched_user_group.id: + if verify_deleted: + raise ValueError( + f"User group {user_group.id} found but should be deleted" + ) + fetched_cc_ids = {cc_pair.id for cc_pair in fetched_user_group.cc_pairs} + fetched_user_ids = {user.id for user in fetched_user_group.users} + user_group_cc_ids = set(user_group.cc_pair_ids) + user_group_user_ids = set(user_group.user_ids) + if ( + fetched_cc_ids == user_group_cc_ids + and fetched_user_ids == user_group_user_ids + ): + return + if not verify_deleted: + raise ValueError(f"User group {user_group.id} not found") + + @staticmethod + def wait_for_sync( + user_groups_to_check: list[TestUserGroup] | None = None, + user_performing_action: TestUser | None = None, + ) -> None: + start = time.time() + while True: + user_groups = UserGroupManager.get_all(user_performing_action) + if user_groups_to_check: + check_ids = {user_group.id for user_group in user_groups_to_check} + user_group_ids = {user_group.id for user_group in user_groups} + if not check_ids.issubset(user_group_ids): + raise RuntimeError("Document set not found") + user_groups = [ + user_group + for user_group in user_groups + if user_group.id in check_ids + ] + if all(ug.is_up_to_date for ug in user_groups): + return + + if time.time() - start > MAX_DELAY: + raise TimeoutError( + f"User groups were not synced within the {MAX_DELAY} seconds" + ) + else: + print("User groups were not synced yet, waiting...") + time.sleep(2) diff --git a/backend/tests/integration/common_utils/reset.py b/backend/tests/integration/common_utils/reset.py index 3815aa9f972..a13ec184b45 100644 --- a/backend/tests/integration/common_utils/reset.py +++ b/backend/tests/integration/common_utils/reset.py @@ -20,7 +20,6 @@ from danswer.indexing.models import IndexingSetting from danswer.main import setup_postgres from danswer.main import setup_vespa -from tests.integration.common_utils.llm import seed_default_openai_provider def _run_migrations( @@ -32,6 +31,7 @@ def _run_migrations( # Create an Alembic configuration object alembic_cfg = Config("alembic.ini") alembic_cfg.set_section_option("logger_alembic", "level", "WARN") + alembic_cfg.attributes["configure_logger"] = False # Set the SQLAlchemy URL in the Alembic configuration alembic_cfg.set_main_option("sqlalchemy.url", database_url) @@ -131,11 +131,13 @@ def reset_vespa() -> None: search_settings = get_current_search_settings(db_session) index_name = search_settings.index_name - setup_vespa( + success = setup_vespa( document_index=VespaIndex(index_name=index_name, secondary_index_name=None), index_setting=IndexingSetting.from_db_model(search_settings), secondary_index_setting=None, ) + if not success: + raise RuntimeError("Could not connect to Vespa within the specified timeout.") for _ in range(5): try: @@ -167,6 +169,4 @@ def reset_all() -> None: reset_postgres() print("Resetting Vespa...") reset_vespa() - print("Seeding LLM Providers...") - seed_default_openai_provider() print("Finished resetting all.") diff --git a/backend/tests/integration/common_utils/seed_documents.py b/backend/tests/integration/common_utils/seed_documents.py deleted file mode 100644 index b6720c9aebe..00000000000 --- a/backend/tests/integration/common_utils/seed_documents.py +++ /dev/null @@ -1,72 +0,0 @@ -import uuid - -import requests -from pydantic import BaseModel - -from danswer.configs.constants import DocumentSource -from tests.integration.common_utils.connectors import ConnectorClient -from tests.integration.common_utils.constants import API_SERVER_URL - - -class SimpleTestDocument(BaseModel): - id: str - content: str - - -class SeedDocumentResponse(BaseModel): - cc_pair_id: int - documents: list[SimpleTestDocument] - - -class TestDocumentClient: - @staticmethod - def seed_documents( - num_docs: int = 5, cc_pair_id: int | None = None - ) -> SeedDocumentResponse: - if not cc_pair_id: - connector_details = ConnectorClient.create_connector() - cc_pair_id = connector_details.cc_pair_id - - # Create and ingest some documents - documents: list[dict] = [] - for _ in range(num_docs): - document_id = f"test-doc-{uuid.uuid4()}" - document = { - "document": { - "id": document_id, - "sections": [ - { - "text": f"This is test document {document_id}", - "link": f"{document_id}", - } - ], - "source": DocumentSource.NOT_APPLICABLE, - # just for testing metadata - "metadata": {"document_id": document_id}, - "semantic_identifier": f"Test Document {document_id}", - "from_ingestion_api": True, - }, - "cc_pair_id": cc_pair_id, - } - documents.append(document) - response = requests.post( - f"{API_SERVER_URL}/danswer-api/ingestion", - json=document, - ) - response.raise_for_status() - - print("Seeding completed successfully.") - return SeedDocumentResponse( - cc_pair_id=cc_pair_id, - documents=[ - SimpleTestDocument( - id=document["document"]["id"], - content=document["document"]["sections"][0]["text"], - ) - for document in documents - ], - ) - - -if __name__ == "__main__": - seed_documents_resp = TestDocumentClient.seed_documents() diff --git a/backend/tests/integration/common_utils/test_models.py b/backend/tests/integration/common_utils/test_models.py new file mode 100644 index 00000000000..04db0851e3d --- /dev/null +++ b/backend/tests/integration/common_utils/test_models.py @@ -0,0 +1,120 @@ +from typing import Any +from uuid import UUID + +from pydantic import BaseModel +from pydantic import Field + +from danswer.auth.schemas import UserRole +from danswer.search.enums import RecencyBiasSetting +from danswer.server.documents.models import DocumentSource +from danswer.server.documents.models import InputType + +""" +These data models are used to represent the data on the testing side of things. +This means the flow is: +1. Make request that changes data in db +2. Make a change to the testing model +3. Retrieve data from db +4. Compare db data with testing model to verify +""" + + +class TestAPIKey(BaseModel): + api_key_id: int + api_key_display: str + api_key: str | None = None # only present on initial creation + api_key_name: str | None = None + api_key_role: UserRole + + user_id: UUID + headers: dict + + +class TestUser(BaseModel): + id: str + email: str + password: str + headers: dict + + +class TestCredential(BaseModel): + id: int + name: str + credential_json: dict[str, Any] + admin_public: bool + source: DocumentSource + curator_public: bool + groups: list[int] + + +class TestConnector(BaseModel): + id: int + name: str + source: DocumentSource + input_type: InputType + connector_specific_config: dict[str, Any] + groups: list[int] | None = None + is_public: bool | None = None + + +class SimpleTestDocument(BaseModel): + id: str + content: str + + +class TestCCPair(BaseModel): + id: int + name: str + connector_id: int + credential_id: int + is_public: bool + groups: list[int] + documents: list[SimpleTestDocument] = Field(default_factory=list) + + +class TestUserGroup(BaseModel): + id: int + name: str + user_ids: list[str] + cc_pair_ids: list[int] + + +class TestLLMProvider(BaseModel): + id: int + name: str + provider: str + api_key: str + default_model_name: str + is_public: bool + groups: list[TestUserGroup] + api_base: str | None = None + api_version: str | None = None + + +class TestDocumentSet(BaseModel): + id: int + name: str + description: str + cc_pair_ids: list[int] = Field(default_factory=list) + is_public: bool + is_up_to_date: bool + users: list[str] = Field(default_factory=list) + groups: list[int] = Field(default_factory=list) + + +class TestPersona(BaseModel): + id: int + name: str + description: str + num_chunks: float + llm_relevance_filter: bool + is_public: bool + llm_filter_extraction: bool + recency_bias: RecencyBiasSetting + prompt_ids: list[int] + document_set_ids: list[int] + tool_ids: list[int] + llm_model_provider_override: str | None + llm_model_version_override: str | None + users: list[str] + groups: list[int] diff --git a/backend/tests/integration/common_utils/user_groups.py b/backend/tests/integration/common_utils/user_groups.py deleted file mode 100644 index 0cd44066463..00000000000 --- a/backend/tests/integration/common_utils/user_groups.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import cast - -import requests - -from ee.danswer.server.user_group.models import UserGroup -from ee.danswer.server.user_group.models import UserGroupCreate -from tests.integration.common_utils.constants import API_SERVER_URL - - -class UserGroupClient: - @staticmethod - def create_user_group(user_group_creation_request: UserGroupCreate) -> int: - response = requests.post( - f"{API_SERVER_URL}/manage/admin/user-group", - json=user_group_creation_request.model_dump(), - ) - response.raise_for_status() - return cast(int, response.json()["id"]) - - @staticmethod - def fetch_user_groups() -> list[UserGroup]: - response = requests.get(f"{API_SERVER_URL}/manage/admin/user-group") - response.raise_for_status() - return [UserGroup(**ug) for ug in response.json()] diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py index 6c46e9f875e..314b78ad36f 100644 --- a/backend/tests/integration/conftest.py +++ b/backend/tests/integration/conftest.py @@ -1,3 +1,4 @@ +import os from collections.abc import Generator import pytest @@ -9,6 +10,25 @@ from tests.integration.common_utils.vespa import TestVespaClient +def load_env_vars(env_file: str = ".env") -> None: + current_dir = os.path.dirname(os.path.abspath(__file__)) + env_path = os.path.join(current_dir, env_file) + try: + with open(env_path, "r") as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + key, value = line.split("=", 1) + os.environ[key] = value.strip() + print("Successfully loaded environment variables") + except FileNotFoundError: + print(f"File {env_file} not found") + + +# Load environment variables at the module level +load_env_vars() + + @pytest.fixture def db_session() -> Generator[Session, None, None]: with get_session_context_manager() as session: diff --git a/backend/tests/integration/tests/connector/test_connector_deletion.py b/backend/tests/integration/tests/connector/test_connector_deletion.py new file mode 100644 index 00000000000..08d34c4c1a8 --- /dev/null +++ b/backend/tests/integration/tests/connector/test_connector_deletion.py @@ -0,0 +1,305 @@ +""" +This file contains tests for the following: +- Ensuring deletion of a connector also: + - deletes the documents in vespa for that connector + - updates the document sets and user groups to remove the connector +- Ensure that deleting a connector that is part of an overlapping document set and/or user group works as expected +""" +from uuid import uuid4 + +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import NUM_DOCS +from tests.integration.common_utils.managers.api_key import APIKeyManager +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.document import DocumentManager +from tests.integration.common_utils.managers.document_set import DocumentSetManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager +from tests.integration.common_utils.test_models import TestAPIKey +from tests.integration.common_utils.test_models import TestUser +from tests.integration.common_utils.test_models import TestUserGroup +from tests.integration.common_utils.vespa import TestVespaClient + + +def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + # add api key to user + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, + ) + + # create connectors + cc_pair_1 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + cc_pair_2 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + + # seed documents + cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_1, + num_docs=NUM_DOCS, + api_key=api_key, + ) + cc_pair_2 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_2, + num_docs=NUM_DOCS, + api_key=api_key, + ) + + # create document sets + doc_set_1 = DocumentSetManager.create( + name="Test Document Set 1", + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + doc_set_2 = DocumentSetManager.create( + name="Test Document Set 2", + cc_pair_ids=[cc_pair_1.id, cc_pair_2.id], + user_performing_action=admin_user, + ) + + # wait for document sets to be synced + DocumentSetManager.wait_for_sync(user_performing_action=admin_user) + + print("Document sets created and synced") + + # create user groups + user_group_1: TestUserGroup = UserGroupManager.create( + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + user_group_2: TestUserGroup = UserGroupManager.create( + cc_pair_ids=[cc_pair_1.id, cc_pair_2.id], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync(user_performing_action=admin_user) + + # delete connector 1 + CCPairManager.pause_cc_pair( + cc_pair=cc_pair_1, + user_performing_action=admin_user, + ) + CCPairManager.delete( + cc_pair=cc_pair_1, + user_performing_action=admin_user, + ) + + # Update local records to match the database for later comparison + user_group_1.cc_pair_ids = [] + user_group_2.cc_pair_ids = [cc_pair_2.id] + doc_set_1.cc_pair_ids = [] + doc_set_2.cc_pair_ids = [cc_pair_2.id] + cc_pair_1.groups = [] + cc_pair_2.groups = [user_group_2.id] + + CCPairManager.wait_for_deletion_completion(user_performing_action=admin_user) + + # validate vespa documents + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + doc_set_names=[], + group_names=[], + doc_creating_user=admin_user, + verify_deleted=True, + ) + + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + doc_set_names=[doc_set_2.name], + group_names=[user_group_2.name], + doc_creating_user=admin_user, + verify_deleted=False, + ) + + # check that only connector 1 is deleted + CCPairManager.verify( + cc_pair=cc_pair_2, + user_performing_action=admin_user, + ) + + # validate document sets + DocumentSetManager.verify( + document_set=doc_set_1, + user_performing_action=admin_user, + ) + DocumentSetManager.verify( + document_set=doc_set_2, + user_performing_action=admin_user, + ) + + # validate user groups + UserGroupManager.verify( + user_group=user_group_1, + user_performing_action=admin_user, + ) + UserGroupManager.verify( + user_group=user_group_2, + user_performing_action=admin_user, + ) + + +def test_connector_deletion_for_overlapping_connectors( + reset: None, vespa_client: TestVespaClient +) -> None: + """Checks to make sure that connectors with overlapping documents work properly. Specifically, that the overlapping + document (1) still exists and (2) has the right document set / group post-deletion of one of the connectors. + """ + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + # add api key to user + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, + ) + + # create connectors + cc_pair_1 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + cc_pair_2 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + + doc_ids = [str(uuid4())] + cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_1, + document_ids=doc_ids, + api_key=api_key, + ) + cc_pair_2 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_2, + document_ids=doc_ids, + api_key=api_key, + ) + + # verify vespa document exists and that it is not in any document sets or groups + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + doc_set_names=[], + group_names=[], + doc_creating_user=admin_user, + ) + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + doc_set_names=[], + group_names=[], + doc_creating_user=admin_user, + ) + + # create document set + doc_set_1 = DocumentSetManager.create( + name="Test Document Set 1", + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + DocumentSetManager.wait_for_sync( + document_sets_to_check=[doc_set_1], + user_performing_action=admin_user, + ) + + print("Document set 1 created and synced") + + # verify vespa document is in the document set + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + doc_set_names=[doc_set_1.name], + doc_creating_user=admin_user, + ) + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + doc_creating_user=admin_user, + ) + + # create a user group and attach it to connector 1 + user_group_1: TestUserGroup = UserGroupManager.create( + name="Test User Group 1", + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], + user_performing_action=admin_user, + ) + cc_pair_1.groups = [user_group_1.id] + + print("User group 1 created and synced") + + # create a user group and attach it to connector 2 + user_group_2: TestUserGroup = UserGroupManager.create( + name="Test User Group 2", + cc_pair_ids=[cc_pair_2.id], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_2], + user_performing_action=admin_user, + ) + cc_pair_2.groups = [user_group_2.id] + + print("User group 2 created and synced") + + # verify vespa document is in the user group + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + group_names=[user_group_1.name, user_group_2.name], + doc_creating_user=admin_user, + ) + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + group_names=[user_group_1.name, user_group_2.name], + doc_creating_user=admin_user, + ) + + # EVERYTHING BELOW HERE IS CURRENTLY BROKEN AND NEEDS TO BE FIXED SERVER SIDE + + # delete connector 1 + CCPairManager.pause_cc_pair( + cc_pair=cc_pair_1, + user_performing_action=admin_user, + ) + CCPairManager.delete( + cc_pair=cc_pair_1, + user_performing_action=admin_user, + ) + + # wait for deletion to finish + CCPairManager.wait_for_deletion_completion(user_performing_action=admin_user) + + print("Connector 1 deleted") + + # check that only connector 1 is deleted + # TODO: check for the CC pair rather than the connector once the refactor is done + CCPairManager.verify( + cc_pair=cc_pair_1, + verify_deleted=True, + user_performing_action=admin_user, + ) + CCPairManager.verify( + cc_pair=cc_pair_2, + user_performing_action=admin_user, + ) + + # verify the document is not in any document sets + # verify the document is only in user group 2 + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + doc_set_names=[], + group_names=[user_group_2.name], + doc_creating_user=admin_user, + verify_deleted=False, + ) diff --git a/backend/tests/integration/tests/connector/test_deletion.py b/backend/tests/integration/tests/connector/test_deletion.py deleted file mode 100644 index 78ad2378af9..00000000000 --- a/backend/tests/integration/tests/connector/test_deletion.py +++ /dev/null @@ -1,190 +0,0 @@ -import time - -from danswer.db.enums import ConnectorCredentialPairStatus -from danswer.server.features.document_set.models import DocumentSetCreationRequest -from tests.integration.common_utils.connectors import ConnectorClient -from tests.integration.common_utils.constants import MAX_DELAY -from tests.integration.common_utils.document_sets import DocumentSetClient -from tests.integration.common_utils.seed_documents import TestDocumentClient -from tests.integration.common_utils.user_groups import UserGroupClient -from tests.integration.common_utils.user_groups import UserGroupCreate -from tests.integration.common_utils.vespa import TestVespaClient - - -def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None: - # create connectors - c1_details = ConnectorClient.create_connector(name_prefix="tc1") - c2_details = ConnectorClient.create_connector(name_prefix="tc2") - c1_seed_res = TestDocumentClient.seed_documents( - num_docs=5, cc_pair_id=c1_details.cc_pair_id - ) - c2_seed_res = TestDocumentClient.seed_documents( - num_docs=5, cc_pair_id=c2_details.cc_pair_id - ) - - # create document sets - doc_set_1_id = DocumentSetClient.create_document_set( - DocumentSetCreationRequest( - name="Test Document Set 1", - description="Intially connector to be deleted, should be empty after test", - cc_pair_ids=[c1_details.cc_pair_id], - is_public=True, - users=[], - groups=[], - ) - ) - - doc_set_2_id = DocumentSetClient.create_document_set( - DocumentSetCreationRequest( - name="Test Document Set 2", - description="Intially both connectors, should contain undeleted connector after test", - cc_pair_ids=[c1_details.cc_pair_id, c2_details.cc_pair_id], - is_public=True, - users=[], - groups=[], - ) - ) - - # wait for document sets to be synced - start = time.time() - while True: - doc_sets = DocumentSetClient.fetch_document_sets() - doc_set_1 = next( - (doc_set for doc_set in doc_sets if doc_set.id == doc_set_1_id), None - ) - doc_set_2 = next( - (doc_set for doc_set in doc_sets if doc_set.id == doc_set_2_id), None - ) - - if not doc_set_1 or not doc_set_2: - raise RuntimeError("Document set not found") - - if doc_set_1.is_up_to_date and doc_set_2.is_up_to_date: - break - - if time.time() - start > MAX_DELAY: - raise TimeoutError("Document sets were not synced within the max delay") - - time.sleep(2) - - print("Document sets created and synced") - - # if so, create ACLs - user_group_1 = UserGroupClient.create_user_group( - UserGroupCreate( - name="Test User Group 1", user_ids=[], cc_pair_ids=[c1_details.cc_pair_id] - ) - ) - user_group_2 = UserGroupClient.create_user_group( - UserGroupCreate( - name="Test User Group 2", - user_ids=[], - cc_pair_ids=[c1_details.cc_pair_id, c2_details.cc_pair_id], - ) - ) - - # wait for user groups to be available - start = time.time() - while True: - user_groups = {ug.id: ug for ug in UserGroupClient.fetch_user_groups()} - - if not ( - user_group_1 in user_groups.keys() and user_group_2 in user_groups.keys() - ): - raise RuntimeError("User groups not found") - - if ( - user_groups[user_group_1].is_up_to_date - and user_groups[user_group_2].is_up_to_date - ): - break - - if time.time() - start > MAX_DELAY: - raise TimeoutError("User groups were not synced within the max delay") - - time.sleep(2) - - print("User groups created and synced") - - # delete connector 1 - ConnectorClient.update_connector_status( - cc_pair_id=c1_details.cc_pair_id, status=ConnectorCredentialPairStatus.PAUSED - ) - ConnectorClient.delete_connector( - connector_id=c1_details.connector_id, credential_id=c1_details.credential_id - ) - - start = time.time() - while True: - connectors = ConnectorClient.get_connectors() - - if c1_details.connector_id not in [c["id"] for c in connectors]: - break - - if time.time() - start > MAX_DELAY: - raise TimeoutError("Connector 1 was not deleted within the max delay") - - time.sleep(2) - - print("Connector 1 deleted") - - # validate vespa documents - c1_vespa_docs = vespa_client.get_documents_by_id( - [doc.id for doc in c1_seed_res.documents] - )["documents"] - c2_vespa_docs = vespa_client.get_documents_by_id( - [doc.id for doc in c2_seed_res.documents] - )["documents"] - - assert len(c1_vespa_docs) == 0 - assert len(c2_vespa_docs) == 5 - - for doc in c2_vespa_docs: - assert doc["fields"]["access_control_list"] == { - "PUBLIC": 1, - "group:Test User Group 2": 1, - } - assert doc["fields"]["document_sets"] == {"Test Document Set 2": 1} - - # check that only connector 1 is deleted - # TODO: check for the CC pair rather than the connector once the refactor is done - all_connectors = ConnectorClient.get_connectors() - assert len(all_connectors) == 1 - assert all_connectors[0]["id"] == c2_details.connector_id - - # validate document sets - all_doc_sets = DocumentSetClient.fetch_document_sets() - assert len(all_doc_sets) == 2 - - doc_set_1_found = False - doc_set_2_found = False - for doc_set in all_doc_sets: - if doc_set.id == doc_set_1_id: - doc_set_1_found = True - assert doc_set.cc_pair_descriptors == [] - - if doc_set.id == doc_set_2_id: - doc_set_2_found = True - assert len(doc_set.cc_pair_descriptors) == 1 - assert doc_set.cc_pair_descriptors[0].id == c2_details.cc_pair_id - - assert doc_set_1_found - assert doc_set_2_found - - # validate user groups - all_user_groups = UserGroupClient.fetch_user_groups() - assert len(all_user_groups) == 2 - - user_group_1_found = False - user_group_2_found = False - for user_group in all_user_groups: - if user_group.id == user_group_1: - user_group_1_found = True - assert user_group.cc_pairs == [] - if user_group.id == user_group_2: - user_group_2_found = True - assert len(user_group.cc_pairs) == 1 - assert user_group.cc_pairs[0].id == c2_details.cc_pair_id - - assert user_group_1_found - assert user_group_2_found diff --git a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py index b00c2e3d1e6..981a9cbd026 100644 --- a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py +++ b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py @@ -1,34 +1,59 @@ import requests -from tests.integration.common_utils.connectors import ConnectorClient +from danswer.configs.constants import MessageType from tests.integration.common_utils.constants import API_SERVER_URL -from tests.integration.common_utils.seed_documents import TestDocumentClient +from tests.integration.common_utils.constants import NUM_DOCS +from tests.integration.common_utils.llm import LLMProviderManager +from tests.integration.common_utils.managers.api_key import APIKeyManager +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.document import DocumentManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.test_models import TestAPIKey +from tests.integration.common_utils.test_models import TestCCPair +from tests.integration.common_utils.test_models import TestUser def test_send_message_simple_with_history(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + # create connectors - c1_details = ConnectorClient.create_connector(name_prefix="tc1") - c1_seed_res = TestDocumentClient.seed_documents( - num_docs=5, cc_pair_id=c1_details.cc_pair_id + cc_pair_1: TestCCPair = CCPairManager.create_from_scratch( + user_performing_action=admin_user, + ) + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, + ) + LLMProviderManager.create(user_performing_action=admin_user) + cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_1, + num_docs=NUM_DOCS, + api_key=api_key, ) response = requests.post( f"{API_SERVER_URL}/chat/send-message-simple-with-history", json={ - "messages": [{"message": c1_seed_res.documents[0].content, "role": "user"}], + "messages": [ + { + "message": cc_pair_1.documents[0].content, + "role": MessageType.USER.value, + } + ], "persona_id": 0, "prompt_id": 0, }, + headers=admin_user.headers, ) assert response.status_code == 200 response_json = response.json() # Check that the top document is the correct document - assert response_json["simple_search_docs"][0]["id"] == c1_seed_res.documents[0].id + assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id # assert that the metadata is correct - for doc in c1_seed_res.documents: + for doc in cc_pair_1.documents: found_doc = next( (x for x in response_json["simple_search_docs"] if x["id"] == doc.id), None ) diff --git a/backend/tests/integration/tests/document_set/test_syncing.py b/backend/tests/integration/tests/document_set/test_syncing.py index 9a6b42ab5df..ab31b751471 100644 --- a/backend/tests/integration/tests/document_set/test_syncing.py +++ b/backend/tests/integration/tests/document_set/test_syncing.py @@ -1,78 +1,66 @@ -import time - -from danswer.server.features.document_set.models import DocumentSetCreationRequest -from tests.integration.common_utils.document_sets import DocumentSetClient -from tests.integration.common_utils.seed_documents import TestDocumentClient +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import NUM_DOCS +from tests.integration.common_utils.managers.api_key import APIKeyManager +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.document import DocumentManager +from tests.integration.common_utils.managers.document_set import DocumentSetManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.test_models import TestAPIKey +from tests.integration.common_utils.test_models import TestUser from tests.integration.common_utils.vespa import TestVespaClient def test_multiple_document_sets_syncing_same_connnector( reset: None, vespa_client: TestVespaClient ) -> None: - # Seed documents - seed_result = TestDocumentClient.seed_documents(num_docs=5) - cc_pair_id = seed_result.cc_pair_id + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") - # Create first document set - doc_set_1_id = DocumentSetClient.create_document_set( - DocumentSetCreationRequest( - name="Test Document Set 1", - description="First test document set", - cc_pair_ids=[cc_pair_id], - is_public=True, - users=[], - groups=[], - ) + # add api key to user + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, ) - doc_set_2_id = DocumentSetClient.create_document_set( - DocumentSetCreationRequest( - name="Test Document Set 2", - description="Second test document set", - cc_pair_ids=[cc_pair_id], - is_public=True, - users=[], - groups=[], - ) + # create connector + cc_pair_1 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, ) - # wait for syncing to be complete - max_delay = 45 - start = time.time() - while True: - doc_sets = DocumentSetClient.fetch_document_sets() - doc_set_1 = next( - (doc_set for doc_set in doc_sets if doc_set.id == doc_set_1_id), None - ) - doc_set_2 = next( - (doc_set for doc_set in doc_sets if doc_set.id == doc_set_2_id), None - ) - - if not doc_set_1 or not doc_set_2: - raise RuntimeError("Document set not found") - - if doc_set_1.is_up_to_date and doc_set_2.is_up_to_date: - assert [ccp.id for ccp in doc_set_1.cc_pair_descriptors] == [ - ccp.id for ccp in doc_set_2.cc_pair_descriptors - ] - break + # seed documents + cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_1, + num_docs=NUM_DOCS, + api_key=api_key, + ) - if time.time() - start > max_delay: - raise TimeoutError("Document sets were not synced within the max delay") + # Create document sets + doc_set_1 = DocumentSetManager.create( + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + doc_set_2 = DocumentSetManager.create( + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) - time.sleep(2) + DocumentSetManager.wait_for_sync( + user_performing_action=admin_user, + ) - # get names so we can compare to what is in vespa - doc_sets = DocumentSetClient.fetch_document_sets() - doc_set_names = {doc_set.name for doc_set in doc_sets} + DocumentSetManager.verify( + document_set=doc_set_1, + user_performing_action=admin_user, + ) + DocumentSetManager.verify( + document_set=doc_set_2, + user_performing_action=admin_user, + ) # make sure documents are as expected - seeded_document_ids = [doc.id for doc in seed_result.documents] - - result = vespa_client.get_documents_by_id([doc.id for doc in seed_result.documents]) - documents = result["documents"] - assert len(documents) == len(seed_result.documents) - assert all(doc["fields"]["document_id"] in seeded_document_ids for doc in documents) - assert all( - set(doc["fields"]["document_sets"].keys()) == doc_set_names for doc in documents + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + doc_set_names=[doc_set_1.name, doc_set_2.name], + doc_creating_user=admin_user, ) diff --git a/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py b/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py new file mode 100644 index 00000000000..c52c5826eae --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py @@ -0,0 +1,179 @@ +""" +This file takes the happy path to adding a curator to a user group and then tests +the permissions of the curator manipulating connector-credential pairs. +""" +import pytest +from requests.exceptions import HTTPError + +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.connector import ConnectorManager +from tests.integration.common_utils.managers.credential import CredentialManager +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_cc_pair_permissions(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # Creating a curator + curator: TestUser = UserManager.create(name="curator") + + # Creating a user group + user_group_1 = UserGroupManager.create( + name="curated_user_group", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + # setting the user as a curator for the user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Creating another user group that the user is not a curator of + user_group_2 = UserGroupManager.create( + name="uncurated_user_group", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # Create a credentials that the curator is and is not curator of + connector_1 = ConnectorManager.create( + name="curator_owned_connector", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + is_public=False, + user_performing_action=admin_user, + ) + # currently we dont enforce permissions at the connector level + # pending cc_pair -> connector rework + # connector_2 = ConnectorManager.create( + # name="curator_visible_connector", + # source=DocumentSource.CONFLUENCE, + # groups=[user_group_2.id], + # is_public=False, + # user_performing_action=admin_user, + # ) + credential_1 = CredentialManager.create( + name="curator_owned_credential", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + curator_public=False, + user_performing_action=admin_user, + ) + credential_2 = CredentialManager.create( + name="curator_visible_credential", + source=DocumentSource.CONFLUENCE, + groups=[user_group_2.id], + curator_public=False, + user_performing_action=admin_user, + ) + + # END OF HAPPY PATH + + """Tests for things Curators should not be able to do""" + + # Curators should not be able to create a public cc pair + with pytest.raises(HTTPError): + CCPairManager.create( + connector_id=connector_1.id, + credential_id=credential_1.id, + name="invalid_cc_pair_1", + groups=[user_group_1.id], + is_public=True, + user_performing_action=curator, + ) + + # Curators should not be able to create a cc + # pair for a user group they are not a curator of + with pytest.raises(HTTPError): + CCPairManager.create( + connector_id=connector_1.id, + credential_id=credential_1.id, + name="invalid_cc_pair_2", + groups=[user_group_1.id, user_group_2.id], + is_public=False, + user_performing_action=curator, + ) + + # Curators should not be able to create a cc + # pair without an attached user group + with pytest.raises(HTTPError): + CCPairManager.create( + connector_id=connector_1.id, + credential_id=credential_1.id, + name="invalid_cc_pair_2", + groups=[], + is_public=False, + user_performing_action=curator, + ) + + # # This test is currently disabled because permissions are + # # not enforced at the connector level + # # Curators should not be able to create a cc pair + # # for a user group that the connector does not belong to (NOT WORKING) + # with pytest.raises(HTTPError): + # CCPairManager.create( + # connector_id=connector_2.id, + # credential_id=credential_1.id, + # name="invalid_cc_pair_3", + # groups=[user_group_1.id], + # is_public=False, + # user_performing_action=curator, + # ) + + # Curators should not be able to create a cc + # pair for a user group that the credential does not belong to + with pytest.raises(HTTPError): + CCPairManager.create( + connector_id=connector_1.id, + credential_id=credential_2.id, + name="invalid_cc_pair_4", + groups=[user_group_1.id], + is_public=False, + user_performing_action=curator, + ) + + """Tests for things Curators should be able to do""" + + # Curators should be able to create a private + # cc pair for a user group they are a curator of + valid_cc_pair = CCPairManager.create( + name="valid_cc_pair", + connector_id=connector_1.id, + credential_id=credential_1.id, + groups=[user_group_1.id], + is_public=False, + user_performing_action=curator, + ) + + # Verify the created cc pair + CCPairManager.verify( + cc_pair=valid_cc_pair, + user_performing_action=curator, + ) + + # Test pausing the cc pair + CCPairManager.pause_cc_pair(valid_cc_pair, user_performing_action=curator) + + # Test deleting the cc pair + CCPairManager.delete(valid_cc_pair, user_performing_action=curator) + CCPairManager.wait_for_deletion_completion(user_performing_action=curator) + + CCPairManager.verify( + cc_pair=valid_cc_pair, + verify_deleted=True, + user_performing_action=curator, + ) diff --git a/backend/tests/integration/tests/permissions/test_connector_permissions.py b/backend/tests/integration/tests/permissions/test_connector_permissions.py new file mode 100644 index 00000000000..279c0568bfb --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_connector_permissions.py @@ -0,0 +1,136 @@ +""" +This file takes the happy path to adding a curator to a user group and then tests +the permissions of the curator manipulating connectors. +""" +import pytest +from requests.exceptions import HTTPError + +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.connector import ConnectorManager +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_connector_permissions(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # Creating a curator + curator: TestUser = UserManager.create(name="curator") + + # Creating a user group + user_group_1 = UserGroupManager.create( + name="user_group_1", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + # setting the user as a curator for the user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Creating another user group that the user is not a curator of + user_group_2 = UserGroupManager.create( + name="user_group_2", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # END OF HAPPY PATH + + """Tests for things Curators should not be able to do""" + + # Curators should not be able to create a public connector + with pytest.raises(HTTPError): + ConnectorManager.create( + name="invalid_connector_1", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + is_public=True, + user_performing_action=curator, + ) + + # Curators should not be able to create a cc pair for a + # user group they are not a curator of + with pytest.raises(HTTPError): + ConnectorManager.create( + name="invalid_connector_2", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id, user_group_2.id], + is_public=False, + user_performing_action=curator, + ) + + """Tests for things Curators should be able to do""" + + # Curators should be able to create a private + # connector for a user group they are a curator of + valid_connector = ConnectorManager.create( + name="valid_connector", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + is_public=False, + user_performing_action=curator, + ) + assert valid_connector.id is not None + + # Verify the created connector + created_connector = ConnectorManager.get( + valid_connector.id, user_performing_action=curator + ) + assert created_connector.name == valid_connector.name + assert created_connector.source == valid_connector.source + + # Verify that the connector can be found in the list of all connectors + all_connectors = ConnectorManager.get_all(user_performing_action=curator) + assert any(conn.id == valid_connector.id for conn in all_connectors) + + # Test editing the connector + valid_connector.name = "updated_valid_connector" + ConnectorManager.edit(valid_connector, user_performing_action=curator) + + # Verify the edit + updated_connector = ConnectorManager.get( + valid_connector.id, user_performing_action=curator + ) + assert updated_connector.name == "updated_valid_connector" + + # Test deleting the connector + ConnectorManager.delete(connector=valid_connector, user_performing_action=curator) + + # Verify the deletion + all_connectors_after_delete = ConnectorManager.get_all( + user_performing_action=curator + ) + assert all(conn.id != valid_connector.id for conn in all_connectors_after_delete) + + # Test that curator cannot create a connector for a group they are not a curator of + with pytest.raises(HTTPError): + ConnectorManager.create( + name="invalid_connector_3", + source=DocumentSource.CONFLUENCE, + groups=[user_group_2.id], + is_public=False, + user_performing_action=curator, + ) + + # Test that curator cannot create a public connector + with pytest.raises(HTTPError): + ConnectorManager.create( + name="invalid_connector_4", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + is_public=True, + user_performing_action=curator, + ) diff --git a/backend/tests/integration/tests/permissions/test_credential_permissions.py b/backend/tests/integration/tests/permissions/test_credential_permissions.py new file mode 100644 index 00000000000..1311f1a3d2d --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_credential_permissions.py @@ -0,0 +1,108 @@ +""" +This file takes the happy path to adding a curator to a user group and then tests +the permissions of the curator manipulating credentials. +""" +import pytest +from requests.exceptions import HTTPError + +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.credential import CredentialManager +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_credential_permissions(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # Creating a curator + curator: TestUser = UserManager.create(name="curator") + + # Creating a user group + user_group_1 = UserGroupManager.create( + name="user_group_1", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + # setting the user as a curator for the user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Creating another user group that the user is not a curator of + user_group_2 = UserGroupManager.create( + name="user_group_2", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # END OF HAPPY PATH + + """Tests for things Curators should not be able to do""" + + # Curators should not be able to create a public credential + with pytest.raises(HTTPError): + CredentialManager.create( + name="invalid_credential_1", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + curator_public=True, + user_performing_action=curator, + ) + + # Curators should not be able to create a credential for a user group they are not a curator of + with pytest.raises(HTTPError): + CredentialManager.create( + name="invalid_credential_2", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id, user_group_2.id], + curator_public=False, + user_performing_action=curator, + ) + + """Tests for things Curators should be able to do""" + # Curators should be able to create a private credential for a user group they are a curator of + valid_credential = CredentialManager.create( + name="valid_credential", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + curator_public=False, + user_performing_action=curator, + ) + + # Verify the created credential + CredentialManager.verify( + credential=valid_credential, + user_performing_action=curator, + ) + + # Test editing the credential + valid_credential.name = "updated_valid_credential" + CredentialManager.edit(valid_credential, user_performing_action=curator) + + # Verify the edit + CredentialManager.verify( + credential=valid_credential, + user_performing_action=curator, + ) + + # Test deleting the credential + CredentialManager.delete(valid_credential, user_performing_action=curator) + + # Verify the deletion + CredentialManager.verify( + credential=valid_credential, + verify_deleted=True, + user_performing_action=curator, + ) diff --git a/backend/tests/integration/tests/permissions/test_doc_set_permissions.py b/backend/tests/integration/tests/permissions/test_doc_set_permissions.py new file mode 100644 index 00000000000..412b5d41fad --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_doc_set_permissions.py @@ -0,0 +1,190 @@ +import pytest +from requests.exceptions import HTTPError + +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.document_set import DocumentSetManager +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_doc_set_permissions_setup(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # Creating a second user (curator) + curator: TestUser = UserManager.create(name="curator") + + # Creating the first user group + user_group_1 = UserGroupManager.create( + name="curated_user_group", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # Setting the curator as a curator for the first user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Creating a second user group + user_group_2 = UserGroupManager.create( + name="uncurated_user_group", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # Admin creates a cc_pair + private_cc_pair = CCPairManager.create_from_scratch( + is_public=False, + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + + # Admin creates a public cc_pair + public_cc_pair = CCPairManager.create_from_scratch( + is_public=True, + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + + # END OF HAPPY PATH + + """Tests for things Curators/Admins should not be able to do""" + + # Test that curator cannot create a document set for the group they don't curate + with pytest.raises(HTTPError): + DocumentSetManager.create( + name="Invalid Document Set 1", + groups=[user_group_2.id], + cc_pair_ids=[public_cc_pair.id], + user_performing_action=curator, + ) + + # Test that curator cannot create a document set attached to both groups + with pytest.raises(HTTPError): + DocumentSetManager.create( + name="Invalid Document Set 2", + is_public=False, + cc_pair_ids=[public_cc_pair.id], + groups=[user_group_1.id, user_group_2.id], + user_performing_action=curator, + ) + + # Test that curator cannot create a document set with no groups + with pytest.raises(HTTPError): + DocumentSetManager.create( + name="Invalid Document Set 3", + is_public=False, + cc_pair_ids=[public_cc_pair.id], + groups=[], + user_performing_action=curator, + ) + + # Test that curator cannot create a document set with no cc_pairs + with pytest.raises(HTTPError): + DocumentSetManager.create( + name="Invalid Document Set 4", + is_public=False, + cc_pair_ids=[], + groups=[user_group_1.id], + user_performing_action=curator, + ) + + # Test that admin cannot create a document set with no cc_pairs + with pytest.raises(HTTPError): + DocumentSetManager.create( + name="Invalid Document Set 4", + is_public=False, + cc_pair_ids=[], + groups=[user_group_1.id], + user_performing_action=admin_user, + ) + + """Tests for things Curators should be able to do""" + # Test that curator can create a document set for the group they curate + valid_doc_set = DocumentSetManager.create( + name="Valid Document Set", + is_public=False, + cc_pair_ids=[public_cc_pair.id], + groups=[user_group_1.id], + user_performing_action=curator, + ) + + DocumentSetManager.wait_for_sync( + document_sets_to_check=[valid_doc_set], user_performing_action=admin_user + ) + + # Verify that the valid document set was created + DocumentSetManager.verify( + document_set=valid_doc_set, + user_performing_action=admin_user, + ) + + # Verify that only one document set exists + all_doc_sets = DocumentSetManager.get_all(user_performing_action=admin_user) + assert len(all_doc_sets) == 1 + + # Add the private_cc_pair to the doc set on our end for later comparison + valid_doc_set.cc_pair_ids.append(private_cc_pair.id) + + # Confirm the curator can't add the private_cc_pair to the doc set + with pytest.raises(HTTPError): + DocumentSetManager.edit( + document_set=valid_doc_set, + user_performing_action=curator, + ) + # Confirm the admin can't add the private_cc_pair to the doc set + with pytest.raises(HTTPError): + DocumentSetManager.edit( + document_set=valid_doc_set, + user_performing_action=admin_user, + ) + + # Verify the document set has not been updated in the db + with pytest.raises(ValueError): + DocumentSetManager.verify( + document_set=valid_doc_set, + user_performing_action=admin_user, + ) + + # Add the private_cc_pair to the user group on our end for later comparison + user_group_1.cc_pair_ids.append(private_cc_pair.id) + + # Admin adds the cc_pair to the group the curator curates + UserGroupManager.edit( + user_group=user_group_1, + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + UserGroupManager.verify( + user_group=user_group_1, + user_performing_action=admin_user, + ) + + # Confirm the curator can now add the cc_pair to the doc set + DocumentSetManager.edit( + document_set=valid_doc_set, + user_performing_action=curator, + ) + DocumentSetManager.wait_for_sync( + document_sets_to_check=[valid_doc_set], user_performing_action=admin_user + ) + # Verify the updated document set + DocumentSetManager.verify( + document_set=valid_doc_set, + user_performing_action=admin_user, + ) diff --git a/backend/tests/integration/tests/permissions/test_user_role_permissions.py b/backend/tests/integration/tests/permissions/test_user_role_permissions.py new file mode 100644 index 00000000000..5da91a57af8 --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_user_role_permissions.py @@ -0,0 +1,93 @@ +""" +This file tests the ability of different user types to set the role of other users. +""" +import pytest +from requests.exceptions import HTTPError + +from danswer.db.models import UserRole +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_user_role_setting_permissions(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + assert UserManager.verify_role(admin_user, UserRole.ADMIN) + + # Creating a basic user + basic_user: TestUser = UserManager.create(name="basic_user") + assert UserManager.verify_role(basic_user, UserRole.BASIC) + + # Creating a curator + curator: TestUser = UserManager.create(name="curator") + assert UserManager.verify_role(curator, UserRole.BASIC) + + # Creating a curator without adding to a group should not work + with pytest.raises(HTTPError): + UserManager.set_role( + user_to_set=curator, + target_role=UserRole.CURATOR, + user_to_perform_action=admin_user, + ) + + global_curator: TestUser = UserManager.create(name="global_curator") + assert UserManager.verify_role(global_curator, UserRole.BASIC) + + # Setting the role of a global curator should not work for a basic user + with pytest.raises(HTTPError): + UserManager.set_role( + user_to_set=global_curator, + target_role=UserRole.GLOBAL_CURATOR, + user_to_perform_action=basic_user, + ) + + # Setting the role of a global curator should work for an admin user + UserManager.set_role( + user_to_set=global_curator, + target_role=UserRole.GLOBAL_CURATOR, + user_to_perform_action=admin_user, + ) + assert UserManager.verify_role(global_curator, UserRole.GLOBAL_CURATOR) + + # Setting the role of a global curator should not work for an invalid curator + with pytest.raises(HTTPError): + UserManager.set_role( + user_to_set=global_curator, + target_role=UserRole.BASIC, + user_to_perform_action=global_curator, + ) + assert UserManager.verify_role(global_curator, UserRole.GLOBAL_CURATOR) + + # Creating a user group + user_group_1 = UserGroupManager.create( + name="user_group_1", + user_ids=[], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # This should fail because the curator is not in the user group + with pytest.raises(HTTPError): + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Adding the curator to the user group + user_group_1.user_ids = [curator.id] + UserGroupManager.edit(user_group=user_group_1, user_performing_action=admin_user) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # This should work because the curator is in the user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) diff --git a/backend/tests/integration/tests/permissions/test_whole_curator_flow.py b/backend/tests/integration/tests/permissions/test_whole_curator_flow.py new file mode 100644 index 00000000000..878ba1e17e8 --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_whole_curator_flow.py @@ -0,0 +1,86 @@ +""" +This test tests the happy path for curator permissions +""" +from danswer.db.models import UserRole +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.connector import ConnectorManager +from tests.integration.common_utils.managers.credential import CredentialManager +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_whole_curator_flow(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + assert UserManager.verify_role(admin_user, UserRole.ADMIN) + + # Creating a curator + curator: TestUser = UserManager.create(name="curator") + + # Creating a user group + user_group_1 = UserGroupManager.create( + name="user_group_1", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + # Making curator a curator of user_group_1 + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + assert UserManager.verify_role(curator, UserRole.CURATOR) + + # Creating a credential as curator + test_credential = CredentialManager.create( + name="curator_test_credential", + source=DocumentSource.FILE, + curator_public=False, + groups=[user_group_1.id], + user_performing_action=curator, + ) + + # Creating a connector as curator + test_connector = ConnectorManager.create( + name="curator_test_connector", + source=DocumentSource.FILE, + is_public=False, + groups=[user_group_1.id], + user_performing_action=curator, + ) + + # Test editing the connector + test_connector.name = "updated_test_connector" + ConnectorManager.edit(connector=test_connector, user_performing_action=curator) + + # Creating a CC pair as curator + test_cc_pair = CCPairManager.create( + connector_id=test_connector.id, + credential_id=test_credential.id, + name="curator_test_cc_pair", + groups=[user_group_1.id], + is_public=False, + user_performing_action=curator, + ) + + CCPairManager.verify(cc_pair=test_cc_pair, user_performing_action=admin_user) + + # Verify that the curator can pause and unpause the CC pair + CCPairManager.pause_cc_pair(cc_pair=test_cc_pair, user_performing_action=curator) + + # Verify that the curator can delete the CC pair + CCPairManager.delete(cc_pair=test_cc_pair, user_performing_action=curator) + CCPairManager.wait_for_deletion_completion(user_performing_action=curator) + + # Verify that the CC pair has been deleted + CCPairManager.verify( + cc_pair=test_cc_pair, + verify_deleted=True, + user_performing_action=admin_user, + ) diff --git a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py b/backend/tests/unit/danswer/direct_qa/test_qa_utils.py index d3974fe47ab..ad6baaa365a 100644 --- a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py +++ b/backend/tests/unit/danswer/direct_qa/test_qa_utils.py @@ -12,6 +12,37 @@ from danswer.search.models import InferenceChunk +def test_passed_in_quotes() -> None: + # Test case 1: Basic quote separation + test_answer = """{ + "answer": "I can assist "James" with that", + "quotes": [ + "Danswer can just ingest PDFs as they are. How GOOD it embeds them depends on the formatting of your PDFs.", + "the ` danswer. llm ` package aims to provide a comprehensive framework." + ] + }""" + + answer, quotes = separate_answer_quotes(test_answer, is_json_prompt=True) + assert answer == 'I can assist "James" with that' + assert quotes == [ + "Danswer can just ingest PDFs as they are. How GOOD it embeds them depends on the formatting of your PDFs.", + "the ` danswer. llm ` package aims to provide a comprehensive framework.", + ] + + # Test case 2: Additional quotes + test_answer = """{ + "answer": "She said the resposne was "1" and I said the reponse was "2".", + "quotes": [ + "Danswer can efficiently ingest PDFs, with the quality of embedding depending on the PDF's formatting." + ] + }""" + answer, quotes = separate_answer_quotes(test_answer, is_json_prompt=True) + assert answer == 'She said the resposne was "1" and I said the reponse was "2".' + assert quotes == [ + "Danswer can efficiently ingest PDFs, with the quality of embedding depending on the PDF's formatting.", + ] + + def test_separate_answer_quotes() -> None: # Test case 1: Basic quote separation test_answer = textwrap.dedent( diff --git a/ct.yaml b/ct.yaml new file mode 100644 index 00000000000..764af160daf --- /dev/null +++ b/ct.yaml @@ -0,0 +1,12 @@ +# See https://github.com/helm/chart-testing#configuration + +chart-dirs: + - deployment/helm/charts + +chart-repos: + - vespa=https://unoplat.github.io/vespa-helm-charts + - postgresql=https://charts.bitnami.com/bitnami + +helm-extra-args: --timeout 900s + +validate-maintainers: false diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index bda8ffa65d5..eb5ba5efc88 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -1,4 +1,3 @@ -version: '3' services: api_server: image: danswer/danswer-backend:${IMAGE_TAG:-latest} @@ -12,6 +11,7 @@ services: depends_on: - relational_db - index + - cache - inference_model_server restart: always ports: @@ -35,13 +35,6 @@ services: - OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-} - TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-} # Gen AI Settings - - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-} - - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-} - - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-} - - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} - - GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-} - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-} @@ -69,6 +62,7 @@ services: # Other services - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose # Don't change the NLP model configs unless you know what you're doing - DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-} @@ -114,19 +108,13 @@ services: depends_on: - relational_db - index + - cache - inference_model_server - indexing_model_server restart: always environment: - ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-} # Gen AI Settings (Needed by DanswerBot) - - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-} - - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-} - - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-} - - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} - - GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-} - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-} @@ -151,6 +139,7 @@ services: - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-} - POSTGRES_DB=${POSTGRES_DB:-} - VESPA_HOST=index + - REDIS_HOST=cache - WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose for OAuth2 connectors # Don't change the NLP model configs unless you know what you're doing - DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-} @@ -186,6 +175,7 @@ services: - NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-} - DANSWER_BOT_MAX_QPM=${DANSWER_BOT_MAX_QPM:-} - DANSWER_BOT_MAX_WAIT_TIME=${DANSWER_BOT_MAX_WAIT_TIME:-} + - CUSTOM_REFRESH_URL=${CUSTOM_REFRESH_URL:-} # Logging # Leave this on pretty please? Nothing sensitive is collected! # https://docs.danswer.dev/more/telemetry @@ -234,6 +224,7 @@ services: # Enterprise Edition only - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} + - CUSTOM_REFRESH_URL=${CUSTOM_REFRESH_URL:-} inference_model_server: image: danswer/danswer-model-server:${IMAGE_TAG:-latest} @@ -342,9 +333,19 @@ services: # in order to make this work on both Unix-like systems and windows command: > /bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh - && /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev" - + && /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev" + + cache: + image: redis:7.4-alpine + restart: always + ports: + - '6379:6379' + command: redis-server + volumes: + - cache_volume:/data + volumes: + cache_volume: db_volume: vespa_volume: # Created by the container itself diff --git a/deployment/docker_compose/docker-compose.gpu-dev.yml b/deployment/docker_compose/docker-compose.gpu-dev.yml index 9079bd10dff..74da119737e 100644 --- a/deployment/docker_compose/docker-compose.gpu-dev.yml +++ b/deployment/docker_compose/docker-compose.gpu-dev.yml @@ -1,4 +1,3 @@ -version: '3' services: api_server: image: danswer/danswer-backend:${IMAGE_TAG:-latest} @@ -12,6 +11,7 @@ services: depends_on: - relational_db - index + - cache - inference_model_server restart: always ports: @@ -32,13 +32,6 @@ services: - EMAIL_FROM=${EMAIL_FROM:-} - TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-} # Gen AI Settings - - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-} - - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-} - - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-} - - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} - - GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-} - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-} @@ -65,6 +58,7 @@ services: # Other services - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose # Don't change the NLP model configs unless you know what you're doing - DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-} @@ -106,19 +100,13 @@ services: depends_on: - relational_db - index + - cache - inference_model_server - indexing_model_server restart: always environment: - ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-} # Gen AI Settings (Needed by DanswerBot) - - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-} - - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-} - - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-} - - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} - - GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-} - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-} @@ -143,6 +131,7 @@ services: - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-} - POSTGRES_DB=${POSTGRES_DB:-} - VESPA_HOST=index + - REDIS_HOST=cache - WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose for OAuth2 connectors # Don't change the NLP model configs unless you know what you're doing - DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-} @@ -355,9 +344,20 @@ services: command: > /bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh && /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev" - + + + cache: + image: redis:7.4-alpine + restart: always + ports: + - '6379:6379' + command: redis-server + volumes: + - cache_volume:/data + volumes: + cache_volume: db_volume: vespa_volume: # Created by the container itself diff --git a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml index 655243d6cb9..c06e9ae3480 100644 --- a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml +++ b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml @@ -1,4 +1,3 @@ -version: '3' services: api_server: image: danswer/danswer-backend:${IMAGE_TAG:-latest} @@ -12,6 +11,7 @@ services: depends_on: - relational_db - index + - cache - inference_model_server restart: always env_file: @@ -20,6 +20,7 @@ services: - AUTH_TYPE=${AUTH_TYPE:-oidc} - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} extra_hosts: - "host.docker.internal:host-gateway" @@ -39,6 +40,7 @@ services: depends_on: - relational_db - index + - cache - inference_model_server - indexing_model_server restart: always @@ -48,6 +50,7 @@ services: - AUTH_TYPE=${AUTH_TYPE:-oidc} - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server} extra_hosts: @@ -204,7 +207,18 @@ services: - .env.nginx + cache: + image: redis:7.4-alpine + restart: always + ports: + - '6379:6379' + command: redis-server + volumes: + - cache_volume:/data + + volumes: + cache_volume: db_volume: vespa_volume: # Created by the container itself diff --git a/deployment/docker_compose/docker-compose.prod.yml b/deployment/docker_compose/docker-compose.prod.yml index 40f018eadd9..53bfa646b55 100644 --- a/deployment/docker_compose/docker-compose.prod.yml +++ b/deployment/docker_compose/docker-compose.prod.yml @@ -1,4 +1,3 @@ -version: '3' services: api_server: image: danswer/danswer-backend:${IMAGE_TAG:-latest} @@ -12,6 +11,7 @@ services: depends_on: - relational_db - index + - cache - inference_model_server restart: always env_file: @@ -20,6 +20,7 @@ services: - AUTH_TYPE=${AUTH_TYPE:-oidc} - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} extra_hosts: - "host.docker.internal:host-gateway" @@ -39,6 +40,7 @@ services: depends_on: - relational_db - index + - cache - inference_model_server - indexing_model_server restart: always @@ -48,6 +50,7 @@ services: - AUTH_TYPE=${AUTH_TYPE:-oidc} - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server} extra_hosts: @@ -221,7 +224,18 @@ services: entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'" + cache: + image: redis:7.4-alpine + restart: always + ports: + - '6379:6379' + command: redis-server + volumes: + - cache_volume:/data + + volumes: + cache_volume: db_volume: vespa_volume: # Created by the container itself diff --git a/deployment/docker_compose/docker-compose.search-testing.yml b/deployment/docker_compose/docker-compose.search-testing.yml index efb387eb083..ecd796f6716 100644 --- a/deployment/docker_compose/docker-compose.search-testing.yml +++ b/deployment/docker_compose/docker-compose.search-testing.yml @@ -1,4 +1,3 @@ -version: '3' services: api_server: image: danswer/danswer-backend:${IMAGE_TAG:-latest} @@ -12,6 +11,7 @@ services: depends_on: - relational_db - index + - cache restart: always ports: - "8080" @@ -21,6 +21,7 @@ services: - AUTH_TYPE=disabled - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} - ENV_SEED_CONFIGURATION=${ENV_SEED_CONFIGURATION:-} @@ -43,6 +44,7 @@ services: depends_on: - relational_db - index + - cache restart: always env_file: - .env_eval @@ -50,6 +52,7 @@ services: - AUTH_TYPE=disabled - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server} @@ -200,7 +203,18 @@ services: && /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev" + cache: + image: redis:7.4-alpine + restart: always + ports: + - '6379:6379' + command: redis-server + volumes: + - cache_volume:/data + + volumes: + cache_volume: db_volume: driver: local driver_opts: diff --git a/deployment/docker_compose/env.multilingual.template b/deployment/docker_compose/env.multilingual.template index e218305153f..1a66dbfbbde 100644 --- a/deployment/docker_compose/env.multilingual.template +++ b/deployment/docker_compose/env.multilingual.template @@ -1,38 +1,8 @@ -# This env template shows how to configure Danswer for multilingual use -# In this case, it is configured for French and English -# To use it, copy it to .env in the docker_compose directory. -# Feel free to combine it with the other templates to suit your needs +# This env template shows how to configure Danswer for custom multilingual use +# Note that for most use cases it will be enough to configure Danswer multilingual purely through the UI +# See "Search Settings" -> "Advanced" for UI options. +# To use it, copy it to .env in the docker_compose directory (or the equivalent environment settings file for your deployment) - -# Rephrase the user query in specified languages using LLM, use comma separated values -MULTILINGUAL_QUERY_EXPANSION="English, French" -# Change the below to suit your specific needs, can be more explicit about the language of the response -LANGUAGE_HINT="IMPORTANT: Respond in the same language as my query!" +# The following is included with the user prompt. Here's one example but feel free to customize it to your needs: +LANGUAGE_HINT="IMPORTANT: ALWAYS RESPOND IN FRENCH! Even if the documents and the user query are in English, your response must be in French." LANGUAGE_CHAT_NAMING_HINT="The name of the conversation must be in the same language as the user query." - -# A recent MIT license multilingual model: https://huggingface.co/intfloat/multilingual-e5-small -DOCUMENT_ENCODER_MODEL="intfloat/multilingual-e5-small" - -# The model above is trained with the following prefix for queries and passages to improve retrieval -# by letting the model know which of the two type is currently being embedded -ASYM_QUERY_PREFIX="query: " -ASYM_PASSAGE_PREFIX="passage: " - -# Depends model by model, the one shown above is tuned with this as True -NORMALIZE_EMBEDDINGS="True" - -# Use LLM to determine if chunks are relevant to the query -# May not work well for languages that do not have much training data in the LLM training set -# If using a common language like Spanish, French, Chinese, etc. this can be kept turned on -DISABLE_LLM_DOC_RELEVANCE="True" - -# Enables fine-grained embeddings for better retrieval -# At the cost of indexing speed (~5x slower), query time is same speed -# Since reranking is turned off and multilingual retrieval is generally harder -# it is advised to turn this one on -ENABLE_MULTIPASS_INDEXING="True" - -# Using a stronger LLM will help with multilingual tasks -# Since documents may be in multiple languages, and there are additional instructions to respond -# in the user query's language, it is advised to use the best model possible -GEN_AI_MODEL_VERSION="gpt-4" diff --git a/deployment/docker_compose/env.prod.template b/deployment/docker_compose/env.prod.template index 818bd1ed1bf..890939deb49 100644 --- a/deployment/docker_compose/env.prod.template +++ b/deployment/docker_compose/env.prod.template @@ -7,16 +7,7 @@ WEB_DOMAIN=http://localhost:3000 -# Generative AI settings, uncomment as needed, will work with defaults -GEN_AI_MODEL_PROVIDER=openai -GEN_AI_MODEL_VERSION=gpt-4 -# Provide this as a global default/backup, this can also be set via the UI -#GEN_AI_API_KEY= -# Set to use Azure OpenAI or other services, such as https://danswer.openai.azure.com/ -#GEN_AI_API_ENDPOINT= -# Set up to use a specific API version, such as 2023-09-15-preview (example taken from Azure) -#GEN_AI_API_VERSION= - +# NOTE: Generative AI configurations are done via the UI now # If you want to setup a slack bot to answer questions automatically in Slack # channels it is added to, you must specify the two below. diff --git a/deployment/docker_compose/init-letsencrypt.sh b/deployment/docker_compose/init-letsencrypt.sh index 9eec409fada..66161e4dfbe 100755 --- a/deployment/docker_compose/init-letsencrypt.sh +++ b/deployment/docker_compose/init-letsencrypt.sh @@ -112,5 +112,14 @@ $COMPOSE_CMD -f docker-compose.prod.yml run --name danswer-stack --rm --entrypoi --force-renewal" certbot echo +echo "### Renaming certificate directory if needed ..." +$COMPOSE_CMD -f docker-compose.prod.yml run --name danswer-stack --rm --entrypoint "\ + sh -c 'for domain in $domains; do \ + numbered_dir=\$(find /etc/letsencrypt/live -maxdepth 1 -type d -name \"\$domain-00*\" | sort -r | head -n1); \ + if [ -n \"\$numbered_dir\" ]; then \ + mv \"\$numbered_dir\" /etc/letsencrypt/live/\$domain; \ + fi; \ + done'" certbot + echo "### Reloading nginx ..." $COMPOSE_CMD -f docker-compose.prod.yml -p danswer-stack up --force-recreate -d diff --git a/deployment/helm/.gitignore b/deployment/helm/charts/danswer/.gitignore similarity index 100% rename from deployment/helm/.gitignore rename to deployment/helm/charts/danswer/.gitignore diff --git a/deployment/helm/.helmignore b/deployment/helm/charts/danswer/.helmignore similarity index 100% rename from deployment/helm/.helmignore rename to deployment/helm/charts/danswer/.helmignore diff --git a/deployment/helm/Chart.lock b/deployment/helm/charts/danswer/Chart.lock similarity index 100% rename from deployment/helm/Chart.lock rename to deployment/helm/charts/danswer/Chart.lock diff --git a/deployment/helm/Chart.yaml b/deployment/helm/charts/danswer/Chart.yaml similarity index 92% rename from deployment/helm/Chart.yaml rename to deployment/helm/charts/danswer/Chart.yaml index cc08a21556f..a84dbcba765 100644 --- a/deployment/helm/Chart.yaml +++ b/deployment/helm/charts/danswer/Chart.yaml @@ -5,8 +5,8 @@ home: https://www.danswer.ai/ sources: - "https://github.com/danswer-ai/danswer" type: application -version: 0.2.2 -appVersion: v0.5.10 +version: 0.2.2-merge-test.1 +appVersion: v0.5.15 dependencies: - name: postgresql version: 14.3.1 @@ -20,5 +20,3 @@ dependencies: version: 15.14.0 repository: oci://registry-1.docker.io/bitnamicharts condition: nginx.enabled - - diff --git a/deployment/helm/azimuth-ui.schema.yaml b/deployment/helm/charts/danswer/azimuth-ui.schema.yaml similarity index 100% rename from deployment/helm/azimuth-ui.schema.yaml rename to deployment/helm/charts/danswer/azimuth-ui.schema.yaml diff --git a/deployment/helm/templates/NOTES.txt b/deployment/helm/charts/danswer/templates/NOTES.txt similarity index 100% rename from deployment/helm/templates/NOTES.txt rename to deployment/helm/charts/danswer/templates/NOTES.txt diff --git a/deployment/helm/templates/_helpers.tpl b/deployment/helm/charts/danswer/templates/_helpers.tpl similarity index 100% rename from deployment/helm/templates/_helpers.tpl rename to deployment/helm/charts/danswer/templates/_helpers.tpl diff --git a/deployment/helm/templates/api-deployment.yaml b/deployment/helm/charts/danswer/templates/api-deployment.yaml similarity index 100% rename from deployment/helm/templates/api-deployment.yaml rename to deployment/helm/charts/danswer/templates/api-deployment.yaml diff --git a/deployment/helm/templates/api-hpa.yaml b/deployment/helm/charts/danswer/templates/api-hpa.yaml similarity index 100% rename from deployment/helm/templates/api-hpa.yaml rename to deployment/helm/charts/danswer/templates/api-hpa.yaml diff --git a/deployment/helm/templates/api-service.yaml b/deployment/helm/charts/danswer/templates/api-service.yaml similarity index 100% rename from deployment/helm/templates/api-service.yaml rename to deployment/helm/charts/danswer/templates/api-service.yaml diff --git a/deployment/helm/templates/background-deployment.yaml b/deployment/helm/charts/danswer/templates/background-deployment.yaml similarity index 100% rename from deployment/helm/templates/background-deployment.yaml rename to deployment/helm/charts/danswer/templates/background-deployment.yaml diff --git a/deployment/helm/templates/background-hpa.yaml b/deployment/helm/charts/danswer/templates/background-hpa.yaml similarity index 100% rename from deployment/helm/templates/background-hpa.yaml rename to deployment/helm/charts/danswer/templates/background-hpa.yaml diff --git a/deployment/helm/templates/configmap.yaml b/deployment/helm/charts/danswer/templates/configmap.yaml similarity index 100% rename from deployment/helm/templates/configmap.yaml rename to deployment/helm/charts/danswer/templates/configmap.yaml diff --git a/deployment/helm/templates/danswer-secret.yaml b/deployment/helm/charts/danswer/templates/danswer-secret.yaml similarity index 100% rename from deployment/helm/templates/danswer-secret.yaml rename to deployment/helm/charts/danswer/templates/danswer-secret.yaml diff --git a/deployment/helm/templates/indexing-model-deployment.yaml b/deployment/helm/charts/danswer/templates/indexing-model-deployment.yaml similarity index 100% rename from deployment/helm/templates/indexing-model-deployment.yaml rename to deployment/helm/charts/danswer/templates/indexing-model-deployment.yaml diff --git a/deployment/helm/templates/indexing-model-pvc.yaml b/deployment/helm/charts/danswer/templates/indexing-model-pvc.yaml similarity index 100% rename from deployment/helm/templates/indexing-model-pvc.yaml rename to deployment/helm/charts/danswer/templates/indexing-model-pvc.yaml diff --git a/deployment/helm/templates/indexing-model-service.yaml b/deployment/helm/charts/danswer/templates/indexing-model-service.yaml similarity index 100% rename from deployment/helm/templates/indexing-model-service.yaml rename to deployment/helm/charts/danswer/templates/indexing-model-service.yaml diff --git a/deployment/helm/templates/inference-model-deployment.yaml b/deployment/helm/charts/danswer/templates/inference-model-deployment.yaml similarity index 100% rename from deployment/helm/templates/inference-model-deployment.yaml rename to deployment/helm/charts/danswer/templates/inference-model-deployment.yaml diff --git a/deployment/helm/templates/inference-model-pvc.yaml b/deployment/helm/charts/danswer/templates/inference-model-pvc.yaml similarity index 100% rename from deployment/helm/templates/inference-model-pvc.yaml rename to deployment/helm/charts/danswer/templates/inference-model-pvc.yaml diff --git a/deployment/helm/templates/inference-model-service.yaml b/deployment/helm/charts/danswer/templates/inference-model-service.yaml similarity index 100% rename from deployment/helm/templates/inference-model-service.yaml rename to deployment/helm/charts/danswer/templates/inference-model-service.yaml diff --git a/deployment/helm/templates/nginx-conf.yaml b/deployment/helm/charts/danswer/templates/nginx-conf.yaml similarity index 100% rename from deployment/helm/templates/nginx-conf.yaml rename to deployment/helm/charts/danswer/templates/nginx-conf.yaml diff --git a/deployment/helm/templates/serviceaccount.yaml b/deployment/helm/charts/danswer/templates/serviceaccount.yaml similarity index 100% rename from deployment/helm/templates/serviceaccount.yaml rename to deployment/helm/charts/danswer/templates/serviceaccount.yaml diff --git a/deployment/helm/templates/stackhpc/client.yaml b/deployment/helm/charts/danswer/templates/stackhpc/client.yaml similarity index 100% rename from deployment/helm/templates/stackhpc/client.yaml rename to deployment/helm/charts/danswer/templates/stackhpc/client.yaml diff --git a/deployment/helm/templates/stackhpc/hooks.yml b/deployment/helm/charts/danswer/templates/stackhpc/hooks.yml similarity index 100% rename from deployment/helm/templates/stackhpc/hooks.yml rename to deployment/helm/charts/danswer/templates/stackhpc/hooks.yml diff --git a/deployment/helm/templates/stackhpc/reservation.yaml b/deployment/helm/charts/danswer/templates/stackhpc/reservation.yaml similarity index 100% rename from deployment/helm/templates/stackhpc/reservation.yaml rename to deployment/helm/charts/danswer/templates/stackhpc/reservation.yaml diff --git a/deployment/helm/templates/tests/test-connection.yaml b/deployment/helm/charts/danswer/templates/tests/test-connection.yaml similarity index 100% rename from deployment/helm/templates/tests/test-connection.yaml rename to deployment/helm/charts/danswer/templates/tests/test-connection.yaml diff --git a/deployment/helm/templates/webserver-deployment.yaml b/deployment/helm/charts/danswer/templates/webserver-deployment.yaml similarity index 100% rename from deployment/helm/templates/webserver-deployment.yaml rename to deployment/helm/charts/danswer/templates/webserver-deployment.yaml diff --git a/deployment/helm/templates/webserver-hpa.yaml b/deployment/helm/charts/danswer/templates/webserver-hpa.yaml similarity index 100% rename from deployment/helm/templates/webserver-hpa.yaml rename to deployment/helm/charts/danswer/templates/webserver-hpa.yaml diff --git a/deployment/helm/templates/webserver-service.yaml b/deployment/helm/charts/danswer/templates/webserver-service.yaml similarity index 100% rename from deployment/helm/templates/webserver-service.yaml rename to deployment/helm/charts/danswer/templates/webserver-service.yaml diff --git a/deployment/helm/values.schema.json b/deployment/helm/charts/danswer/values.schema.json similarity index 100% rename from deployment/helm/values.schema.json rename to deployment/helm/charts/danswer/values.schema.json diff --git a/deployment/helm/values.yaml b/deployment/helm/charts/danswer/values.yaml similarity index 95% rename from deployment/helm/values.yaml rename to deployment/helm/charts/danswer/values.yaml index edd4fe87d62..dfa18b38a2f 100644 --- a/deployment/helm/values.yaml +++ b/deployment/helm/charts/danswer/values.yaml @@ -11,7 +11,7 @@ appVersionOverride: # e.g "v0.3.93" # tags to refer to downstream StackHPC-modified images. # The full image ref will be: # {{ image-name }}:{{ image-tag or appVersion }}-{{ tagSuffix }} -tagSuffix: stackhpc.4 +tagSuffix: merge-test.1 zenithClient: iconUrl: https://raw.githubusercontent.com/danswer-ai/danswer/1fabd9372d66cd54238847197c33f091a724803b/Danswer.png @@ -114,7 +114,7 @@ postgresql: auth: existingSecret: danswer-secrets secretKeys: - adminPasswordKey: postgres_password #overwriting as postgres typically expects 'postgres-password' + adminPasswordKey: postgres_password # overwriting as postgres typically expects 'postgres-password' nginx: containerPorts: @@ -375,7 +375,7 @@ vespa: storage: 10Gi -#ingress: +# ingress: # enabled: false # className: "" # annotations: {} @@ -403,47 +403,43 @@ persistence: auth: # for storing smtp, oauth, slack, and other secrets # keys are lowercased version of env vars (e.g. SMTP_USER -> smtp_user) - existingSecret: "" # danswer-secrets + existingSecret: "" # danswer-secrets # optionally override the secret keys to reference in the secret + # this is used to populate the env vars in individual deployments + # the values here reference the keys in secrets below secretKeys: postgres_password: "postgres_password" smtp_pass: "" oauth_client_id: "" oauth_client_secret: "" oauth_cookie_secret: "" - gen_ai_api_key: "" danswer_bot_slack_app_token: "" danswer_bot_slack_bot_token: "" + redis_password: "redis_password" # will be overridden by the existingSecret if set secretName: "danswer-secrets" # set values as strings, they will be base64 encoded + # this is used to populate the secrets yaml secrets: postgres_password: "postgres" smtp_pass: "" oauth_client_id: "" oauth_client_secret: "" oauth_cookie_secret: "" - gen_ai_api_key: "" danswer_bot_slack_app_token: "" danswer_bot_slack_bot_token: "" + redis_password: "password" configMap: AUTH_TYPE: "basic" # Basic auth required for x-remote-user header integration SESSION_EXPIRE_TIME_SECONDS: "86400" # 1 Day Default VALID_EMAIL_DOMAINS: "" # Can be something like danswer.ai, as an extra double-check - SMTP_SERVER: "" # For sending verification emails, if unspecified then defaults to 'smtp.gmail.com' - SMTP_PORT: "" # For sending verification emails, if unspecified then defaults to '587' + SMTP_SERVER: "" # For sending verification emails, if unspecified then defaults to 'smtp.gmail.com' + SMTP_PORT: "" # For sending verification emails, if unspecified then defaults to '587' SMTP_USER: "" # 'your-email@company.com' # SMTP_PASS: "" # 'your-gmail-password' EMAIL_FROM: "" # 'your-email@company.com' SMTP_USER missing used instead # Gen AI Settings - GEN_AI_MODEL_PROVIDER: "" - GEN_AI_MODEL_VERSION: "" - FAST_GEN_AI_MODEL_VERSION: "" - # GEN_AI_API_KEY: "" - GEN_AI_API_ENDPOINT: "" - GEN_AI_API_VERSION: "" - GEN_AI_LLM_PROVIDER_TYPE: "" GEN_AI_MAX_TOKENS: "" QA_TIMEOUT: "60" MAX_CHUNKS_FED_TO_CHAT: "" diff --git a/deployment/kubernetes/api_server-service-deployment.yaml b/deployment/kubernetes/api_server-service-deployment.yaml index eeac5fecc96..ccbbc906d61 100644 --- a/deployment/kubernetes/api_server-service-deployment.yaml +++ b/deployment/kubernetes/api_server-service-deployment.yaml @@ -52,6 +52,11 @@ spec: secretKeyRef: name: danswer-secrets key: google_oauth_client_secret + - name: REDIS_PASSWORD + valueFrom: + secretKeyRef: + name: danswer-secrets + key: redis_password envFrom: - configMapRef: name: env-configmap diff --git a/deployment/kubernetes/background-deployment.yaml b/deployment/kubernetes/background-deployment.yaml index 18521b0f5ad..1a6ef61c104 100644 --- a/deployment/kubernetes/background-deployment.yaml +++ b/deployment/kubernetes/background-deployment.yaml @@ -19,6 +19,12 @@ spec: command: ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"] # There are some extra values since this is shared between services # There are no conflicts though, extra env variables are simply ignored + env: + - name: REDIS_PASSWORD + valueFrom: + secretKeyRef: + name: danswer-secrets + key: redis_password envFrom: - configMapRef: name: env-configmap diff --git a/deployment/kubernetes/env-configmap.yaml b/deployment/kubernetes/env-configmap.yaml index 907fae1c836..cfba42a843c 100644 --- a/deployment/kubernetes/env-configmap.yaml +++ b/deployment/kubernetes/env-configmap.yaml @@ -14,13 +14,6 @@ data: SMTP_PASS: "" # 'your-gmail-password' EMAIL_FROM: "" # 'your-email@company.com' SMTP_USER missing used instead # Gen AI Settings - GEN_AI_MODEL_PROVIDER: "" - GEN_AI_MODEL_VERSION: "" - FAST_GEN_AI_MODEL_VERSION: "" - GEN_AI_API_KEY: "" - GEN_AI_API_ENDPOINT: "" - GEN_AI_API_VERSION: "" - GEN_AI_LLM_PROVIDER_TYPE: "" GEN_AI_MAX_TOKENS: "" QA_TIMEOUT: "60" MAX_CHUNKS_FED_TO_CHAT: "" @@ -38,6 +31,7 @@ data: # Other Services POSTGRES_HOST: "relational-db-service" VESPA_HOST: "document-index-service" + REDIS_HOST: "redis-service" # Internet Search Tool BING_API_KEY: "" # Don't change the NLP models unless you know what you're doing diff --git a/deployment/kubernetes/redis-service-deployment.yaml b/deployment/kubernetes/redis-service-deployment.yaml new file mode 100644 index 00000000000..ab5113e5f49 --- /dev/null +++ b/deployment/kubernetes/redis-service-deployment.yaml @@ -0,0 +1,41 @@ +apiVersion: v1 +kind: Service +metadata: + name: redis-service +spec: + selector: + app: redis + ports: + - name: redis + protocol: TCP + port: 6379 + targetPort: 6379 + type: ClusterIP +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: redis-deployment +spec: + replicas: 1 + selector: + matchLabels: + app: redis + template: + metadata: + labels: + app: redis + spec: + containers: + - name: redis + image: redis:7.4-alpine + ports: + - containerPort: 6379 + env: + - name: REDIS_PASSWORD + valueFrom: + secretKeyRef: + name: danswer-secrets + key: redis_password + command: ["redis-server"] + args: ["--requirepass", "$(REDIS_PASSWORD)"] diff --git a/deployment/kubernetes/secrets.yaml b/deployment/kubernetes/secrets.yaml index c135a29f676..d4cc9e2a739 100644 --- a/deployment/kubernetes/secrets.yaml +++ b/deployment/kubernetes/secrets.yaml @@ -8,4 +8,6 @@ data: postgres_user: cG9zdGdyZXM= # "postgres" base64 encoded postgres_password: cGFzc3dvcmQ= # "password" base64 encoded google_oauth_client_id: ZXhhbXBsZS1jbGllbnQtaWQ= # "example-client-id" base64 encoded. You will need to provide this, use echo -n "your-client-id" | base64 - google_oauth_client_secret: example_google_oauth_secret # "example-client-secret" base64 encoded. You will need to provide this, use echo -n "your-client-id" | base64 + google_oauth_client_secret: ZXhhbXBsZV9nb29nbGVfb2F1dGhfc2VjcmV0 # "example-client-secret" base64 encoded. You will need to provide this, use echo -n "your-client-id" | base64 + redis_password: cGFzc3dvcmQ= # "password" base64 encoded + \ No newline at end of file diff --git a/deployment/kubernetes/web_server-service-deployment.yaml b/deployment/kubernetes/web_server-service-deployment.yaml index b19b8e37986..b54c1b7f3d0 100644 --- a/deployment/kubernetes/web_server-service-deployment.yaml +++ b/deployment/kubernetes/web_server-service-deployment.yaml @@ -33,6 +33,12 @@ spec: - containerPort: 3000 # There are some extra values since this is shared between services # There are no conflicts though, extra env variables are simply ignored + env: + - name: REDIS_PASSWORD + valueFrom: + secretKeyRef: + name: danswer-secrets + key: redis_password envFrom: - configMapRef: name: env-configmap diff --git a/web/Dockerfile b/web/Dockerfile index 4ffced0da49..710cf653f25 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -58,6 +58,7 @@ ENV NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_T ARG NEXT_PUBLIC_DISABLE_LOGOUT ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT} + RUN npx next build # Step 2. Production image, copy all the files and run next diff --git a/web/next.config.js b/web/next.config.js index 1586af8d178..92812c513b7 100644 --- a/web/next.config.js +++ b/web/next.config.js @@ -8,47 +8,6 @@ const version = env_version || package_version; const nextConfig = { output: "standalone", swcMinify: true, - rewrites: async () => { - // In production, something else (nginx in the one box setup) should take - // care of this rewrite. TODO (chris): better support setups where - // web_server and api_server are on different machines. - if (process.env.NODE_ENV === "production") return []; - - return [ - { - source: "/api/:path*", - destination: "http://127.0.0.1:8080/:path*", // Proxy to Backend - }, - ]; - }, - redirects: async () => { - // In production, something else (nginx in the one box setup) should take - // care of this redirect. TODO (chris): better support setups where - // web_server and api_server are on different machines. - const defaultRedirects = []; - - if (process.env.NODE_ENV === "production") return defaultRedirects; - - return defaultRedirects.concat([ - { - source: "/api/chat/send-message:params*", - destination: "http://127.0.0.1:8080/chat/send-message:params*", // Proxy to Backend - permanent: true, - }, - { - source: "/api/query/stream-answer-with-quote:params*", - destination: - "http://127.0.0.1:8080/query/stream-answer-with-quote:params*", // Proxy to Backend - permanent: true, - }, - { - source: "/api/query/stream-query-validation:params*", - destination: - "http://127.0.0.1:8080/query/stream-query-validation:params*", // Proxy to Backend - permanent: true, - }, - ]); - }, publicRuntimeConfig: { version, }, diff --git a/web/package-lock.json b/web/package-lock.json index 48ac21d6477..338cf0a9f0f 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -2555,11 +2555,11 @@ } }, "node_modules/braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dependencies": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" }, "engines": { "node": ">=8" @@ -4061,9 +4061,9 @@ } }, "node_modules/fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dependencies": { "to-regex-range": "^5.0.1" }, diff --git a/web/public/LiteLLM.jpg b/web/public/LiteLLM.jpg new file mode 100644 index 00000000000..d6a77b2d105 Binary files /dev/null and b/web/public/LiteLLM.jpg differ diff --git a/web/src/app/admin/add-connector/page.tsx b/web/src/app/admin/add-connector/page.tsx index bf7032b5f90..8d73131e69a 100644 --- a/web/src/app/admin/add-connector/page.tsx +++ b/web/src/app/admin/add-connector/page.tsx @@ -112,7 +112,7 @@ export default function Page() { value={searchTerm} onChange={(e) => setSearchTerm(e.target.value)} onKeyDown={handleKeyPress} - className="flex mt-2 max-w-sm h-9 w-full rounded-md border-2 border border-input bg-transparent px-3 py-1 text-sm shadow-sm transition-colors placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring" + className="ml-1 w-96 h-9 flex-none rounded-md border border-border bg-background-50 px-3 py-1 text-sm shadow-sm transition-colors placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring" /> {Object.entries(categorizedSources) diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index d478922e516..f8bdf55f745 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -928,9 +928,9 @@ export function AssistantEditor({ { const value = e.target.value; if ( diff --git a/web/src/app/admin/bot/page.tsx b/web/src/app/admin/bot/page.tsx index 14f270ee9bc..c3ef70ccbf6 100644 --- a/web/src/app/admin/bot/page.tsx +++ b/web/src/app/admin/bot/page.tsx @@ -66,7 +66,7 @@ const SlackBotConfigsTable = ({ Channels - Persona + Assistant Document Sets Delete diff --git a/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx b/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx index 80ff1f456b9..204d054a991 100644 --- a/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx +++ b/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx @@ -211,7 +211,7 @@ export function CustomLLMProviderUpdateForm({ > {({ values, setFieldValue }) => { return ( -
+ void; existingLlmProvider?: FullLLMProvider; shouldMarkAsDefault?: boolean; + hideAdvanced?: boolean; setPopup?: (popup: PopupSpec) => void; }) { const { mutate } = useSWRConfig(); @@ -52,7 +54,7 @@ export function LLMProviderUpdateForm({ // Define the initial values based on the provider's requirements const initialValues = { - name: existingLlmProvider?.name ?? "", + name: existingLlmProvider?.name || (hideAdvanced ? "Default" : ""), api_key: existingLlmProvider?.api_key ?? "", api_base: existingLlmProvider?.api_base ?? "", api_version: existingLlmProvider?.api_version ?? "", @@ -218,17 +220,20 @@ export function LLMProviderUpdateForm({ }} > {({ values, setFieldValue }) => ( - - + + {!hideAdvanced && ( + + )} {llmProviderDescriptor.api_key_required && ( (
))} - - - {llmProviderDescriptor.llm_names.length > 0 ? ( - ({ - name: getDisplayNameForModel(name), - value: name, - }))} - maxHeight="max-h-56" - /> - ) : ( - - )} + {!hideAdvanced && ( + <> + + + {llmProviderDescriptor.llm_names.length > 0 ? ( + ({ + name: getDisplayNameForModel(name), + value: name, + }))} + maxHeight="max-h-56" + /> + ) : ( + + )} - {llmProviderDescriptor.llm_names.length > 0 ? ( - 0 ? ( + ({ - name: getDisplayNameForModel(name), - value: name, - }))} - includeDefault - maxHeight="max-h-56" - /> - ) : ( - ({ + name: getDisplayNameForModel(name), + value: name, + }))} + includeDefault + maxHeight="max-h-56" + /> + ) : ( + - )} - - + label="[Optional] Fast Model" + placeholder="E.g. gpt-4" + /> + )} - {llmProviderDescriptor.name != "azure" && ( - - )} + - {showAdvancedOptions && ( - <> - {llmProviderDescriptor.llm_names.length > 0 && ( -
- ({ - value: name, - label: getDisplayNameForModel(name), - }))} - onChange={(selected) => - setFieldValue("display_model_names", selected) - } - /> -
+ {llmProviderDescriptor.name != "azure" && ( + )} - {isPaidEnterpriseFeaturesEnabled && userGroups && ( + {showAdvancedOptions && ( <> - - - {userGroups && userGroups.length > 0 && !values.is_public && ( -
- - Select which User Groups should have access to this LLM - Provider. - -
- {userGroups.map((userGroup) => { - const isSelected = values.groups.includes( - userGroup.id - ); - return ( - { - if (isSelected) { - setFieldValue( - "groups", - values.groups.filter( - (id) => id !== userGroup.id - ) - ); - } else { - setFieldValue("groups", [ - ...values.groups, - userGroup.id, - ]); - } - }} - > -
- -
{userGroup.name}
-
-
- ); - })} -
+ {llmProviderDescriptor.llm_names.length > 0 && ( +
+ ({ + value: name, + label: getDisplayNameForModel(name), + }) + )} + onChange={(selected) => + setFieldValue("display_model_names", selected) + } + />
)} + + {isPaidEnterpriseFeaturesEnabled && userGroups && ( + <> + + + {userGroups && + userGroups.length > 0 && + !values.is_public && ( +
+ + Select which User Groups should have access to + this LLM Provider. + +
+ {userGroups.map((userGroup) => { + const isSelected = values.groups.includes( + userGroup.id + ); + return ( + { + if (isSelected) { + setFieldValue( + "groups", + values.groups.filter( + (id) => id !== userGroup.id + ) + ); + } else { + setFieldValue("groups", [ + ...values.groups, + userGroup.id, + ]); + } + }} + > +
+ +
+ {userGroup.name} +
+
+
+ ); + })} +
+
+ )} + + )} )} @@ -432,6 +450,27 @@ export function LLMProviderUpdateForm({ return; } + // If the deleted provider was the default, set the first remaining provider as default + const remainingProvidersResponse = await fetch( + LLM_PROVIDERS_ADMIN_URL + ); + if (remainingProvidersResponse.ok) { + const remainingProviders = + await remainingProvidersResponse.json(); + + if (remainingProviders.length > 0) { + const setDefaultResponse = await fetch( + `${LLM_PROVIDERS_ADMIN_URL}/${remainingProviders[0].id}/default`, + { + method: "POST", + } + ); + if (!setDefaultResponse.ok) { + console.error("Failed to set new default provider"); + } + } + } + mutate(LLM_PROVIDERS_ADMIN_URL); onClose(); }} diff --git a/web/src/app/admin/configuration/llm/constants.ts b/web/src/app/admin/configuration/llm/constants.ts index a265f4a2b2d..d7e3449b34d 100644 --- a/web/src/app/admin/configuration/llm/constants.ts +++ b/web/src/app/admin/configuration/llm/constants.ts @@ -2,3 +2,5 @@ export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider"; export const EMBEDDING_PROVIDERS_ADMIN_URL = "/api/admin/embedding/embedding-provider"; + +export const EMBEDDING_MODELS_ADMIN_URL = "/api/admin/embedding"; diff --git a/web/src/app/admin/configuration/llm/interfaces.ts b/web/src/app/admin/configuration/llm/interfaces.ts index 2d0d49196b4..33fa94d7f15 100644 --- a/web/src/app/admin/configuration/llm/interfaces.ts +++ b/web/src/app/admin/configuration/llm/interfaces.ts @@ -1,3 +1,13 @@ +import { + AnthropicIcon, + AWSIcon, + AzureIcon, + CPUIcon, + OpenAIIcon, + OpenSourceIcon, +} from "@/components/icons/icons"; +import { FaRobot } from "react-icons/fa"; + export interface CustomConfigKey { name: string; description: string | null; @@ -53,3 +63,18 @@ export interface LLMProviderDescriptor { groups: number[]; display_model_names: string[] | null; } + +export const getProviderIcon = (providerName: string) => { + switch (providerName) { + case "openai": + return OpenAIIcon; + case "anthropic": + return AnthropicIcon; + case "bedrock": + return AWSIcon; + case "azure": + return AzureIcon; + default: + return CPUIcon; + } +}; diff --git a/web/src/app/admin/configuration/search/UpgradingPage.tsx b/web/src/app/admin/configuration/search/UpgradingPage.tsx index da379656336..ff707e932c1 100644 --- a/web/src/app/admin/configuration/search/UpgradingPage.tsx +++ b/web/src/app/admin/configuration/search/UpgradingPage.tsx @@ -1,9 +1,9 @@ import { ThreeDotsLoader } from "@/components/Loading"; import { Modal } from "@/components/Modal"; import { errorHandlingFetcher } from "@/lib/fetcher"; -import { ConnectorIndexingStatus } from "@/lib/types"; +import { ConnectorIndexingStatus, ValidStatuses } from "@/lib/types"; import { Button, Text, Title } from "@tremor/react"; -import { useState } from "react"; +import { useMemo, useState } from "react"; import useSWR, { mutate } from "swr"; import { ReindexingProgressTable } from "../../../../components/embedding/ReindexingProgressTable"; import { ErrorCallout } from "@/components/ErrorCallout"; @@ -48,6 +48,29 @@ export default function UpgradingPage({ } setIsCancelling(false); }; + const statusOrder: Record = { + failed: 0, + completed_with_errors: 1, + not_started: 2, + in_progress: 3, + success: 4, + }; + + const sortedReindexingProgress = useMemo(() => { + return [...(ongoingReIndexingStatus || [])].sort((a, b) => { + const statusComparison = + statusOrder[a.latest_index_attempt?.status || "not_started"] - + statusOrder[b.latest_index_attempt?.status || "not_started"]; + + if (statusComparison !== 0) { + return statusComparison; + } + + return ( + (a.latest_index_attempt?.id || 0) - (b.latest_index_attempt?.id || 0) + ); + }); + }, [ongoingReIndexingStatus]); return ( <> @@ -101,9 +124,9 @@ export default function UpgradingPage({ {isLoadingOngoingReIndexingStatus ? ( - ) : ongoingReIndexingStatus ? ( + ) : sortedReindexingProgress ? ( ) : ( diff --git a/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx b/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx index b9861a29759..d1e6d01964b 100644 --- a/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx +++ b/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx @@ -1,5 +1,6 @@ "use client"; +import { useEffect, useRef } from "react"; import { Table, TableHead, @@ -8,31 +9,172 @@ import { TableBody, TableCell, Text, - Button, - Divider, } from "@tremor/react"; -import { IndexAttemptStatus } from "@/components/Status"; import { CCPairFullInfo } from "./types"; +import { IndexAttemptStatus } from "@/components/Status"; import { useState } from "react"; import { PageSelector } from "@/components/PageSelector"; +import { ThreeDotsLoader } from "@/components/Loading"; +import { buildCCPairInfoUrl } from "./lib"; import { localizeAndPrettify } from "@/lib/time"; import { getDocsProcessedPerMinute } from "@/lib/indexAttempt"; -import { Modal } from "@/components/Modal"; -import { CheckmarkIcon, CopyIcon, SearchIcon } from "@/components/icons/icons"; +import { ErrorCallout } from "@/components/ErrorCallout"; +import { SearchIcon } from "@/components/icons/icons"; import Link from "next/link"; import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal"; +import { PaginatedIndexAttempts } from "./types"; +import { useRouter } from "next/navigation"; +// This is the number of index attempts to display per page const NUM_IN_PAGE = 8; +// This is the number of pages to fetch at a time +const BATCH_SIZE = 8; export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) { - const [page, setPage] = useState(1); const [indexAttemptTracePopupId, setIndexAttemptTracePopupId] = useState< number | null >(null); - const indexAttemptToDisplayTraceFor = ccPair.index_attempts.find( + + const totalPages = Math.ceil(ccPair.number_of_index_attempts / NUM_IN_PAGE); + + const router = useRouter(); + const [page, setPage] = useState(() => { + if (typeof window !== "undefined") { + const urlParams = new URLSearchParams(window.location.search); + return parseInt(urlParams.get("page") || "1", 10); + } + return 1; + }); + + const [currentPageData, setCurrentPageData] = + useState(null); + const [currentPageError, setCurrentPageError] = useState(null); + const [isCurrentPageLoading, setIsCurrentPageLoading] = useState(false); + + // This is a cache of the data for each "batch" which is a set of pages + const [cachedBatches, setCachedBatches] = useState<{ + [key: number]: PaginatedIndexAttempts[]; + }>({}); + + // This is a set of the batches that are currently being fetched + // we use it to avoid duplicate requests + const ongoingRequestsRef = useRef>(new Set()); + + const batchRetrievalUrlBuilder = (batchNum: number) => + `${buildCCPairInfoUrl(ccPair.id)}/index-attempts?page=${batchNum}&page_size=${BATCH_SIZE * NUM_IN_PAGE}`; + + // This fetches and caches the data for a given batch number + const fetchBatchData = async (batchNum: number) => { + if (ongoingRequestsRef.current.has(batchNum)) return; + ongoingRequestsRef.current.add(batchNum); + + try { + const response = await fetch(batchRetrievalUrlBuilder(batchNum + 1)); + if (!response.ok) { + throw new Error("Failed to fetch data"); + } + const data = await response.json(); + + const newBatchData: PaginatedIndexAttempts[] = []; + for (let i = 0; i < BATCH_SIZE; i++) { + const startIndex = i * NUM_IN_PAGE; + const endIndex = startIndex + NUM_IN_PAGE; + const pageIndexAttempts = data.index_attempts.slice( + startIndex, + endIndex + ); + newBatchData.push({ + ...data, + index_attempts: pageIndexAttempts, + }); + } + + setCachedBatches((prev) => ({ + ...prev, + [batchNum]: newBatchData, + })); + } catch (error) { + setCurrentPageError( + error instanceof Error ? error : new Error("An error occurred") + ); + } finally { + ongoingRequestsRef.current.delete(batchNum); + } + }; + + // This fetches and caches the data for the current batch and the next and previous batches + useEffect(() => { + const batchNum = Math.floor((page - 1) / BATCH_SIZE); + + if (!cachedBatches[batchNum]) { + setIsCurrentPageLoading(true); + fetchBatchData(batchNum); + } else { + setIsCurrentPageLoading(false); + } + + const nextBatchNum = Math.min( + batchNum + 1, + Math.ceil(totalPages / BATCH_SIZE) - 1 + ); + if (!cachedBatches[nextBatchNum]) { + fetchBatchData(nextBatchNum); + } + + const prevBatchNum = Math.max(batchNum - 1, 0); + if (!cachedBatches[prevBatchNum]) { + fetchBatchData(prevBatchNum); + } + + // Always fetch the first batch if it's not cached + if (!cachedBatches[0]) { + fetchBatchData(0); + } + }, [ccPair.id, page, cachedBatches, totalPages]); + + // This updates the data on the current page + useEffect(() => { + const batchNum = Math.floor((page - 1) / BATCH_SIZE); + const batchPageNum = (page - 1) % BATCH_SIZE; + + if (cachedBatches[batchNum] && cachedBatches[batchNum][batchPageNum]) { + setCurrentPageData(cachedBatches[batchNum][batchPageNum]); + setIsCurrentPageLoading(false); + } else { + setIsCurrentPageLoading(true); + } + }, [page, cachedBatches]); + + // This updates the page number and manages the URL + const updatePage = (newPage: number) => { + setPage(newPage); + router.push(`/admin/connector/${ccPair.id}?page=${newPage}`, { + scroll: false, + }); + window.scrollTo({ + top: 0, + left: 0, + behavior: "smooth", + }); + }; + + if (isCurrentPageLoading || !currentPageData) { + return ; + } + + if (currentPageError) { + return ( + + ); + } + + // This is the index attempt that the user wants to view the trace for + const indexAttemptToDisplayTraceFor = currentPageData?.index_attempts?.find( (indexAttempt) => indexAttempt.id === indexAttemptTracePopupId ); - const [copyClicked, setCopyClicked] = useState(false); return ( <> @@ -55,101 +197,92 @@ export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) { - {ccPair.index_attempts - .slice(NUM_IN_PAGE * (page - 1), NUM_IN_PAGE * page) - .map((indexAttempt) => { - const docsPerMinute = - getDocsProcessedPerMinute(indexAttempt)?.toFixed(2); - return ( - - - {indexAttempt.time_started - ? localizeAndPrettify(indexAttempt.time_started) - : "-"} - - - - {docsPerMinute && ( -
- {docsPerMinute} docs / min -
- )} -
- -
-
-
{indexAttempt.new_docs_indexed}
- {indexAttempt.docs_removed_from_index > 0 && ( -
- (also removed {indexAttempt.docs_removed_from_index}{" "} - docs that were detected as deleted in the source) -
- )} -
+ {currentPageData.index_attempts.map((indexAttempt) => { + const docsPerMinute = + getDocsProcessedPerMinute(indexAttempt)?.toFixed(2); + return ( + + + {indexAttempt.time_started + ? localizeAndPrettify(indexAttempt.time_started) + : "-"} + + + + {docsPerMinute && ( +
+ {docsPerMinute} docs / min
-
- {indexAttempt.total_docs_indexed} - -
- {indexAttempt.error_count > 0 && ( - - - -  View Errors - - + )} + + +
+
+
{indexAttempt.new_docs_indexed}
+ {indexAttempt.docs_removed_from_index > 0 && ( +
+ (also removed {indexAttempt.docs_removed_from_index}{" "} + docs that were detected as deleted in the source) +
)} +
+
+
+ {indexAttempt.total_docs_indexed} + +
+ {indexAttempt.error_count > 0 && ( + + + +  View Errors + + + )} - {indexAttempt.status === "success" && ( + {indexAttempt.status === "success" && ( + + {"-"} + + )} + + {indexAttempt.status === "failed" && + indexAttempt.error_msg && ( - {"-"} + {indexAttempt.error_msg} )} - {indexAttempt.status === "failed" && - indexAttempt.error_msg && ( - - {indexAttempt.error_msg} - - )} - - {indexAttempt.full_exception_trace && ( -
{ - setIndexAttemptTracePopupId(indexAttempt.id); - }} - className="mt-2 text-link cursor-pointer select-none" - > - View Full Trace -
- )} -
-
- - ); - })} + {indexAttempt.full_exception_trace && ( +
{ + setIndexAttemptTracePopupId(indexAttempt.id); + }} + className="mt-2 text-link cursor-pointer select-none" + > + View Full Trace +
+ )} +
+
+
+ ); + })} - {ccPair.index_attempts.length > NUM_IN_PAGE && ( + {totalPages > 1 && (
{ - setPage(newPage); - window.scrollTo({ - top: 0, - left: 0, - behavior: "smooth", - }); - }} + onPageChange={updatePage} />
diff --git a/web/src/app/admin/connector/[ccPairId]/page.tsx b/web/src/app/admin/connector/[ccPairId]/page.tsx index f5da225a867..f2e8a8de8cc 100644 --- a/web/src/app/admin/connector/[ccPairId]/page.tsx +++ b/web/src/app/admin/connector/[ccPairId]/page.tsx @@ -1,7 +1,6 @@ "use client"; import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types"; -import { HealthCheckBanner } from "@/components/health/healthcheck"; import { CCPairStatus } from "@/components/Status"; import { BackButton } from "@/components/BackButton"; import { Button, Divider, Title } from "@tremor/react"; @@ -11,7 +10,6 @@ import { ModifyStatusButtonCluster } from "./ModifyStatusButtonCluster"; import { DeletionButton } from "./DeletionButton"; import { ErrorCallout } from "@/components/ErrorCallout"; import { ReIndexButton } from "./ReIndexButton"; -import { isCurrentlyDeleting } from "@/lib/documentDeletion"; import { ValidSources } from "@/lib/types"; import useSWR, { mutate } from "swr"; import { errorHandlingFetcher } from "@/lib/fetcher"; @@ -86,24 +84,13 @@ function Main({ ccPairId }: { ccPairId: number }) { return ( ); } - const lastIndexAttempt = ccPair.index_attempts[0]; const isDeleting = ccPair.status === ConnectorCredentialPairStatus.DELETING; - // figure out if we need to artificially deflate the number of docs indexed. - // This is required since the total number of docs indexed by a CC Pair is - // updated before the new docs for an indexing attempt. If we don't do this, - // there is a mismatch between these two numbers which may confuse users. - const totalDocsIndexed = - lastIndexAttempt?.status === "in_progress" && - ccPair.index_attempts.length === 1 - ? lastIndexAttempt.total_docs_indexed - : ccPair.num_docs_indexed; - const refresh = () => { mutate(buildCCPairInfoUrl(ccPairId)); }; @@ -182,13 +169,13 @@ function Main({ ccPairId }: { ccPairId: number }) { )}
Total Documents Indexed:{" "} - {totalDocsIndexed} + {ccPair.num_docs_indexed}
{!ccPair.is_editable_for_current_user && (
diff --git a/web/src/app/admin/connector/[ccPairId]/types.ts b/web/src/app/admin/connector/[ccPairId]/types.ts index 1cc43311e21..f44b958b095 100644 --- a/web/src/app/admin/connector/[ccPairId]/types.ts +++ b/web/src/app/admin/connector/[ccPairId]/types.ts @@ -1,6 +1,10 @@ import { Connector } from "@/lib/connectors/connectors"; import { Credential } from "@/lib/connectors/credentials"; -import { DeletionAttemptSnapshot, IndexAttemptSnapshot } from "@/lib/types"; +import { + DeletionAttemptSnapshot, + IndexAttemptSnapshot, + ValidStatuses, +} from "@/lib/types"; export enum ConnectorCredentialPairStatus { ACTIVE = "ACTIVE", @@ -15,8 +19,15 @@ export interface CCPairFullInfo { num_docs_indexed: number; connector: Connector; credential: Credential; - index_attempts: IndexAttemptSnapshot[]; + number_of_index_attempts: number; + last_index_attempt_status: ValidStatuses | null; latest_deletion_attempt: DeletionAttemptSnapshot | null; is_public: boolean; is_editable_for_current_user: boolean; } + +export interface PaginatedIndexAttempts { + index_attempts: IndexAttemptSnapshot[]; + page: number; + total_pages: number; +} diff --git a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx index dd8d19ca720..895abf95d1a 100644 --- a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx +++ b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx @@ -22,7 +22,7 @@ import AdvancedFormPage from "./pages/Advanced"; import DynamicConnectionForm from "./pages/DynamicConnectorCreationForm"; import CreateCredential from "@/components/credentials/actions/CreateCredential"; import ModifyCredential from "@/components/credentials/actions/ModifyCredential"; -import { ValidSources } from "@/lib/types"; +import { ConfigurableSources, ValidSources } from "@/lib/types"; import { Credential, credentialTemplates } from "@/lib/connectors/credentials"; import { ConnectionConfiguration, @@ -44,7 +44,6 @@ import { IsPublicGroupSelectorFormType, } from "@/components/IsPublicGroupSelector"; import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; -import { AdminBooleanFormField } from "@/components/credentials/CredentialFields"; export type AdvancedConfigFinal = { pruneFreq: number | null; @@ -55,7 +54,7 @@ export type AdvancedConfigFinal = { export default function AddConnector({ connector, }: { - connector: ValidSources; + connector: ConfigurableSources; }) { const [currentCredential, setCurrentCredential] = useState | null>(null); @@ -92,10 +91,12 @@ export default function AddConnector({ >({ name: "", groups: [], - is_public: false, + is_public: true, ...configuration.values.reduce( (acc, field) => { - if (field.type === "list") { + if (field.type === "select") { + acc[field.name] = null; + } else if (field.type === "list") { acc[field.name] = field.default || []; } else if (field.type === "checkbox") { acc[field.name] = field.default || false; @@ -196,7 +197,7 @@ export default function AddConnector({ }; // google sites-specific handling - if (connector == "google_site") { + if (connector == "google_sites") { const response = await submitGoogleSite( selectedFiles, formValues?.base_url, @@ -338,11 +339,13 @@ export default function AddConnector({ ...configuration.values.reduce( (acc, field) => { let schema: any = - field.type === "list" - ? Yup.array().of(Yup.string()) - : field.type === "checkbox" - ? Yup.boolean() - : Yup.string(); + field.type === "select" + ? Yup.string() + : field.type === "list" + ? Yup.array().of(Yup.string()) + : field.type === "checkbox" + ? Yup.boolean() + : Yup.string(); if (!field.optional) { schema = schema.required(`${field.label} is required`); @@ -444,26 +447,29 @@ export default function AddConnector({ )} - {!(connector == "google_drive") && createConnectorToggle && ( - setCreateConnectorToggle(false)} - > - <> - - Create a {getSourceDisplayName(connector)} credential - - setCreateConnectorToggle(false)} - /> - - - )} + {/* NOTE: connector will never be google_drive, since the ternary above will + prevent that, but still keeping this here for safety in case the above changes. */} + {(connector as ValidSources) !== "google_drive" && + createConnectorToggle && ( + setCreateConnectorToggle(false)} + > + <> + + Create a {getSourceDisplayName(connector)} credential + + setCreateConnectorToggle(false)} + /> + + + )}
) : ( - + )}
diff --git a/web/src/app/admin/connectors/[connector]/Sidebar.tsx b/web/src/app/admin/connectors/[connector]/Sidebar.tsx index 5d678938dc7..e0c85029fb0 100644 --- a/web/src/app/admin/connectors/[connector]/Sidebar.tsx +++ b/web/src/app/admin/connectors/[connector]/Sidebar.tsx @@ -25,9 +25,10 @@ export default function Sidebar() { ]; return ( -
+
; + return ( + + ); } diff --git a/web/src/app/admin/connectors/[connector]/pages/Advanced.tsx b/web/src/app/admin/connectors/[connector]/pages/Advanced.tsx index 470ab8d2a77..8bb96d54db1 100644 --- a/web/src/app/admin/connectors/[connector]/pages/Advanced.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/Advanced.tsx @@ -25,7 +25,6 @@ const AdvancedFormPage = forwardRef, AdvancedFormPageProps>(
, AdvancedFormPageProps>( diff --git a/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/NumberInput.tsx b/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/NumberInput.tsx index 5a9f5041b5d..a62864495ef 100644 --- a/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/NumberInput.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/NumberInput.tsx @@ -1,5 +1,5 @@ import { SubLabel } from "@/components/admin/connectors/Field"; -import { Field } from "formik"; +import { Field, useFormikContext } from "formik"; export default function NumberInput({ label, @@ -8,6 +8,7 @@ export default function NumberInput({ description, name, showNeverIfZero, + onChange, }: { value?: number; label: string; @@ -15,7 +16,10 @@ export default function NumberInput({ optional?: boolean; description?: string; showNeverIfZero?: boolean; + onChange?: (value: number) => void; }) { + const { setFieldValue } = useFormikContext(); + return (
); } diff --git a/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx b/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx index 4494e4b22ee..247b64e61b4 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx @@ -8,8 +8,6 @@ import { ErrorCallout } from "@/components/ErrorCallout"; import { LoadingAnimation } from "@/components/Loading"; import { usePopup } from "@/components/admin/connectors/Popup"; import { ConnectorIndexingStatus } from "@/lib/types"; -import { getCurrentUser } from "@/lib/user"; -import { User, UserRole } from "@/lib/types"; import { usePublicCredentials } from "@/lib/hooks"; import { Title } from "@tremor/react"; import { DriveJsonUploadSection, DriveOAuthSection } from "./Credential"; @@ -109,6 +107,7 @@ const GDriveMain = ({}: {}) => { | undefined = credentialsData.find( (credential) => credential.credential_json?.google_drive_service_account_key ); + const googleDriveConnectorIndexingStatuses: ConnectorIndexingStatus< GoogleDriveConfig, GoogleDriveCredentialJson diff --git a/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx b/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx index 814af4e2863..2778103e345 100644 --- a/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx +++ b/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx @@ -125,14 +125,12 @@ export const DocumentSetCreationForm = ({ placeholder="Describe what the document set represents" autoCompleteDisabled={true} /> - {isPaidEnterpriseFeaturesEnabled && - userGroups && - userGroups.length > 0 && ( - - )} + {isPaidEnterpriseFeaturesEnabled && ( + + )} diff --git a/web/src/app/admin/documents/sets/page.tsx b/web/src/app/admin/documents/sets/page.tsx index 718b81ab0b6..41104f9c343 100644 --- a/web/src/app/admin/documents/sets/page.tsx +++ b/web/src/app/admin/documents/sets/page.tsx @@ -67,10 +67,11 @@ const EditRow = ({
)}
{ if (documentSet.is_up_to_date) { router.push(`/admin/documents/sets/${documentSet.id}`); @@ -87,8 +88,8 @@ const EditRow = ({ } }} > - - {documentSet.name} + + {documentSet.name}
); diff --git a/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx b/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx index 1b9fffda428..2b4394c56b5 100644 --- a/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx +++ b/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx @@ -24,10 +24,14 @@ import { ChangeCredentialsModal } from "./modals/ChangeCredentialsModal"; import { ModelSelectionConfirmationModal } from "./modals/ModelSelectionModal"; import { AlreadyPickedModal } from "./modals/AlreadyPickedModal"; import { ModelOption } from "../../../components/embedding/ModelSelector"; -import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../configuration/llm/constants"; +import { + EMBEDDING_MODELS_ADMIN_URL, + EMBEDDING_PROVIDERS_ADMIN_URL, +} from "../configuration/llm/constants"; export interface EmbeddingDetails { - api_key: string; + api_key?: string; + api_url?: string; custom_config: any; provider_type: EmbeddingProvider; } @@ -77,12 +81,20 @@ export function EmbeddingModelSelection({ const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] = useState(false); + const [showAddConnectorPopup, setShowAddConnectorPopup] = useState(false); + const { data: embeddingModelDetails } = useSWR( + EMBEDDING_MODELS_ADMIN_URL, + errorHandlingFetcher, + { refreshInterval: 5000 } // 5 seconds + ); + const { data: embeddingProviderDetails } = useSWR( EMBEDDING_PROVIDERS_ADMIN_URL, - errorHandlingFetcher + errorHandlingFetcher, + { refreshInterval: 5000 } // 5 seconds ); const { data: connectors } = useSWR[]>( @@ -175,6 +187,7 @@ export function EmbeddingModelSelection({ {showTentativeProvider && ( { setShowTentativeProvider(showUnconfiguredProvider); @@ -189,8 +202,10 @@ export function EmbeddingModelSelection({ }} /> )} + {changeCredentialsProvider && ( { clientsideRemoveProvider(changeCredentialsProvider); @@ -277,9 +292,10 @@ export function EmbeddingModelSelection({ {modelTab == "cloud" && ( { const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false); + const [showLiteLLMConfigurationModal, setShowLiteLLMConfigurationModal] = + useState(false); return ( -
-

- Post-processing -

-
- {originalRerankingDetails.rerank_model_name && ( - - )} -
- -
+ () + .nullable() + .oneOf(Object.values(RerankerProvider)) + .optional(), + api_key: Yup.string().nullable(), + num_rerank: Yup.number().min(1, "Must be at least 1"), + rerank_api_url: Yup.string() + .url("Must be a valid URL") + .matches(/^https?:\/\//, "URL must start with http:// or https://") + .nullable(), + })} + onSubmit={async (_, { setSubmitting }) => { + setSubmitting(false); + }} + enableReinitialize={true} + > + {({ values, setFieldValue, resetForm }) => { + const resetRerankingValues = () => { + setRerankingDetails({ + ...values, + rerank_provider_type: null!, + rerank_model_name: null, + }); + setFieldValue("rerank_provider_type", null); + setFieldValue("rerank_model_name", null); + setFieldValue("rerank_api_key", null); + }; -
- -
-
+ return ( +
+

+ Post-processing +

+
+ {originalRerankingDetails.rerank_model_name && ( + + )} +
+ +
- () - .nullable() - .oneOf(Object.values(RerankerProvider)) - .optional(), - api_key: Yup.string().nullable(), - num_rerank: Yup.number().min(1, "Must be at least 1"), - })} - onSubmit={async (_, { setSubmitting }) => { - setSubmitting(false); - }} - enableReinitialize={true} - > - {({ values, setFieldValue }) => ( -
-
- {(modelTab - ? rerankingModels.filter( - (model) => model.cloud == (modelTab == "cloud") - ) - : rerankingModels.filter( - (modelCard) => - modelCard.modelName == - originalRerankingDetails.rerank_model_name - ) - ).map((card) => { - const isSelected = - values.rerank_provider_type === card.rerank_provider_type && - values.rerank_model_name === card.modelName; - return ( -
{ - if (card.rerank_provider_type) { - setIsApiKeyModalOpen(true); - } - setRerankingDetails({ - ...values, - rerank_provider_type: card.rerank_provider_type!, - rerank_model_name: card.modelName, - }); - setFieldValue( - "rerank_provider_type", - card.rerank_provider_type - ); - setFieldValue("rerank_model_name", card.modelName); - }} +
+ +
+ {values.rerank_model_name && ( +
+ +
+ )} +
+ + +
+ {(modelTab + ? rerankingModels.filter( + (model) => model.cloud == (modelTab == "cloud") + ) + : rerankingModels.filter( + (modelCard) => + (modelCard.modelName == + originalRerankingDetails.rerank_model_name && + modelCard.rerank_provider_type == + originalRerankingDetails.rerank_provider_type) || + (modelCard.rerank_provider_type == + RerankerProvider.LITELLM && + originalRerankingDetails.rerank_provider_type == + RerankerProvider.LITELLM) + ) + ).map((card) => { + const isSelected = + values.rerank_provider_type === + card.rerank_provider_type && + (card.modelName == null || + values.rerank_model_name === card.modelName); + + return ( +
{ + if ( + card.rerank_provider_type == RerankerProvider.COHERE + ) { + setIsApiKeyModalOpen(true); + } else if ( + card.rerank_provider_type == + RerankerProvider.LITELLM + ) { + setShowLiteLLMConfigurationModal(true); + } + + if (!isSelected) { + setRerankingDetails({ + ...values, + rerank_provider_type: card.rerank_provider_type!, + rerank_model_name: card.modelName || null, + rerank_api_key: null, + rerank_api_url: null, + }); + setFieldValue( + "rerank_provider_type", + card.rerank_provider_type + ); + setFieldValue("rerank_model_name", card.modelName); + } + }} + > +
+
+ {card.rerank_provider_type === + RerankerProvider.LITELLM ? ( + + ) : RerankerProvider.COHERE ? ( + + ) : ( + + )} +

+ {card.displayName} +

+
+ {card.link && ( + e.stopPropagation()} + className="text-blue-500 hover:text-blue-700 transition-colors duration-200" + > + + )} -

- {card.displayName} -

- {card.link && ( - e.stopPropagation()} - className="text-blue-500 hover:text-blue-700 transition-colors duration-200" - > - - - )} +

+ {card.description} +

+
+ {card.cloud ? "Cloud-based" : "Self-hosted"} +
-

- {card.description} -

-
- {card.cloud ? "Cloud-based" : "Self-hosted"} + ); + })} +
+ + {showLiteLLMConfigurationModal && ( + { + resetForm(); + setShowLiteLLMConfigurationModal(false); + }} + width="w-[800px]" + title="API Key Configuration" + > +
+ ) => { + const value = e.target.value; + setRerankingDetails({ + ...values, + rerank_api_url: value, + }); + setFieldValue("rerank_api_url", value); + }} + type="text" + label="LiteLLM Proxy URL" + name="rerank_api_url" + /> + + ) => { + const value = e.target.value; + setRerankingDetails({ + ...values, + rerank_api_key: value, + }); + setFieldValue("rerank_api_key", value); + }} + type="password" + label="LiteLLM Proxy Key" + name="rerank_api_key" + optional + /> + + ) => { + const value = e.target.value; + setRerankingDetails({ + ...values, + rerank_model_name: value, + }); + setFieldValue("rerank_model_name", value); + }} + label="LiteLLM Model Name" + name="rerank_model_name" + optional + /> + +
+
- ); - })} -
+ + )} - {isApiKeyModalOpen && ( - { - Object.keys(originalRerankingDetails).forEach((key) => { - setFieldValue( - key, - originalRerankingDetails[key as keyof RerankingDetails] - ); - }); - - setIsApiKeyModalOpen(false); - }} - width="w-[800px]" - title="API Key Configuration" - > -
- ) => { - const value = e.target.value; - setRerankingDetails({ ...values, api_key: value }); - setFieldValue("api_key", value); - }} - type="password" - label="Cohere API Key" - name="api_key" - /> -
- - + type="password" + label="Cohere API Key" + name="api_key" + /> +
+ +
-
-
- )} - - )} - -
+ + )} + +
+ ); + }} + ); } ); diff --git a/web/src/app/admin/embeddings/interfaces.ts b/web/src/app/admin/embeddings/interfaces.ts index 335292f73e6..70afb9830f8 100644 --- a/web/src/app/admin/embeddings/interfaces.ts +++ b/web/src/app/admin/embeddings/interfaces.ts @@ -4,12 +4,14 @@ import { NonNullChain } from "typescript"; export interface RerankingDetails { rerank_model_name: string | null; rerank_provider_type: RerankerProvider | null; - api_key: string | null; + rerank_api_key: string | null; + rerank_api_url: string | null; num_rerank: number; } export enum RerankerProvider { COHERE = "cohere", + LITELLM = "litellm", } export interface AdvancedSearchConfiguration { model_name: string; @@ -21,6 +23,7 @@ export interface AdvancedSearchConfiguration { multipass_indexing: boolean; multilingual_expansion: string[]; disable_rerank_for_streaming: boolean; + api_url: string | null; } export interface SavedSearchSettings extends RerankingDetails { @@ -33,12 +36,13 @@ export interface SavedSearchSettings extends RerankingDetails { multipass_indexing: boolean; multilingual_expansion: string[]; disable_rerank_for_streaming: boolean; + api_url: string | null; provider_type: EmbeddingProvider | null; } export interface RerankingModel { rerank_provider_type: RerankerProvider | null; - modelName: string; + modelName?: string; displayName: string; description: string; link: string; @@ -46,6 +50,13 @@ export interface RerankingModel { } export const rerankingModels: RerankingModel[] = [ + { + rerank_provider_type: RerankerProvider.LITELLM, + cloud: true, + displayName: "LiteLLM", + description: "Host your own reranker or router with LiteLLM proxy", + link: "https://docs.litellm.ai/docs/proxy", + }, { rerank_provider_type: null, cloud: false, diff --git a/web/src/app/admin/embeddings/modals/ChangeCredentialsModal.tsx b/web/src/app/admin/embeddings/modals/ChangeCredentialsModal.tsx index c2f3923e5cd..636aa562474 100644 --- a/web/src/app/admin/embeddings/modals/ChangeCredentialsModal.tsx +++ b/web/src/app/admin/embeddings/modals/ChangeCredentialsModal.tsx @@ -15,14 +15,19 @@ export function ChangeCredentialsModal({ onCancel, onDeleted, useFileUpload, + isProxy = false, }: { provider: CloudEmbeddingProvider; onConfirm: () => void; onCancel: () => void; onDeleted: () => void; useFileUpload: boolean; + isProxy?: boolean; }) { const [apiKey, setApiKey] = useState(""); + const [apiUrl, setApiUrl] = useState(""); + const [modelName, setModelName] = useState(""); + const [testError, setTestError] = useState(""); const [fileName, setFileName] = useState(""); const fileInputRef = useRef(null); @@ -74,7 +79,7 @@ export function ChangeCredentialsModal({ try { const response = await fetch( - `${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type}`, + `${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type.toLowerCase()}`, { method: "DELETE", } @@ -99,13 +104,18 @@ export function ChangeCredentialsModal({ const handleSubmit = async () => { setTestError(""); + const normalizedProviderType = provider.provider_type + .toLowerCase() + .split(" ")[0]; try { const testResponse = await fetch("/api/admin/embedding/test-embedding", { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ - provider_type: provider.provider_type.toLowerCase().split(" ")[0], + provider_type: normalizedProviderType, api_key: apiKey, + api_url: apiUrl, + model_name: modelName, }), }); @@ -118,8 +128,9 @@ export function ChangeCredentialsModal({ method: "PUT", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ - provider_type: provider.provider_type.toLowerCase().split(" ")[0], + provider_type: normalizedProviderType, api_key: apiKey, + api_url: apiUrl, is_default_provider: false, is_configured: true, }), @@ -128,7 +139,8 @@ export function ChangeCredentialsModal({ if (!updateResponse.ok) { const errorData = await updateResponse.json(); throw new Error( - errorData.detail || "Failed to update provider- check your API key" + errorData.detail || + `Failed to update provider- check your ${isProxy ? "API URL" : "API key"}` ); } @@ -144,26 +156,20 @@ export function ChangeCredentialsModal({ -
- - Want to swap out your key? - - - Visit API - - -
+ <> +

+ You can modify your configuration by providing a new API key + {isProxy ? " or API URL." : "."} +

+ +
+ {useFileUpload ? ( <> - + )} -
- {testError && ( - - {testError} - - )} + {isProxy && ( + <> + + + setApiUrl(e.target.value)} + placeholder="Paste your API URL here" + /> + + {deletionError && ( + + {deletionError} + + )} + +
+ +

+ Since you are using a liteLLM proxy, we'll need a model + name to test the connection with. +

+
+ setModelName(e.target.value)} + placeholder="Paste your API URL here" + /> + + {deletionError && ( + + {deletionError} + + )} + + )} + + {testError && ( + + {testError} + + )} -
+ + + + + You can also delete your configuration. + + + This is only possible if you have already switched to a different + embedding type! + + + + {deletionError && ( + + {deletionError} + + )}
- - - - You can also delete your key. - - - This is only possible if you have already switched to a different - embedding type! - - - - {deletionError && ( - - {deletionError} - - )} -
+ ); } diff --git a/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx b/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx index 4b2ad9c51fc..54ca5d72e78 100644 --- a/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx +++ b/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx @@ -1,6 +1,6 @@ import React, { useRef, useState } from "react"; import { Text, Button, Callout } from "@tremor/react"; -import { Formik, Form, Field } from "formik"; +import { Formik, Form } from "formik"; import * as Yup from "yup"; import { Label, TextFormField } from "@/components/admin/connectors/Field"; import { LoadingAnimation } from "@/components/Loading"; @@ -13,11 +13,13 @@ export function ProviderCreationModal({ onConfirm, onCancel, existingProvider, + isProxy, }: { selectedProvider: CloudEmbeddingProvider; onConfirm: () => void; onCancel: () => void; existingProvider?: CloudEmbeddingProvider; + isProxy?: boolean; }) { const useFileUpload = selectedProvider.provider_type == "Google"; @@ -29,17 +31,27 @@ export function ProviderCreationModal({ provider_type: existingProvider?.provider_type || selectedProvider.provider_type, api_key: existingProvider?.api_key || "", + api_url: existingProvider?.api_url || "", custom_config: existingProvider?.custom_config ? Object.entries(existingProvider.custom_config) : [], model_id: 0, + model_name: null, }; const validationSchema = Yup.object({ provider_type: Yup.string().required("Provider type is required"), - api_key: useFileUpload + api_key: isProxy ? Yup.string() - : Yup.string().required("API Key is required"), + : useFileUpload + ? Yup.string() + : Yup.string().required("API Key is required"), + model_name: isProxy + ? Yup.string().required("Model name is required") + : Yup.string().nullable(), + api_url: isProxy + ? Yup.string().required("API URL is required") + : Yup.string(), custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)), }); @@ -87,6 +99,8 @@ export function ProviderCreationModal({ body: JSON.stringify({ provider_type: values.provider_type.toLowerCase().split(" ")[0], api_key: values.api_key, + api_url: values.api_url, + model_name: values.model_name, }), } ); @@ -169,11 +183,28 @@ export function ProviderCreationModal({ target="_blank" href={selectedProvider.apiLink} > - API KEY + {isProxy ? "API URL" : "API KEY"} -
+
+ {isProxy && ( + <> + + + + )} + {useFileUpload ? ( <> @@ -189,7 +220,7 @@ export function ProviderCreationModal({ ) : ( diff --git a/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx b/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx index 89d885a1368..4f4df0a465c 100644 --- a/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx +++ b/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx @@ -4,7 +4,7 @@ import * as Yup from "yup"; import CredentialSubText from "@/components/credentials/CredentialFields"; import { TrashIcon } from "@/components/icons/icons"; import { FaPlus } from "react-icons/fa"; -import { AdvancedSearchConfiguration, RerankingDetails } from "../interfaces"; +import { AdvancedSearchConfiguration } from "../interfaces"; import { BooleanFormField } from "@/components/admin/connectors/Field"; import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput"; @@ -15,6 +15,7 @@ interface AdvancedEmbeddingFormPageProps { ) => void; advancedEmbeddingDetails: AdvancedSearchConfiguration; numRerank: number; + updateNumRerank: (value: number) => void; } const AdvancedEmbeddingFormPage = forwardRef< @@ -22,7 +23,12 @@ const AdvancedEmbeddingFormPage = forwardRef< AdvancedEmbeddingFormPageProps >( ( - { updateAdvancedEmbeddingDetails, advancedEmbeddingDetails, numRerank }, + { + updateAdvancedEmbeddingDetails, + advancedEmbeddingDetails, + numRerank, + updateNumRerank, + }, ref ) => { return ( @@ -154,6 +160,10 @@ const AdvancedEmbeddingFormPage = forwardRef< name="disableRerankForStreaming" /> { + updateNumRerank(value); + setFieldValue("num_rerank", value); + }} description="Number of results to rerank" optional={false} value={values.num_rerank} diff --git a/web/src/app/admin/embeddings/pages/CloudEmbeddingPage.tsx b/web/src/app/admin/embeddings/pages/CloudEmbeddingPage.tsx index a7a7a1553a5..a6c71530f24 100644 --- a/web/src/app/admin/embeddings/pages/CloudEmbeddingPage.tsx +++ b/web/src/app/admin/embeddings/pages/CloudEmbeddingPage.tsx @@ -1,6 +1,6 @@ "use client"; -import { Text, Title } from "@tremor/react"; +import { Button, Card, Text, Title } from "@tremor/react"; import { CloudEmbeddingProvider, @@ -8,15 +8,22 @@ import { AVAILABLE_CLOUD_PROVIDERS, CloudEmbeddingProviderFull, EmbeddingModelDescriptor, + EmbeddingProvider, + LITELLM_CLOUD_PROVIDER, } from "../../../../components/embedding/interfaces"; import { EmbeddingDetails } from "../EmbeddingModelSelectionForm"; -import { FiExternalLink, FiInfo } from "react-icons/fi"; +import { FiExternalLink, FiInfo, FiTrash } from "react-icons/fi"; import { HoverPopup } from "@/components/HoverPopup"; -import { Dispatch, SetStateAction } from "react"; +import { Dispatch, SetStateAction, useEffect, useState } from "react"; +import { LiteLLMModelForm } from "@/components/embedding/LiteLLMModelForm"; +import { deleteSearchSettings } from "./utils"; +import { usePopup } from "@/components/admin/connectors/Popup"; +import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal"; export default function CloudEmbeddingPage({ currentModel, embeddingProviderDetails, + embeddingModelDetails, newEnabledProviders, newUnenabledProviders, setShowTentativeProvider, @@ -30,6 +37,7 @@ export default function CloudEmbeddingPage({ currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel; setAlreadySelectedModel: Dispatch>; newUnenabledProviders: string[]; + embeddingModelDetails?: CloudEmbeddingModel[]; embeddingProviderDetails?: EmbeddingDetails[]; newEnabledProviders: string[]; setShowTentativeProvider: React.Dispatch< @@ -61,6 +69,17 @@ export default function CloudEmbeddingPage({ ))!), }) ); + const [liteLLMProvider, setLiteLLMProvider] = useState< + EmbeddingDetails | undefined + >(undefined); + + useEffect(() => { + const foundProvider = embeddingProviderDetails?.find( + (provider) => + provider.provider_type === EmbeddingProvider.LITELLM.toLowerCase() + ); + setLiteLLMProvider(foundProvider); + }, [embeddingProviderDetails]); return (
@@ -122,6 +141,127 @@ export default function CloudEmbeddingPage({
))} + + + Alternatively, you can use a self-hosted model using the LiteLLM + proxy. This allows you to leverage various LLM providers through a + unified interface that you control.{" "} + + Learn more about LiteLLM + + + +
+
+ {LITELLM_CLOUD_PROVIDER.icon({ size: 40 })} +

+ {LITELLM_CLOUD_PROVIDER.provider_type}{" "} + {LITELLM_CLOUD_PROVIDER.provider_type == "Cohere" && + "(recommended)"} +

+ + } + popupContent={ +
+
+ {LITELLM_CLOUD_PROVIDER.description} +
+
+ } + style="dark" + /> +
+
+ {!liteLLMProvider ? ( + + ) : ( + + )} + + {!liteLLMProvider && ( + +
+ + API URL Required + + + Before you can add models, you need to provide an API URL + for your LiteLLM proxy. Click the "Provide API + URL" button above to set up your LiteLLM configuration. + +
+ + + Once configured, you'll be able to add and manage + your LiteLLM models here. + +
+
+
+ )} + {liteLLMProvider && ( + <> +
+ {embeddingModelDetails + ?.filter( + (model) => + model.provider_type === + EmbeddingProvider.LITELLM.toLowerCase() + ) + .map((model) => ( + + ))} +
+ + + + + + )} +
+
); @@ -146,7 +286,32 @@ export function CloudModelCard({ React.SetStateAction >; }) { - const enabled = model.model_name === currentModel.model_name; + const { popup, setPopup } = usePopup(); + const [showDeleteModel, setShowDeleteModel] = useState(false); + const enabled = + model.model_name === currentModel.model_name && + model.provider_type?.toLowerCase() == + currentModel.provider_type?.toLowerCase(); + + const deleteModel = async () => { + if (!model.id) { + setPopup({ message: "Model cannot be deleted", type: "error" }); + return; + } + + const response = await deleteSearchSettings(model.id); + + if (response.ok) { + setPopup({ message: "Model deleted successfully", type: "success" }); + setShowDeleteModel(false); + } else { + setPopup({ + message: + "Failed to delete model. Ensure you are not attempting to delete a curently active model.", + type: "error", + }); + } + }; return (
+ {popup} + {showDeleteModel && ( + deleteModel()} + onClose={() => setShowDeleteModel(false)} + /> + )} +

{model.model_name}

- e.stopPropagation()} - className="text-blue-500 hover:text-blue-700 transition-colors duration-200" - > - - +
+ {model.provider_type == EmbeddingProvider.LITELLM.toLowerCase() && ( + + )} + e.stopPropagation()} + className="text-blue-500 hover:text-blue-700 transition-colors duration-200" + > + + +

{model.description}

-
- ${model.pricePerMillion}/M tokens -
+ {model?.provider_type?.toLowerCase() != + EmbeddingProvider.LITELLM.toLowerCase() && ( +
+ ${model.pricePerMillion}/M tokens +
+ )}
+
-
- setSearchTerm(e.target.value)} - className="ml-2 w-96 h-9 flex-none rounded-md border border-border bg-background-50 px-3 py-1 text-sm shadow-sm transition-colors placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring" - /> - - -
{sortedSources - .filter((source) => source != "not_applicable") + .filter( + (source) => + source != "not_applicable" && source != "ingestion_api" + ) .map((source, ind) => { const sourceMatches = source .toLowerCase() @@ -479,7 +467,7 @@ export function CCPairIndexingStatusTable({ if (sourceMatches || matchingConnectors.length > 0) { return ( -
+
- - Name - - - Last Indexed - - - Activity - + Name + Last Indexed + Activity {isPaidEnterpriseFeaturesEnabled && ( - - Permissions - + Permissions )} - - Total Docs - - - Last Status - - + Total Docs + Last Status + {(sourceMatches ? groupedStatuses[source] diff --git a/web/src/app/admin/settings/interfaces.ts b/web/src/app/admin/settings/interfaces.ts index 247bfd09d83..2c315f320ec 100644 --- a/web/src/app/admin/settings/interfaces.ts +++ b/web/src/app/admin/settings/interfaces.ts @@ -15,16 +15,27 @@ export interface Notification { first_shown: string; } +export interface NavigationItem { + link: string; + icon: string; + title: string; +} + export interface EnterpriseSettings { application_name: string | null; use_custom_logo: boolean; use_custom_logotype: boolean; + // custom navigation + custom_nav_items: NavigationItem[]; + // custom Chat components custom_lower_disclaimer_content: string | null; custom_header_content: string | null; + two_lines_for_chat_header: boolean | null; custom_popup_header: string | null; custom_popup_content: string | null; + enable_consent_screen: boolean | null; } export interface CombinedSettings { diff --git a/web/src/app/api/[...path]/route.ts b/web/src/app/api/[...path]/route.ts new file mode 100644 index 00000000000..550ebaf6d1f --- /dev/null +++ b/web/src/app/api/[...path]/route.ts @@ -0,0 +1,116 @@ +import { INTERNAL_URL } from "@/lib/constants"; +import { NextRequest, NextResponse } from "next/server"; + +/* NextJS is annoying and makes use use a separate function for +each request type >:( */ + +export async function GET( + request: NextRequest, + { params }: { params: { path: string[] } } +) { + return handleRequest(request, params.path); +} + +export async function POST( + request: NextRequest, + { params }: { params: { path: string[] } } +) { + return handleRequest(request, params.path); +} + +export async function PUT( + request: NextRequest, + { params }: { params: { path: string[] } } +) { + return handleRequest(request, params.path); +} + +export async function PATCH( + request: NextRequest, + { params }: { params: { path: string[] } } +) { + return handleRequest(request, params.path); +} + +export async function DELETE( + request: NextRequest, + { params }: { params: { path: string[] } } +) { + return handleRequest(request, params.path); +} + +export async function HEAD( + request: NextRequest, + { params }: { params: { path: string[] } } +) { + return handleRequest(request, params.path); +} + +export async function OPTIONS( + request: NextRequest, + { params }: { params: { path: string[] } } +) { + return handleRequest(request, params.path); +} + +async function handleRequest(request: NextRequest, path: string[]) { + if (process.env.NODE_ENV !== "development") { + return NextResponse.json( + { + message: + "This API is only available in development mode. In production, something else (e.g. nginx) should handle this.", + }, + { status: 404 } + ); + } + + try { + const backendUrl = new URL(`${INTERNAL_URL}/${path.join("/")}`); + + // Get the URL parameters from the request + const urlParams = new URLSearchParams(request.url.split("?")[1]); + + // Append the URL parameters to the backend URL + urlParams.forEach((value, key) => { + backendUrl.searchParams.append(key, value); + }); + + const response = await fetch(backendUrl, { + method: request.method, + headers: request.headers, + body: request.body, + // @ts-ignore + duplex: "half", + }); + + // Check if the response is a stream + if ( + response.headers.get("Transfer-Encoding") === "chunked" || + response.headers.get("Content-Type")?.includes("stream") + ) { + // If it's a stream, create a TransformStream to pass the data through + const { readable, writable } = new TransformStream(); + response.body?.pipeTo(writable); + + return new NextResponse(readable, { + status: response.status, + headers: response.headers, + }); + } else { + return new NextResponse(response.body, { + status: response.status, + headers: response.headers, + }); + } + } catch (error: unknown) { + console.error("Proxy error:", error); + return NextResponse.json( + { + message: "Proxy error", + error: + error instanceof Error ? error.message : "An unknown error occurred", + }, + { status: 500 } + ); + } +} diff --git a/web/src/app/assistants/ToolsDisplay.tsx b/web/src/app/assistants/ToolsDisplay.tsx index 10c25b640c9..2be7670c0ee 100644 --- a/web/src/app/assistants/ToolsDisplay.tsx +++ b/web/src/app/assistants/ToolsDisplay.tsx @@ -71,7 +71,7 @@ export function AssistantTools({ w-fit flex items-center - ${hovered ? "bg-background-300" : list ? "bg-background-125" : "bg-background-100"}`} + ${list ? "bg-background-125" : "bg-background-100"}`} >
@@ -91,7 +91,7 @@ export function AssistantTools({ border-border w-fit flex - ${hovered ? "bg-background-300" : list ? "bg-background-125" : "bg-background-100"}`} + ${list ? "bg-background-125" : "bg-background-100"}`} >
( - ( - - ), - p: ({ node, ...props }) => ( -

- ), - }} - remarkPlugins={[remarkGfm]} - > - {settings.enterpriseSettings?.custom_header_content} - - ); return (

@@ -90,7 +66,7 @@ export function ChatBanner() { className="absolute top-0 left-0 invisible w-full" >
diff --git a/web/src/app/chat/ChatIntro.tsx b/web/src/app/chat/ChatIntro.tsx index 27353aa340f..3703655d7f9 100644 --- a/web/src/app/chat/ChatIntro.tsx +++ b/web/src/app/chat/ChatIntro.tsx @@ -1,30 +1,9 @@ -import { getSourceMetadataForSources, listSourceMetadata } from "@/lib/sources"; +import { getSourceMetadataForSources } from "@/lib/sources"; import { ValidSources } from "@/lib/types"; -import Image from "next/image"; import { Persona } from "../admin/assistants/interfaces"; import { Divider } from "@tremor/react"; -import { FiBookmark, FiCpu, FiInfo, FiX, FiZoomIn } from "react-icons/fi"; +import { FiBookmark, FiInfo } from "react-icons/fi"; import { HoverPopup } from "@/components/HoverPopup"; -import { Modal } from "@/components/Modal"; -import { useState } from "react"; -import { Logo } from "@/components/Logo"; - -const MAX_PERSONAS_TO_DISPLAY = 4; - -function HelperItemDisplay({ - title, - description, -}: { - title: string; - description: string; -}) { - return ( -
-
{title}
-
{description}
-
- ); -} export function ChatIntro({ availableSources, diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 3ee22d4d74f..a3239114ada 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -65,7 +65,12 @@ import { FiArrowDown } from "react-icons/fi"; import { ChatIntro } from "./ChatIntro"; import { AIMessage, HumanMessage } from "./message/Messages"; import { StarterMessage } from "./StarterMessage"; -import { AnswerPiecePacket, DanswerDocument } from "@/lib/search/interfaces"; +import { + AnswerPiecePacket, + DanswerDocument, + StreamStopInfo, + StreamStopReason, +} from "@/lib/search/interfaces"; import { buildFilters } from "@/lib/search/utils"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import Dropzone from "react-dropzone"; @@ -86,7 +91,6 @@ import FunctionalHeader from "@/components/chat_search/Header"; import { useSidebarVisibility } from "@/components/chat_search/hooks"; import { SIDEBAR_TOGGLED_COOKIE_NAME } from "@/components/resizable/constants"; import FixedLogo from "./shared_chat_search/FixedLogo"; -import { getSecondsUntilExpiration } from "@/lib/time"; import { SetDefaultModelModal } from "./modal/SetDefaultModelModal"; import { DeleteEntityModal } from "../../components/modals/DeleteEntityModal"; import { MinimalMarkdown } from "@/components/chat_search/MinimalMarkdown"; @@ -94,6 +98,7 @@ import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal"; import { SEARCH_TOOL_NAME } from "./tools/constants"; import { useUser } from "@/components/user/UserProvider"; +import { ApiKeyModal } from "@/components/llm/ApiKeyModal"; const TEMP_USER_MESSAGE_ID = -1; const TEMP_ASSISTANT_MESSAGE_ID = -2; @@ -102,12 +107,10 @@ const SYSTEM_MESSAGE_ID = -3; export function ChatPage({ toggle, documentSidebarInitialWidth, - defaultSelectedAssistantId, toggledSidebar, }: { toggle: (toggled?: boolean) => void; documentSidebarInitialWidth?: number; - defaultSelectedAssistantId?: number; toggledSidebar: boolean; }) { const router = useRouter(); @@ -122,9 +125,14 @@ export function ChatPage({ folders, openedFolders, userInputPrompts, + defaultAssistantId, + shouldShowWelcomeModal, + refreshChatSessions, } = useChatContext(); - const { user, refreshUser } = useUser(); + const [showApiKeyModal, setShowApiKeyModal] = useState(true); + + const { user, refreshUser, isLoadingUser } = useUser(); // chat session const existingChatIdRaw = searchParams.get("chatId"); @@ -133,6 +141,7 @@ export function ChatPage({ const existingChatSessionId = existingChatIdRaw ? parseInt(existingChatIdRaw) : null; + const selectedChatSession = chatSessions.find( (chatSession) => chatSession.id === existingChatSessionId ); @@ -157,9 +166,9 @@ export function ChatPage({ ? availableAssistants.find( (assistant) => assistant.id === existingChatSessionAssistantId ) - : defaultSelectedAssistantId !== undefined + : defaultAssistantId !== undefined ? availableAssistants.find( - (assistant) => assistant.id === defaultSelectedAssistantId + (assistant) => assistant.id === defaultAssistantId ) : undefined ); @@ -201,6 +210,7 @@ export function ChatPage({ selectedAssistant || filteredAssistants[0] || availableAssistants[0]; + useEffect(() => { if (!loadedIdSessionRef.current && !currentPersonaId) { return; @@ -249,6 +259,7 @@ export function ChatPage({ updateChatState("input", currentSession); }; + // this is for "@"ing assistants // this is used to track which assistant is being used to generate the current message @@ -273,6 +284,7 @@ export function ChatPage({ ); const [isReady, setIsReady] = useState(false); + useEffect(() => { Prism.highlightAll(); setIsReady(true); @@ -319,8 +331,8 @@ export function ChatPage({ async function initialSessionFetch() { if (existingChatSessionId === null) { setIsFetchingChatMessages(false); - if (defaultSelectedAssistantId !== undefined) { - setSelectedAssistantFromId(defaultSelectedAssistantId); + if (defaultAssistantId !== undefined) { + setSelectedAssistantFromId(defaultAssistantId); } else { setSelectedAssistant(undefined); } @@ -337,6 +349,7 @@ export function ChatPage({ } return; } + clearSelectedDocuments(); setIsFetchingChatMessages(true); const response = await fetch( @@ -393,7 +406,7 @@ export function ChatPage({ // force re-name if the chat session doesn't have one if (!chatSession.description) { await nameChatSession(existingChatSessionId, seededMessage); - router.refresh(); // need to refresh to update name on sidebar + refreshChatSessions(); } } } @@ -623,6 +636,24 @@ export function ChatPage({ const currentRegenerationState = (): RegenerationState | null => { return regenerationState.get(currentSessionId()) || null; }; + const [canContinue, setCanContinue] = useState>( + new Map([[null, false]]) + ); + + const updateCanContinue = (newState: boolean, sessionId?: number | null) => { + setCanContinue((prevState) => { + const newCanContinueState = new Map(prevState); + newCanContinueState.set( + sessionId !== undefined ? sessionId : currentSessionId(), + newState + ); + return newCanContinueState; + }); + }; + + const currentCanContinue = (): boolean => { + return canContinue.get(currentSessionId()) || false; + }; const currentSessionChatState = currentChatState(); const currentSessionRegenerationState = currentRegenerationState(); @@ -649,12 +680,10 @@ export function ChatPage({ useEffect(() => { if (messageHistory.length === 0 && chatSessionIdRef.current === null) { setSelectedAssistant( - filteredAssistants.find( - (persona) => persona.id === defaultSelectedAssistantId - ) + filteredAssistants.find((persona) => persona.id === defaultAssistantId) ); } - }, [defaultSelectedAssistantId]); + }, [defaultAssistantId]); const [ selectedDocuments, @@ -751,11 +780,16 @@ export function ChatPage({ const clientScrollToBottom = (fast?: boolean) => { setTimeout(() => { - if (fast) { - endDivRef.current?.scrollIntoView(); - } else { - endDivRef.current?.scrollIntoView({ behavior: "smooth" }); + if (!endDivRef.current) { + return; } + + const rect = endDivRef.current.getBoundingClientRect(); + const isVisible = rect.top >= 0 && rect.bottom <= window.innerHeight; + + if (isVisible) return; + + endDivRef.current.scrollIntoView({ behavior: fast ? "auto" : "smooth" }); setHasPerformedInitialScroll(true); }, 50); }; @@ -863,6 +897,13 @@ export function ChatPage({ } }; + const continueGenerating = () => { + onSubmit({ + messageOverride: + "Continue Generating (pick up exactly where you left off)", + }); + }; + const onSubmit = async ({ messageIdToResend, messageOverride, @@ -883,6 +924,7 @@ export function ChatPage({ regenerationRequest?: RegenerationRequest | null; } = {}) => { let frozenSessionId = currentSessionId(); + updateCanContinue(false, frozenSessionId); if (currentChatState() != "input") { setPopup({ @@ -892,13 +934,6 @@ export function ChatPage({ return; } - updateRegenerationState( - regenerationRequest - ? { regenerating: true, finalMessageIndex: messageIdToResend || 0 } - : null - ); - - updateChatState("loading"); setAlternativeGeneratingAssistant(alternativeAssistantOverride); clientScrollToBottom(); @@ -929,6 +964,11 @@ export function ChatPage({ (message) => message.messageId === messageIdToResend ); + updateRegenerationState( + regenerationRequest + ? { regenerating: true, finalMessageIndex: messageIdToResend || 0 } + : null + ); const messageMap = currentMessageMap(completeMessageDetail); const messageToResendParent = messageToResend?.parentMessageId !== null && @@ -955,6 +995,9 @@ export function ChatPage({ } setSubmittedMessage(currMessage); + + updateChatState("loading"); + const currMessageHistory = messageToResendIndex !== null ? messageHistory.slice(0, messageToResendIndex) @@ -977,6 +1020,8 @@ export function ChatPage({ let messageUpdates: Message[] | null = null; let answer = ""; + + let stopReason: StreamStopReason | null = null; let query: string | null = null; let retrievalType: RetrievalType = selectedDocuments.length > 0 @@ -1067,6 +1112,12 @@ export function ChatPage({ console.error( "First packet should contain message response info " ); + if (Object.hasOwn(packet, "error")) { + const error = (packet as StreamingError).error; + setLoadingError(error); + updateChatState("input"); + return; + } continue; } @@ -1173,6 +1224,11 @@ export function ChatPage({ stackTrace = (packet as StreamingError).stack_trace; } else if (Object.hasOwn(packet, "message_id")) { finalMessage = packet as BackendMessage; + } else if (Object.hasOwn(packet, "stop_reason")) { + const stop_reason = (packet as StreamStopInfo).stop_reason; + if (stop_reason === StreamStopReason.CONTEXT_LENGTH) { + updateCanContinue(true, frozenSessionId); + } } // on initial message send, we insert a dummy system message @@ -1236,6 +1292,7 @@ export function ChatPage({ alternateAssistantID: alternativeAssistant?.id, stackTrace: stackTrace, overridden_model: finalMessage?.overridden_model, + stopReason: stopReason, }, ]); } @@ -1280,6 +1337,7 @@ export function ChatPage({ if (!searchParamBasedChatSessionName) { await new Promise((resolve) => setTimeout(resolve, 200)); await nameChatSession(currChatSessionId, currMessage); + refreshChatSessions(); } // NOTE: don't switch pages if the user has navigated away from the chat @@ -1415,6 +1473,7 @@ export function ChatPage({ // Used to maintain a "time out" for history sidebar so our existing refs can have time to process change const [untoggled, setUntoggled] = useState(false); + const [loadingError, setLoadingError] = useState(null); const explicitlyUntoggle = () => { setShowDocSidebar(false); @@ -1518,7 +1577,6 @@ export function ChatPage({ setDocumentSelection((documentSelection) => !documentSelection); setShowDocSidebar(false); }; - const secondsUntilExpiration = getSecondsUntilExpiration(user); interface RegenerationRequest { messageId: number; @@ -1538,7 +1596,12 @@ export function ChatPage({ return ( <> - + + + {showApiKeyModal && !shouldShowWelcomeModal && ( + setShowApiKeyModal(false)} /> + )} + {/* ChatPopup is a custom popup that displays a admin-specified message on initial user visit. Only used in the EE version of the app. */} {popup} @@ -1580,10 +1643,14 @@ export function ChatPage({ if (response.ok) { setDeletingChatSession(null); // go back to the main page - router.push("/chat"); + if (deletingChatSession.id === chatSessionIdRef.current) { + router.push("/chat"); + } } else { - alert("Failed to delete chat session"); + const responseJson = await response.json(); + setPopup({ message: responseJson.detail, type: "error" }); } + router.refresh(); }} /> )} @@ -1680,7 +1747,9 @@ export function ChatPage({ /> )} - {documentSidebarInitialWidth !== undefined && isReady ? ( + {documentSidebarInitialWidth !== undefined && + isReady && + !isLoadingUser ? ( {({ getRootProps }) => (
@@ -1705,7 +1774,6 @@ export function ChatPage({ className={`h-full w-full relative flex-auto transition-margin duration-300 overflow-x-auto mobile:pb-12 desktop:pb-[100px]`} {...getRootProps()} > - {/* */}
@@ -1737,9 +1810,14 @@ export function ChatPage({ ? messageMap.get(message.parentMessageId) : null; if ( - currentSessionRegenerationState?.regenerating && - message.messageId >= - currentSessionRegenerationState?.finalMessageIndex! + (currentSessionRegenerationState?.regenerating && + message.messageId > + currentSessionRegenerationState?.finalMessageIndex!) || + (currentSessionChatState == "loading" && + ((i >= messageHistory.length - 2 && + message.type == "user") || + (i >= messageHistory.length - 1 && + !currentSessionRegenerationState?.regenerating))) ) { return <>; } @@ -1816,7 +1894,8 @@ export function ChatPage({ if ( currentSessionRegenerationState?.regenerating && currentSessionChatState == "loading" && - message.messageId == messageHistory.length - 1 + (i == messageHistory.length - 1 || + currentSessionRegenerationState?.regenerating) ) { return <>; } @@ -1830,6 +1909,12 @@ export function ChatPage({ } > - )} + {(currentSessionChatState == "loading" || + (loadingError && + !currentSessionRegenerationState?.regenerating && + messageHistory[messageHistory.length - 1] + ?.type != "user")) && ( + + )} {currentSessionChatState == "loading" && (
)} + {loadingError && ( +
+ + {loadingError} +

+ } + /> +
+ )} {currentPersona && currentPersona.starter_messages && currentPersona.starter_messages.length > 0 && @@ -2090,6 +2190,7 @@ export function ChatPage({ )}
)} + {/* Some padding at the bottom so the search bar has space at the bottom to not cover the last message*/}
@@ -2111,6 +2212,9 @@ export function ChatPage({
)} + setShowApiKeyModal(true) + } chatState={currentSessionChatState} stopGenerating={stopGenerating} openModelSettings={() => setSettingsToggled(true)} @@ -2141,7 +2245,6 @@ export function ChatPage({
{ setCompletedFlow( @@ -20,16 +20,26 @@ export function ChatPopup() { }); const settings = useContext(SettingsContext); - if (!settings?.enterpriseSettings?.custom_popup_content || completedFlow) { + const enterpriseSettings = settings?.enterpriseSettings; + const isConsentScreen = enterpriseSettings?.enable_consent_screen; + if ( + (!enterpriseSettings?.custom_popup_content && !isConsentScreen) || + completedFlow + ) { return null; } - let popupTitle = settings.enterpriseSettings.custom_popup_header; - if (!popupTitle) { - popupTitle = `Welcome to ${ - settings.enterpriseSettings.application_name || "Danswer" - }!`; - } + const popupTitle = + enterpriseSettings?.custom_popup_header || + (isConsentScreen + ? "Terms of Use" + : `Welcome to ${enterpriseSettings?.application_name || "Danswer"}!`); + + const popupContent = + enterpriseSettings?.custom_popup_content || + (isConsentScreen + ? "By clicking 'I Agree', you acknowledge that you agree to the terms of use of this application and consent to proceed." + : ""); return ( @@ -49,12 +59,26 @@ export function ChatPopup() { }} remarkPlugins={[remarkGfm]} > - {settings.enterpriseSettings.custom_popup_content} + {popupContent} -
+ {showConsentError && ( +

+ You need to agree to the terms to access the application. +

+ )} + +
+ {isConsentScreen && ( + + )}
diff --git a/web/src/app/chat/WrappedChat.tsx b/web/src/app/chat/WrappedChat.tsx index cdb8508dfb0..6b48e442175 100644 --- a/web/src/app/chat/WrappedChat.tsx +++ b/web/src/app/chat/WrappedChat.tsx @@ -3,21 +3,15 @@ import { ChatPage } from "./ChatPage"; import FunctionalWrapper from "./shared_chat_search/FunctionalWrapper"; export default function WrappedChat({ - defaultAssistantId, initiallyToggled, }: { - defaultAssistantId?: number; initiallyToggled: boolean; }) { return ( ( - + )} /> ); diff --git a/web/src/app/chat/input/ChatInputBar.tsx b/web/src/app/chat/input/ChatInputBar.tsx index b579abefeed..6ea2ce868a5 100644 --- a/web/src/app/chat/input/ChatInputBar.tsx +++ b/web/src/app/chat/input/ChatInputBar.tsx @@ -33,12 +33,14 @@ import { Tooltip } from "@/components/tooltip/Tooltip"; import { Hoverable } from "@/components/Hoverable"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import { ChatState } from "../types"; +import UnconfiguredProviderText from "@/components/chat_search/UnconfiguredProviderText"; const MAX_INPUT_HEIGHT = 200; export function ChatInputBar({ openModelSettings, showDocs, + showConfigureAPIKey, selectedDocuments, message, setMessage, @@ -62,6 +64,7 @@ export function ChatInputBar({ chatSessionId, inputPrompts, }: { + showConfigureAPIKey: () => void; openModelSettings: () => void; chatState: ChatState; stopGenerating: () => void; @@ -111,6 +114,7 @@ export function ChatInputBar({ } } }; + const settings = useContext(SettingsContext); const { llmProviders } = useChatContext(); @@ -338,10 +342,10 @@ export function ChatInputBar({ updateInputPrompt(currentPrompt); }} > -

{currentPrompt.prompt}

-

+

{currentPrompt.prompt}:

+

{currentPrompt.id == selectedAssistant.id && "(default) "} - {currentPrompt.content} + {currentPrompt.content?.trim()}

))} @@ -364,6 +368,9 @@ export function ChatInputBar({
+ + +
void; +}) { + const [showExplanation, setShowExplanation] = useState(false); + + useEffect(() => { + const timer = setTimeout(() => { + setShowExplanation(true); + }, 1000); + + return () => clearTimeout(timer); + }, []); + + return ( +
+
+ + <> + + Continue Generation + + + {showExplanation && ( +
+ LLM reached its token limit. Click to continue. +
+ )} +
+
+ ); +} diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index 09cacd1b9f1..bbc4ae42601 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -64,7 +64,7 @@ import { SettingsContext } from "@/components/settings/SettingsProvider"; import GeneratingImageDisplay from "../tools/GeneratingImageDisplay"; import RegenerateOption from "../RegenerateOption"; import { LlmOverride } from "@/lib/hooks"; -import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal"; +import { ContinueGenerating } from "./ContinueMessage"; const TOOLS_WITH_CUSTOM_HANDLING = [ SEARCH_TOOL_NAME, @@ -123,6 +123,7 @@ function FileDisplay({ export const AIMessage = ({ regenerate, overriddenModel, + continueGenerating, shared, isActive, toggleDocumentSelection, @@ -150,6 +151,7 @@ export const AIMessage = ({ }: { shared?: boolean; isActive?: boolean; + continueGenerating?: () => void; otherMessagesCanSwitchTo?: number[]; onMessageSelection?: (messageId: number) => void; selectedDocuments?: DanswerDocument[] | null; @@ -283,11 +285,12 @@ export const AIMessage = ({ size="small" assistant={alternativeAssistant || currentPersona} /> +
- {(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && ( + {!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME ? ( <> {query !== undefined && handleShowRetrieved !== undefined && @@ -315,7 +318,8 @@ export const AIMessage = ({
)} - )} + ) : null} + {toolCall && !TOOLS_WITH_CUSTOM_HANDLING.includes( toolCall.tool_name @@ -358,7 +362,7 @@ export const AIMessage = ({ {typeof content === "string" ? ( -
+
+ {(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && + !query && + continueGenerating && ( + + )}
); @@ -706,6 +715,7 @@ export const HumanMessage = ({ // Move the cursor to the end of the text textareaRef.current.selectionStart = textareaRef.current.value.length; textareaRef.current.selectionEnd = textareaRef.current.value.length; + textareaRef.current.style.height = `${textareaRef.current.scrollHeight}px`; } }, [isEditing]); @@ -731,6 +741,7 @@ export const HumanMessage = ({
+
{isEditing ? ( @@ -777,6 +788,7 @@ export const HumanMessage = ({ style={{ scrollbarWidth: "thin" }} onChange={(e) => { setEditedContent(e.target.value); + textareaRef.current!.style.height = "auto"; e.target.style.height = `${e.target.scrollHeight}px`; }} onKeyDown={(e) => { diff --git a/web/src/app/chat/message/SkippedSearch.tsx b/web/src/app/chat/message/SkippedSearch.tsx index 62c47b7d96f..b339ac784ab 100644 --- a/web/src/app/chat/message/SkippedSearch.tsx +++ b/web/src/app/chat/message/SkippedSearch.tsx @@ -27,7 +27,7 @@ export function SkippedSearch({ handleForceSearch: () => void; }) { return ( -
+
diff --git a/web/src/app/chat/page.tsx b/web/src/app/chat/page.tsx index e391b79dae7..870ad963fca 100644 --- a/web/src/app/chat/page.tsx +++ b/web/src/app/chat/page.tsx @@ -2,10 +2,10 @@ import { redirect } from "next/navigation"; import { unstable_noStore as noStore } from "next/cache"; import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh"; import { WelcomeModal } from "@/components/initialSetup/welcome/WelcomeModalWrapper"; -import { ApiKeyModal } from "@/components/llm/ApiKeyModal"; import { ChatProvider } from "@/components/context/ChatContext"; import { fetchChatData } from "@/lib/chat/fetchChatData"; import WrappedChat from "./WrappedChat"; +import { ProviderContextProvider } from "@/components/chat_search/ProviderContext"; export default async function Page({ searchParams, @@ -23,7 +23,6 @@ export default async function Page({ const { user, chatSessions, - ccPairs, availableSources, documentSets, assistants, @@ -33,9 +32,7 @@ export default async function Page({ toggleSidebar, openedFolders, defaultAssistantId, - finalDocumentSidebarInitialWidth, shouldShowWelcomeModal, - shouldDisplaySourcesIncompleteModal, userInputPrompts, } = data; @@ -43,9 +40,7 @@ export default async function Page({ <> {shouldShowWelcomeModal && } - {!shouldShowWelcomeModal && !shouldDisplaySourcesIncompleteModal && ( - - )} + - + + + ); diff --git a/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx b/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx index df7ddee957f..35256ada98a 100644 --- a/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx +++ b/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx @@ -46,6 +46,7 @@ export function ChatSessionDisplay({ showDeleteModal?: (chatSession: ChatSession) => void; }) { const router = useRouter(); + const [isHovering, setIsHovering] = useState(false); const [isRenamingChat, setIsRenamingChat] = useState(false); const [isMoreOptionsDropdownOpen, setIsMoreOptionsDropdownOpen] = useState(false); @@ -97,6 +98,11 @@ export function ChatSessionDisplay({ setIsHovering(true)} + onMouseLeave={() => { + setIsMoreOptionsDropdownOpen(false); + setIsHovering(false); + }} onClick={() => { if (settings?.isMobile && closeSidebar) { closeSidebar(); @@ -145,7 +151,7 @@ export function ChatSessionDisplay({

)} - {isSelected && + {isHovering && (isRenamingChat ? (
-
{ - setIsMoreOptionsDropdownOpen( - !isMoreOptionsDropdownOpen - ); - }} - className={"-my-1"} - > - - setIsMoreOptionsDropdownOpen(open) - } - content={ -
- -
- } - popover={ -
- {showShareModal && ( - showShareModal(chatSession)} - /> - )} - setIsRenamingChat(true)} - /> -
- } - requiresContentPadding - sideOffset={6} - triggerMaxWidth - /> -
+ {search ? ( + showDeleteModal && ( +
{ + e.preventDefault(); + showDeleteModal(chatSession); + }} + className={`p-1 -m-1 rounded ml-1`} + > + +
+ ) + ) : ( +
{ + e.preventDefault(); + // e.stopPropagation(); + setIsMoreOptionsDropdownOpen( + !isMoreOptionsDropdownOpen + ); + }} + className="-my-1" + > + + setIsMoreOptionsDropdownOpen(open) + } + content={ +
+ +
+ } + popover={ +
+ {showShareModal && ( + showShareModal(chatSession)} + /> + )} + {!search && ( + setIsRenamingChat(true)} + /> + )} + {showDeleteModal && ( + + showDeleteModal(chatSession) + } + /> + )} +
+ } + requiresContentPadding + sideOffset={6} + triggerMaxWidth + /> +
+ )}
- {showDeleteModal && ( -
showDeleteModal(chatSession)} - className={`hover:bg-black/10 p-1 -m-1 rounded ml-1`} - > - -
- )}
))}
diff --git a/web/src/app/chat/shared_chat_search/FixedLogo.tsx b/web/src/app/chat/shared_chat_search/FixedLogo.tsx index 0b3e115c3c8..c5114ca3e71 100644 --- a/web/src/app/chat/shared_chat_search/FixedLogo.tsx +++ b/web/src/app/chat/shared_chat_search/FixedLogo.tsx @@ -21,11 +21,11 @@ export default function FixedLogo() { } className="fixed cursor-pointer flex z-40 left-2.5 top-2" > -
+
-
+
{enterpriseSettings && enterpriseSettings.application_name ? (
{enterpriseSettings.application_name} diff --git a/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx b/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx index 4ef22ef4e60..4f8d31d39ee 100644 --- a/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx +++ b/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx @@ -53,10 +53,15 @@ const ToggleSwitch = () => { onClick={() => handleTabChange("search")} > -

+

Search - {commandSymbol}S -

+
+ + {commandSymbol} + + S +
+
); @@ -122,6 +132,8 @@ export default function FunctionalWrapper({ const settings = combinedSettings?.settings; const chatBannerPresent = combinedSettings?.enterpriseSettings?.custom_header_content; + const twoLines = + combinedSettings?.enterpriseSettings?.two_lines_for_chat_header; const [toggledSidebar, setToggledSidebar] = useState(initiallyToggled); @@ -136,7 +148,7 @@ export default function FunctionalWrapper({ {(!settings || (settings.search_page_enabled && settings.chat_page_enabled)) && (
{ formikHelpers.setSubmitting(true); @@ -204,28 +211,62 @@ export function WhitelabelingForm() { disabled={isSubmitting} /> + + + + li > p, +ul > li > p { + margin-top: 0; + margin-bottom: 0; + display: inline; /* Make paragraphs inline to reduce vertical space */ +} diff --git a/web/src/app/layout.tsx b/web/src/app/layout.tsx index 8219dbb9a61..5b9435cbc85 100644 --- a/web/src/app/layout.tsx +++ b/web/src/app/layout.tsx @@ -3,21 +3,19 @@ import "./globals.css"; import { fetchEnterpriseSettingsSS, fetchSettingsSS, - SettingsError, } from "@/components/settings/lib"; import { CUSTOM_ANALYTICS_ENABLED, + EE_ENABLED, SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED, } from "@/lib/constants"; import { SettingsProvider } from "@/components/settings/SettingsProvider"; import { Metadata } from "next"; -import { buildClientUrl } from "@/lib/utilsSS"; +import { buildClientUrl, fetchSS } from "@/lib/utilsSS"; import { Inter } from "next/font/google"; import Head from "next/head"; import { EnterpriseSettings } from "./admin/settings/interfaces"; -import { redirect } from "next/navigation"; -import { Button, Card } from "@tremor/react"; -import LogoType from "@/components/header/LogoType"; +import { Card } from "@tremor/react"; import { HeaderTitle } from "@/components/header/HeaderTitle"; import { Logo } from "@/components/Logo"; import { UserProvider } from "@/components/user/UserProvider"; @@ -56,6 +54,7 @@ export default async function RootLayout({ children: React.ReactNode; }) { const combinedSettings = await fetchSettingsSS(); + if (!combinedSettings) { // Just display a simple full page error if fetching fails. @@ -75,8 +74,34 @@ export default async function RootLayout({

Error

Your Danswer instance was not configured properly and your - settings could not be loaded. Please contact your admin to fix - this error. + settings could not be loaded. This could be due to an admin + configuration issue or an incomplete setup. +

+

+ If you're an admin, please check{" "} + + our docs + {" "} + to see how to configure Danswer properly. If you're a user, + please contact your admin to fix this error. +

+

+ For additional support and guidance, you can reach out to our + community on{" "} + + Slack + + .

@@ -107,7 +132,7 @@ export default async function RootLayout({
( )} diff --git a/web/src/app/search/page.tsx b/web/src/app/search/page.tsx index 6f6cef8c4f0..e317d271b0d 100644 --- a/web/src/app/search/page.tsx +++ b/web/src/app/search/page.tsx @@ -3,10 +3,8 @@ import { getAuthTypeMetadataSS, getCurrentUserSS, } from "@/lib/userSS"; -import { getSecondsUntilExpiration } from "@/lib/time"; import { redirect } from "next/navigation"; import { HealthCheckBanner } from "@/components/health/healthcheck"; -import { ApiKeyModal } from "@/components/llm/ApiKeyModal"; import { fetchSS } from "@/lib/utilsSS"; import { CCPairBasicInfo, DocumentSet, Tag, User } from "@/lib/types"; import { cookies } from "next/headers"; @@ -35,6 +33,8 @@ import { DISABLE_LLM_DOC_RELEVANCE, } from "@/lib/constants"; import WrappedSearch from "./WrappedSearch"; +import { SearchProvider } from "@/components/context/SearchContext"; +import { ProviderContextProvider } from "@/components/chat_search/ProviderContext"; export default async function Home() { // Disable caching so we always get the up to date connector / document set / persona info @@ -179,18 +179,13 @@ export default async function Home() { const agenticSearchEnabled = agenticSearchToggle ? agenticSearchToggle.value.toLocaleLowerCase() == "true" || false : false; - const secondsUntilExpiration = getSecondsUntilExpiration(user); return ( <> - + {shouldShowWelcomeModal && } - {!shouldShowWelcomeModal && - !shouldDisplayNoSourcesModal && - !shouldDisplaySourcesIncompleteModal && } - {shouldDisplayNoSourcesModal && } {shouldDisplaySourcesIncompleteModal && ( @@ -201,18 +196,27 @@ export default async function Home() { Only used in the EE version of the app. */} - + + + + + ); } diff --git a/web/src/components/IsPublicGroupSelector.tsx b/web/src/components/IsPublicGroupSelector.tsx index 63d47e506a6..c478d29b177 100644 --- a/web/src/components/IsPublicGroupSelector.tsx +++ b/web/src/components/IsPublicGroupSelector.tsx @@ -15,14 +15,18 @@ export type IsPublicGroupSelectorFormType = { export const IsPublicGroupSelector = ({ formikProps, objectName, + publicToWhom = "Users", + removeIndent = false, enforceGroupSelection = true, }: { formikProps: FormikProps; objectName: string; + publicToWhom?: string; + removeIndent?: boolean; enforceGroupSelection?: boolean; }) => { const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups(); - const { isAdmin, user, isLoadingUser } = useUser(); + const { isAdmin, user, isLoadingUser, isCurator } = useUser(); const [shouldHideContent, setShouldHideContent] = useState(false); useEffect(() => { @@ -72,57 +76,67 @@ export const IsPublicGroupSelector = ({ <> - If set, then this {objectName} will be visible to{" "} - all users. If turned off, then only users who explicitly - have been given access to this {objectName} (e.g. through a User - Group) will have access. + If set, then this {objectName} will be usable by{" "} + All {publicToWhom}. Otherwise, only Admins and{" "} + {publicToWhom} who have explicitly been given access to + this {objectName} (e.g. via a User Group) will have access. } /> )} - {(!formikProps.values.is_public || - !isAdmin || - formikProps.values.groups.length > 0) && ( - <> -
-
- Assign group access for this {objectName} + {(!formikProps.values.is_public || isCurator) && + formikProps.values.groups.length > 0 && ( + <> +
+
+ Assign group access for this {objectName} +
-
- - {isAdmin || !enforceGroupSelection ? ( - <> - This {objectName} will be visible/accessible by the groups - selected below - + {userGroupsIsLoading ? ( +
) : ( - <> - Curators must select one or more groups to give access to this{" "} - {objectName} - - )} -
- ( -
- {userGroupsIsLoading ? ( -
+ + {isAdmin || !enforceGroupSelection ? ( + <> + This {objectName} will be visible/accessible by the groups + selected below + ) : ( - userGroups && - userGroups.map((userGroup: UserGroup) => { - const ind = formikProps.values.groups.indexOf(userGroup.id); - let isSelected = ind !== -1; - return ( -
+ Curators must select one or more groups to give access to + this {objectName} + + )} + + )} + ( +
+ {userGroupsIsLoading ? ( +
+ ) : ( + userGroups && + userGroups.map((userGroup: UserGroup) => { + const ind = formikProps.values.groups.indexOf( + userGroup.id + ); + let isSelected = ind !== -1; + return ( +
({ cursor-pointer ${isSelected ? "bg-background-strong" : "hover:bg-hover"} `} - onClick={() => { - if (isSelected) { - arrayHelpers.remove(ind); - } else { - arrayHelpers.push(userGroup.id); - } - }} - > -
- {userGroup.name} + onClick={() => { + if (isSelected) { + arrayHelpers.remove(ind); + } else { + arrayHelpers.push(userGroup.id); + } + }} + > +
+ {" "} + {userGroup.name} +
-
- ); - }) - )} -
- )} - /> - - - )} + ); + }) + )} +
+ )} + /> + + + )}
); }; diff --git a/web/src/components/UserDropdown.tsx b/web/src/components/UserDropdown.tsx index 2ca71c577c2..3bc18b5241b 100644 --- a/web/src/components/UserDropdown.tsx +++ b/web/src/components/UserDropdown.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useRef, useContext } from "react"; +import { useState, useRef, useContext, useEffect } from "react"; import { FiLogOut } from "react-icons/fi"; import Link from "next/link"; import { useRouter } from "next/navigation"; @@ -15,6 +15,35 @@ import { UsersIcon, } from "./icons/icons"; import { pageType } from "@/app/chat/sessionSidebar/types"; +import { NavigationItem } from "@/app/admin/settings/interfaces"; +import DynamicFaIcon, { preloadIcons } from "./icons/DynamicFaIcon"; + +interface DropdownOptionProps { + href?: string; + onClick?: () => void; + icon: React.ReactNode; + label: string; +} + +const DropdownOption: React.FC = ({ + href, + onClick, + icon, + label, +}) => { + const content = ( +
+ {icon} + {label} +
+ ); + + return href ? ( + {content} + ) : ( +
{content}
+ ); +}; export function UserDropdown({ user, @@ -28,10 +57,17 @@ export function UserDropdown({ const router = useRouter(); const combinedSettings = useContext(SettingsContext); + const customNavItems: NavigationItem[] = + combinedSettings?.enterpriseSettings?.custom_nav_items || []; + + useEffect(() => { + const iconNames = customNavItems.map((item) => item.icon); + preloadIcons(iconNames); + }, [customNavItems]); + if (!combinedSettings) { return null; } - const settings = combinedSettings.settings; const handleLogout = () => { logout().then((isSuccess) => { @@ -100,44 +136,49 @@ export function UserDropdown({ overscroll-contain `} > - {showAdminPanel && ( - <> - - - Admin Panel - - - )} - {showCuratorPanel && ( - <> - ( + + } + label={item.title} + /> + ))} + + {showAdminPanel ? ( + } + label="Admin Panel" + /> + ) : ( + showCuratorPanel && ( + - - Curator Panel - - + icon={} + label="Curator Panel" + /> + ) )} + {showLogout && + (showCuratorPanel || + showAdminPanel || + customNavItems.length > 0) && ( +
+ )} + {showLogout && ( - <> - {(!(page == "search" || page == "chat") || showAdminPanel) && ( -
- )} -
- - Log out -
- + } + label="Log out" + /> )}
} diff --git a/web/src/components/admin/connectors/AdminSidebar.tsx b/web/src/components/admin/connectors/AdminSidebar.tsx index 60bbd3c7e93..5b1a8adc831 100644 --- a/web/src/components/admin/connectors/AdminSidebar.tsx +++ b/web/src/components/admin/connectors/AdminSidebar.tsx @@ -48,13 +48,13 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) { : "/search" } > -
+
-
+
{enterpriseSettings && enterpriseSettings.application_name ? ( -
+
{enterpriseSettings.application_name} @@ -74,9 +74,9 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
- ); }; diff --git a/web/src/components/admin/users/SignedUpUserTable.tsx b/web/src/components/admin/users/SignedUpUserTable.tsx index 324c5266117..b747229ea39 100644 --- a/web/src/components/admin/users/SignedUpUserTable.tsx +++ b/web/src/components/admin/users/SignedUpUserTable.tsx @@ -19,6 +19,7 @@ import { import { GenericConfirmModal } from "@/components/modals/GenericConfirmModal"; import { useState } from "react"; import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; +import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal"; const USER_ROLE_LABELS: Record = { [UserRole.BASIC]: "Basic", @@ -157,6 +158,59 @@ const DeactivaterButton = ({ ); }; +const DeleteUserButton = ({ + user, + setPopup, + mutate, +}: { + user: User; + setPopup: (spec: PopupSpec) => void; + mutate: () => void; +}) => { + const { trigger, isMutating } = useSWRMutation( + "/api/manage/admin/delete-user", + userMutationFetcher, + { + onSuccess: () => { + mutate(); + setPopup({ + message: "User deleted successfully!", + type: "success", + }); + }, + onError: (errorMsg) => + setPopup({ + message: `Unable to delete user - ${errorMsg}`, + type: "error", + }), + } + ); + + const [showDeleteModal, setShowDeleteModal] = useState(false); + return ( + <> + {showDeleteModal && ( + setShowDeleteModal(false)} + onSubmit={() => trigger({ user_email: user.email, method: "DELETE" })} + /> + )} + + + + ); +}; + const SignedUpUserTable = ({ users, setPopup, @@ -215,13 +269,20 @@ const SignedUpUserTable = ({ {user.status === "live" ? "Active" : "Inactive"} -
+
+ {user.status == UserStatus.deactivated && ( + + )}
diff --git a/web/src/components/chat_search/Header.tsx b/web/src/components/chat_search/Header.tsx index 0480d463736..640393641d0 100644 --- a/web/src/components/chat_search/Header.tsx +++ b/web/src/components/chat_search/Header.tsx @@ -7,7 +7,6 @@ import { NewChatIcon } from "../icons/icons"; import { NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA } from "@/lib/constants"; import { ChatSession } from "@/app/chat/interfaces"; import Link from "next/link"; -import { SettingsContext } from "../settings/SettingsProvider"; import { pageType } from "@/app/chat/sessionSidebar/types"; import { useRouter } from "next/navigation"; import { ChatBanner } from "@/app/chat/ChatBanner"; @@ -65,7 +64,7 @@ export default function FunctionalHeader({ router.push(newChatUrl); }; return ( -
+
{page != "assistants" && ( -
+
)}
); diff --git a/web/src/components/chat_search/MinimalMarkdown.tsx b/web/src/components/chat_search/MinimalMarkdown.tsx index 6fadc979c48..9df0260f4d5 100644 --- a/web/src/components/chat_search/MinimalMarkdown.tsx +++ b/web/src/components/chat_search/MinimalMarkdown.tsx @@ -1,3 +1,4 @@ +import { CodeBlock } from "@/app/chat/message/CodeBlock"; import React from "react"; import ReactMarkdown from "react-markdown"; import remarkGfm from "remark-gfm"; @@ -5,11 +6,13 @@ import remarkGfm from "remark-gfm"; interface MinimalMarkdownProps { content: string; className?: string; + useCodeBlock?: boolean; } export const MinimalMarkdown: React.FC = ({ content, className = "", + useCodeBlock = false, }) => { return ( = ({ p: ({ node, ...props }) => (

), + code: useCodeBlock + ? (props) => ( + + ) + : (props) => , }} remarkPlugins={[remarkGfm]} > diff --git a/web/src/components/chat_search/ProviderContext.tsx b/web/src/components/chat_search/ProviderContext.tsx new file mode 100644 index 00000000000..3907b98a68b --- /dev/null +++ b/web/src/components/chat_search/ProviderContext.tsx @@ -0,0 +1,70 @@ +"use client"; +import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces"; +import React, { createContext, useContext, useState, useEffect } from "react"; +import { useUser } from "../user/UserProvider"; +import { useRouter } from "next/navigation"; +import { checkLlmProvider } from "../initialSetup/welcome/lib"; + +interface ProviderContextType { + shouldShowConfigurationNeeded: boolean; + providerOptions: WellKnownLLMProviderDescriptor[]; + refreshProviderInfo: () => Promise; // Add this line +} + +const ProviderContext = createContext( + undefined +); + +export function ProviderContextProvider({ + children, +}: { + children: React.ReactNode; +}) { + const { user } = useUser(); + const router = useRouter(); + + const [validProviderExists, setValidProviderExists] = useState(true); + const [providerOptions, setProviderOptions] = useState< + WellKnownLLMProviderDescriptor[] + >([]); + + const fetchProviderInfo = async () => { + const { providers, options, defaultCheckSuccessful } = + await checkLlmProvider(user); + setValidProviderExists(providers.length > 0 && defaultCheckSuccessful); + setProviderOptions(options); + }; + + useEffect(() => { + fetchProviderInfo(); + }, [router, user]); + + const shouldShowConfigurationNeeded = + !validProviderExists && providerOptions.length > 0; + + const refreshProviderInfo = async () => { + await fetchProviderInfo(); + }; + + return ( + + {children} + + ); +} + +export function useProviderStatus() { + const context = useContext(ProviderContext); + if (context === undefined) { + throw new Error( + "useProviderStatus must be used within a ProviderContextProvider" + ); + } + return context; +} diff --git a/web/src/components/chat_search/UnconfiguredProviderText.tsx b/web/src/components/chat_search/UnconfiguredProviderText.tsx new file mode 100644 index 00000000000..e990eecd610 --- /dev/null +++ b/web/src/components/chat_search/UnconfiguredProviderText.tsx @@ -0,0 +1,27 @@ +import { useProviderStatus } from "./ProviderContext"; + +export default function CredentialNotConfigured({ + showConfigureAPIKey, +}: { + showConfigureAPIKey: () => void; +}) { + const { shouldShowConfigurationNeeded } = useProviderStatus(); + + if (!shouldShowConfigurationNeeded) { + return null; + } + + return ( +

+ Please note that you have not yet configured an LLM provider. You can + configure one{" "} + + . +

+ ); +} diff --git a/web/src/components/chat_search/hooks.ts b/web/src/components/chat_search/hooks.ts index 3377d9172a0..314810b55ee 100644 --- a/web/src/components/chat_search/hooks.ts +++ b/web/src/components/chat_search/hooks.ts @@ -1,4 +1,4 @@ -import { Dispatch, SetStateAction, useEffect, useRef, useState } from "react"; +import { Dispatch, SetStateAction, useEffect, useRef } from "react"; interface UseSidebarVisibilityProps { toggledSidebar: boolean; diff --git a/web/src/components/context/ChatContext.tsx b/web/src/components/context/ChatContext.tsx index 9ab0c1e2eac..06a63c903cc 100644 --- a/web/src/components/context/ChatContext.tsx +++ b/web/src/components/context/ChatContext.tsx @@ -1,6 +1,6 @@ "use client"; -import React, { createContext, useContext } from "react"; +import React, { createContext, useContext, useState } from "react"; import { DocumentSet, Tag, User, ValidSources } from "@/lib/types"; import { ChatSession } from "@/app/chat/interfaces"; import { Persona } from "@/app/admin/assistants/interfaces"; @@ -18,15 +18,40 @@ interface ChatContextProps { folders: Folder[]; openedFolders: Record; userInputPrompts: InputPrompt[]; + shouldShowWelcomeModal?: boolean; + shouldDisplaySourcesIncompleteModal?: boolean; + defaultAssistantId?: number; + refreshChatSessions: () => Promise; } const ChatContext = createContext(undefined); +// We use Omit to exclude 'refreshChatSessions' from the value prop type +// because we're defining it within the component export const ChatProvider: React.FC<{ - value: ChatContextProps; + value: Omit; children: React.ReactNode; }> = ({ value, children }) => { - return {children}; + const [chatSessions, setChatSessions] = useState(value?.chatSessions || []); + + const refreshChatSessions = async () => { + try { + const response = await fetch("/api/chat/get-user-chat-sessions"); + if (!response.ok) throw new Error("Failed to fetch chat sessions"); + const { sessions } = await response.json(); + setChatSessions(sessions); + } catch (error) { + console.error("Error refreshing chat sessions:", error); + } + }; + + return ( + + {children} + + ); }; export const useChatContext = (): ChatContextProps => { diff --git a/web/src/components/context/SearchContext.tsx b/web/src/components/context/SearchContext.tsx new file mode 100644 index 00000000000..a46fbed24ca --- /dev/null +++ b/web/src/components/context/SearchContext.tsx @@ -0,0 +1,38 @@ +"use client"; + +import React, { createContext, useContext } from "react"; +import { CCPairBasicInfo, DocumentSet, Tag } from "@/lib/types"; +import { Persona } from "@/app/admin/assistants/interfaces"; +import { ChatSession } from "@/app/chat/interfaces"; + +interface SearchContextProps { + querySessions: ChatSession[]; + ccPairs: CCPairBasicInfo[]; + documentSets: DocumentSet[]; + assistants: Persona[]; + tags: Tag[]; + agenticSearchEnabled: boolean; + disabledAgentic: boolean; + initiallyToggled: boolean; + shouldShowWelcomeModal: boolean; + shouldDisplayNoSources: boolean; +} + +const SearchContext = createContext(undefined); + +export const SearchProvider: React.FC<{ + value: SearchContextProps; + children: React.ReactNode; +}> = ({ value, children }) => { + return ( + {children} + ); +}; + +export const useSearchContext = (): SearchContextProps => { + const context = useContext(SearchContext); + if (!context) { + throw new Error("useSearchContext must be used within a SearchProvider"); + } + return context; +}; diff --git a/web/src/components/credentials/actions/CreateCredential.tsx b/web/src/components/credentials/actions/CreateCredential.tsx index 613e02d34b6..36c204d94ce 100644 --- a/web/src/components/credentials/actions/CreateCredential.tsx +++ b/web/src/components/credentials/actions/CreateCredential.tsx @@ -236,6 +236,7 @@ export default function CreateCredential({ )}
diff --git a/web/src/components/embedding/CustomModelForm.tsx b/web/src/components/embedding/CustomModelForm.tsx index e726921d390..6ea8fbc57cc 100644 --- a/web/src/components/embedding/CustomModelForm.tsx +++ b/web/src/components/embedding/CustomModelForm.tsx @@ -41,6 +41,7 @@ export function CustomModelForm({ api_key: null, provider_type: null, index_name: null, + api_url: null, }); }} > @@ -106,20 +107,19 @@ export function CustomModelForm({ /> -
- -
+ )} diff --git a/web/src/components/embedding/LiteLLMModelForm.tsx b/web/src/components/embedding/LiteLLMModelForm.tsx new file mode 100644 index 00000000000..b84db4f9067 --- /dev/null +++ b/web/src/components/embedding/LiteLLMModelForm.tsx @@ -0,0 +1,116 @@ +import { CloudEmbeddingModel, CloudEmbeddingProvider } from "./interfaces"; +import { Formik, Form } from "formik"; +import * as Yup from "yup"; +import { TextFormField, BooleanFormField } from "../admin/connectors/Field"; +import { Dispatch, SetStateAction } from "react"; +import { Button, Text } from "@tremor/react"; +import { EmbeddingDetails } from "@/app/admin/embeddings/EmbeddingModelSelectionForm"; + +export function LiteLLMModelForm({ + setShowTentativeModel, + currentValues, + provider, +}: { + setShowTentativeModel: Dispatch>; + currentValues: CloudEmbeddingModel | null; + provider: EmbeddingDetails; +}) { + return ( +
+ { + setShowTentativeModel(values as CloudEmbeddingModel); + }} + > + {({ isSubmitting }) => ( +
+ + Add a new model to LiteLLM proxy at {provider.api_url} + + + + + + + + + + + + + + )} +
+
+ ); +} diff --git a/web/src/components/embedding/ReindexingProgressTable.tsx b/web/src/components/embedding/ReindexingProgressTable.tsx index b1f91d24bb3..882b2591003 100644 --- a/web/src/components/embedding/ReindexingProgressTable.tsx +++ b/web/src/components/embedding/ReindexingProgressTable.tsx @@ -27,10 +27,16 @@ export function ReindexingProgressTable({ - Connector Name - Status - Docs Re-Indexed - Error Message + + Connector Name + + Status + + Docs Re-Indexed + + + Error Message + diff --git a/web/src/components/embedding/interfaces.tsx b/web/src/components/embedding/interfaces.tsx index c719b7dc7bf..daa56128c3d 100644 --- a/web/src/components/embedding/interfaces.tsx +++ b/web/src/components/embedding/interfaces.tsx @@ -2,6 +2,7 @@ import { CohereIcon, GoogleIcon, IconProps, + LiteLLMIcon, MicrosoftIcon, NomicIcon, OpenAIIcon, @@ -14,11 +15,13 @@ export enum EmbeddingProvider { COHERE = "Cohere", VOYAGE = "Voyage", GOOGLE = "Google", + LITELLM = "LiteLLM", } export interface CloudEmbeddingProvider { provider_type: EmbeddingProvider; api_key?: string; + api_url?: string; custom_config?: Record; docsLink?: string; @@ -36,6 +39,7 @@ export interface CloudEmbeddingProvider { // Embedding Models export interface EmbeddingModelDescriptor { + id?: number; model_name: string; model_dim: number; normalize: boolean; @@ -44,6 +48,7 @@ export interface EmbeddingModelDescriptor { provider_type: string | null; description: string; api_key: string | null; + api_url: string | null; index_name: string | null; } @@ -70,7 +75,7 @@ export interface FullEmbeddingModelResponse { } export interface CloudEmbeddingProviderFull extends CloudEmbeddingProvider { - configured: boolean; + configured?: boolean; } export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ @@ -87,6 +92,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ index_name: "", provider_type: null, api_key: null, + api_url: null, }, { model_name: "intfloat/e5-base-v2", @@ -99,6 +105,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ passage_prefix: "passage: ", index_name: "", provider_type: null, + api_url: null, api_key: null, }, { @@ -113,6 +120,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ index_name: "", provider_type: null, api_key: null, + api_url: null, }, { model_name: "intfloat/multilingual-e5-base", @@ -126,6 +134,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ index_name: "", provider_type: null, api_key: null, + api_url: null, }, { model_name: "intfloat/multilingual-e5-small", @@ -139,9 +148,19 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ index_name: "", provider_type: null, api_key: null, + api_url: null, }, ]; +export const LITELLM_CLOUD_PROVIDER: CloudEmbeddingProvider = { + provider_type: EmbeddingProvider.LITELLM, + website: "https://github.com/BerriAI/litellm", + icon: LiteLLMIcon, + description: "Open-source library to call LLM APIs using OpenAI format", + apiLink: "https://docs.litellm.ai/docs/proxy/quick_start", + embedding_models: [], // No default embedding models +}; + export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ { provider_type: EmbeddingProvider.COHERE, @@ -169,6 +188,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, { model_name: "embed-english-light-v3.0", @@ -185,6 +205,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, ], }, @@ -213,6 +234,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ enabled: false, index_name: "", api_key: null, + api_url: null, }, { provider_type: EmbeddingProvider.OPENAI, @@ -229,6 +251,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ maxContext: 8191, index_name: "", api_key: null, + api_url: null, }, ], }, @@ -258,6 +281,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, { provider_type: EmbeddingProvider.GOOGLE, @@ -273,6 +297,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, ], }, @@ -301,6 +326,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, { provider_type: EmbeddingProvider.VOYAGE, @@ -317,6 +343,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, ], }, diff --git a/web/src/components/header/HeaderTitle.tsx b/web/src/components/header/HeaderTitle.tsx index 2ec1d3cbff2..f63a8769c44 100644 --- a/web/src/components/header/HeaderTitle.tsx +++ b/web/src/components/header/HeaderTitle.tsx @@ -7,7 +7,9 @@ export function HeaderTitle({ children }: { children: JSX.Element | string }) { const textSize = isString && children.length > 10 ? "text-xl" : "text-2xl"; return ( -

+

{children}

); diff --git a/web/src/components/header/LogoType.tsx b/web/src/components/header/LogoType.tsx index 04ddc54d758..29a4bfedea5 100644 --- a/web/src/components/header/LogoType.tsx +++ b/web/src/components/header/LogoType.tsx @@ -56,7 +56,7 @@ export default function LogoType({ >
{enterpriseSettings && enterpriseSettings.application_name ? ( -
+
{enterpriseSettings.application_name} {!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (

Powered by Danswer

diff --git a/web/src/components/health/healthcheck.tsx b/web/src/components/health/healthcheck.tsx index a8110ba8c55..2cba8be8278 100644 --- a/web/src/components/health/healthcheck.tsx +++ b/web/src/components/health/healthcheck.tsx @@ -3,29 +3,95 @@ import { errorHandlingFetcher, RedirectError } from "@/lib/fetcher"; import useSWR from "swr"; import { Modal } from "../Modal"; -import { useState } from "react"; +import { useEffect, useState } from "react"; +import { getSecondsUntilExpiration } from "@/lib/time"; +import { User } from "@/lib/types"; -export const HealthCheckBanner = ({ - secondsUntilExpiration, -}: { - secondsUntilExpiration?: number | null; -}) => { +export const HealthCheckBanner = () => { const { error } = useSWR("/api/health", errorHandlingFetcher); const [expired, setExpired] = useState(false); + const [secondsUntilExpiration, setSecondsUntilExpiration] = useState< + number | null + >(null); + const { data: user, mutate: mutateUser } = useSWR( + "/api/me", + errorHandlingFetcher + ); - if (secondsUntilExpiration !== null && secondsUntilExpiration !== undefined) { - setTimeout( - () => { - setExpired(true); - }, - secondsUntilExpiration * 1000 - 200 - ); - } + const updateExpirationTime = async () => { + const updatedUser = await mutateUser(); + + if (updatedUser) { + const seconds = getSecondsUntilExpiration(updatedUser); + setSecondsUntilExpiration(seconds); + console.debug(`Updated seconds until expiration:! ${seconds}`); + } + }; + + useEffect(() => { + updateExpirationTime(); + }, [user]); + + useEffect(() => { + if (true) { + let refreshTimeoutId: NodeJS.Timeout; + let expireTimeoutId: NodeJS.Timeout; + + const refreshToken = async () => { + try { + const response = await fetch( + "/api/enterprise-settings/refresh-token", + { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + } + ); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + console.debug("Token refresh successful"); + // Force revalidation of user data + + await mutateUser(undefined, { revalidate: true }); + updateExpirationTime(); + } catch (error) { + console.error("Error refreshing token:", error); + } + }; + + const scheduleRefreshAndExpire = () => { + if (secondsUntilExpiration !== null) { + const timeUntilRefresh = (secondsUntilExpiration + 0.5) * 1000; + refreshTimeoutId = setTimeout(refreshToken, timeUntilRefresh); + + const timeUntilExpire = (secondsUntilExpiration + 10) * 1000; + expireTimeoutId = setTimeout(() => { + console.debug("Session expired. Setting expired state to true."); + setExpired(true); + }, timeUntilExpire); + } + }; + + scheduleRefreshAndExpire(); + + return () => { + clearTimeout(refreshTimeoutId); + clearTimeout(expireTimeoutId); + }; + } + }, [secondsUntilExpiration, user]); if (!error && !expired) { return null; } + console.debug( + `Rendering HealthCheckBanner. Error: ${error}, Expired: ${expired}` + ); + if (error instanceof RedirectError || expired) { return ( = ({ name, ...props }) => { + const IconComponent = getPreloadedIcon(name); + return IconComponent ? ( + + ) : ( + + ); +}; + +// Cache for storing preloaded icons +const iconCache: Record = {}; + +// Preloads icons asynchronously and stores them in the cache +export async function preloadIcons(iconNames: string[]): Promise { + const promises = iconNames.map(async (name) => { + try { + const iconModule = await import("react-icons/fa"); + const iconName = + `Fa${name.charAt(0).toUpperCase() + name.slice(1)}` as keyof typeof iconModule; + iconCache[name] = (iconModule[iconName] as IconType) || FaQuestion; + } catch (error) { + console.error(`Failed to load icon: ${name}`, error); + iconCache[name] = FaQuestion; + } + }); + + await Promise.all(promises); +} + +// Retrieves a preloaded icon from the cache +export function getPreloadedIcon(name: string): IconType | undefined { + return iconCache[name] || FaQuestion; +} + +export default DynamicFaIcon; diff --git a/web/src/components/icons/icons.tsx b/web/src/components/icons/icons.tsx index fe33d4e9c15..b5e735b0e65 100644 --- a/web/src/components/icons/icons.tsx +++ b/web/src/components/icons/icons.tsx @@ -48,6 +48,7 @@ import jiraSVG from "../../../public/Jira.svg"; import confluenceSVG from "../../../public/Confluence.svg"; import openAISVG from "../../../public/Openai.svg"; import openSourceIcon from "../../../public/OpenSource.png"; +import litellmIcon from "../../../public/LiteLLM.jpg"; import awsWEBP from "../../../public/Amazon.webp"; import azureIcon from "../../../public/Azure.png"; @@ -267,6 +268,20 @@ export const ColorSlackIcon = ({ ); }; +export const LiteLLMIcon = ({ + size = 16, + className = defaultTailwindCSS, +}: IconProps) => { + return ( +
+ +
+ ); +}; + export const OpenSourceIcon = ({ size = 16, className = defaultTailwindCSS, @@ -2754,3 +2769,45 @@ export const CameraIcon = ({ ); }; + +export const MacIcon = ({ + size = 16, + className = "my-auto flex flex-shrink-0 ", +}: IconProps) => { + return ( + + + + ); +}; + +export const WindowsIcon = ({ + size = 16, + className = "my-auto flex flex-shrink-0 ", +}: IconProps) => { + return ( + + + + ); +}; diff --git a/web/src/components/initialSetup/welcome/WelcomeModal.tsx b/web/src/components/initialSetup/welcome/WelcomeModal.tsx index c9472992f73..ec71c5e6de4 100644 --- a/web/src/components/initialSetup/welcome/WelcomeModal.tsx +++ b/web/src/components/initialSetup/welcome/WelcomeModal.tsx @@ -27,13 +27,11 @@ function UsageTypeSection({ title, description, callToAction, - icon, onClick, }: { title: string; description: string | JSX.Element; callToAction: string; - icon?: React.ElementType; onClick: () => void; }) { return ( @@ -243,7 +241,6 @@ export function _WelcomeModal({ user }: { user: User | null }) { this is the option for you! } - icon={FiMessageSquare} callToAction="Get Started" onClick={() => { setSelectedFlow("chat"); diff --git a/web/src/components/llm/ApiKeyForm.tsx b/web/src/components/llm/ApiKeyForm.tsx index 0ebe38dc3d7..1a1f24d9183 100644 --- a/web/src/components/llm/ApiKeyForm.tsx +++ b/web/src/components/llm/ApiKeyForm.tsx @@ -55,6 +55,7 @@ export const ApiKeyForm = ({ return ( onSuccess()} shouldMarkAsDefault diff --git a/web/src/components/llm/ApiKeyModal.tsx b/web/src/components/llm/ApiKeyModal.tsx index 8c522fdc398..706413b4eef 100644 --- a/web/src/components/llm/ApiKeyModal.tsx +++ b/web/src/components/llm/ApiKeyModal.tsx @@ -1,60 +1,38 @@ "use client"; -import { useState, useEffect } from "react"; import { ApiKeyForm } from "./ApiKeyForm"; import { Modal } from "../Modal"; -import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces"; -import { checkLlmProvider } from "../initialSetup/welcome/lib"; -import { User } from "@/lib/types"; import { useRouter } from "next/navigation"; +import { useProviderStatus } from "../chat_search/ProviderContext"; -export const ApiKeyModal = ({ user }: { user: User | null }) => { +export const ApiKeyModal = ({ hide }: { hide: () => void }) => { const router = useRouter(); - const [forceHidden, setForceHidden] = useState(false); - const [validProviderExists, setValidProviderExists] = useState(true); - const [providerOptions, setProviderOptions] = useState< - WellKnownLLMProviderDescriptor[] - >([]); + const { + shouldShowConfigurationNeeded, + providerOptions, + refreshProviderInfo, + } = useProviderStatus(); - useEffect(() => { - async function fetchProviderInfo() { - const { providers, options, defaultCheckSuccessful } = - await checkLlmProvider(user); - setValidProviderExists(providers.length > 0 && defaultCheckSuccessful); - setProviderOptions(options); - } - - fetchProviderInfo(); - }, []); - - // don't show if - // (1) a valid provider has been setup or - // (2) there are no provider options (e.g. user isn't an admin) - // (3) user explicitly hides the modal - if (validProviderExists || !providerOptions.length || forceHidden) { + if (!shouldShowConfigurationNeeded) { return null; } return ( setForceHidden(true)} + title="Set an API Key!" + className="max-w-3xl" + onOutsideClick={() => hide()} >
- Please setup an LLM below in order to start using Danswer Search or - Danswer Chat. Don't worry, you can always change this later in - the Admin Panel. -
+ Please provide an API Key below in order to start using + Danswer – you can always change this later.
- Or if you'd rather look around first,{" "} - setForceHidden(true)} - className="text-link cursor-pointer" - > + If you'd rather look around first, you can + hide()} className="text-link cursor-pointer"> + {" "} skip this step . @@ -63,7 +41,8 @@ export const ApiKeyModal = ({ user }: { user: User | null }) => { { router.refresh(); - setForceHidden(true); + refreshProviderInfo(); + hide(); }} providerOptions={providerOptions} /> diff --git a/web/src/components/llm/LLMList.tsx b/web/src/components/llm/LLMList.tsx index c8eef1189b9..191a8948580 100644 --- a/web/src/components/llm/LLMList.tsx +++ b/web/src/components/llm/LLMList.tsx @@ -1,7 +1,10 @@ import React from "react"; import { getDisplayNameForModel } from "@/lib/hooks"; import { structureValue } from "@/lib/llm/utils"; -import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces"; +import { + getProviderIcon, + LLMProviderDescriptor, +} from "@/app/admin/configuration/llm/interfaces"; interface LlmListProps { llmProviders: LLMProviderDescriptor[]; @@ -9,6 +12,7 @@ interface LlmListProps { onSelect: (value: string | null) => void; userDefault?: string | null; scrollable?: boolean; + hideProviderIcon?: boolean; } export const LlmList: React.FC = ({ @@ -19,7 +23,11 @@ export const LlmList: React.FC = ({ scrollable, }) => { const llmOptionsByProvider: { - [provider: string]: { name: string; value: string }[]; + [provider: string]: { + name: string; + value: string; + icon: React.FC<{ size?: number; className?: string }>; + }[]; } = {}; const uniqueModelNames = new Set(); @@ -39,6 +47,7 @@ export const LlmList: React.FC = ({ llmProvider.provider, modelName ), + icon: getProviderIcon(llmProvider.provider), }); } } @@ -55,6 +64,7 @@ export const LlmList: React.FC = ({ > {userDefault && ( )} - {llmOptions.map(({ name, value }, index) => ( + {llmOptions.map(({ name, icon, value }, index) => ( ))} diff --git a/web/src/components/search/DocumentDisplay.tsx b/web/src/components/search/DocumentDisplay.tsx index 5fb1c255b4e..2b2258bf6c3 100644 --- a/web/src/components/search/DocumentDisplay.tsx +++ b/web/src/components/search/DocumentDisplay.tsx @@ -19,6 +19,7 @@ import { FiTag } from "react-icons/fi"; import { DISABLE_LLM_DOC_RELEVANCE } from "@/lib/constants"; import { SettingsContext } from "../settings/SettingsProvider"; import { CustomTooltip, TooltipGroup } from "../tooltip/CustomTooltip"; +import { WarningCircle } from "@phosphor-icons/react"; export const buildDocumentSummaryDisplay = ( matchHighlights: string[], @@ -230,7 +231,7 @@ export const DocumentDisplay = ({ {document.semantic_identifier || document.document_id}

-
+
{isHovered && messageId && ( @@ -326,31 +327,32 @@ export const AgenticDocumentDisplay = ({

-
- {isHovered && messageId && ( - - )} - - {(contentEnriched || additional_relevance) && - relevance_explanation && - (isHovered || alternativeToggled) && ( - +
+ + {isHovered && messageId && ( + )} + + {(contentEnriched || additional_relevance) && + (isHovered || alternativeToggled) && ( + + )} +
@@ -367,7 +369,13 @@ export const AgenticDocumentDisplay = ({ document.match_highlights, document.blurb ) - : relevance_explanation} + : relevance_explanation || ( + + {" "} + + Model failed to produce an analysis of the document + + )}

diff --git a/web/src/components/search/SearchAnswer.tsx b/web/src/components/search/SearchAnswer.tsx index 359879831ec..e8645dfcea0 100644 --- a/web/src/components/search/SearchAnswer.tsx +++ b/web/src/components/search/SearchAnswer.tsx @@ -78,13 +78,13 @@ export default function SearchAnswer({ {searchState == "generating" && (
- Generating response... + Generating Response...
)} {searchState == "citing" && (
- Creating citations... + Extracting Quotes...
)} diff --git a/web/src/components/search/SearchBar.tsx b/web/src/components/search/SearchBar.tsx index 98533ffeaaf..7d62b099206 100644 --- a/web/src/components/search/SearchBar.tsx +++ b/web/src/components/search/SearchBar.tsx @@ -71,7 +71,7 @@ export const AnimatedToggle = ({ Get quality results immediately, best suited for instant access to your documents.

-

Shortcut: ({commandSymbol}/)

+

Shortcut: ({commandSymbol}/)

} > diff --git a/web/src/components/search/SearchSection.tsx b/web/src/components/search/SearchSection.tsx index 5711854217f..cb33d961831 100644 --- a/web/src/components/search/SearchSection.tsx +++ b/web/src/components/search/SearchSection.tsx @@ -34,8 +34,13 @@ import FixedLogo from "@/app/chat/shared_chat_search/FixedLogo"; import { usePopup } from "../admin/connectors/Popup"; import { FeedbackType } from "@/app/chat/types"; import { FeedbackModal } from "@/app/chat/modal/FeedbackModal"; -import { handleChatFeedback } from "@/app/chat/lib"; +import { deleteChatSession, handleChatFeedback } from "@/app/chat/lib"; import SearchAnswer from "./SearchAnswer"; +import { DeleteEntityModal } from "../modals/DeleteEntityModal"; +import { ApiKeyModal } from "../llm/ApiKeyModal"; +import { useSearchContext } from "../context/SearchContext"; +import { useUser } from "../user/UserProvider"; +import UnconfiguredProviderText from "../chat_search/UnconfiguredProviderText"; export type searchState = | "input" @@ -57,33 +62,28 @@ const VALID_QUESTION_RESPONSE_DEFAULT: ValidQuestionResponse = { }; interface SearchSectionProps { - disabledAgentic: boolean; - ccPairs: CCPairBasicInfo[]; - documentSets: DocumentSet[]; - personas: Persona[]; - tags: Tag[]; toggle: () => void; - querySessions: ChatSession[]; defaultSearchType: SearchType; - user: User | null; toggledSidebar: boolean; - agenticSearchEnabled: boolean; } export const SearchSection = ({ - ccPairs, toggle, - disabledAgentic, - documentSets, - agenticSearchEnabled, - personas, - user, - tags, - querySessions, toggledSidebar, defaultSearchType, }: SearchSectionProps) => { - // Search Bar + const { + querySessions, + ccPairs, + documentSets, + assistants, + tags, + shouldShowWelcomeModal, + agenticSearchEnabled, + disabledAgentic, + shouldDisplayNoSources, + } = useSearchContext(); + const [query, setQuery] = useState(""); const [comments, setComments] = useState(null); const [contentEnriched, setContentEnriched] = useState(false); @@ -99,6 +99,8 @@ export const SearchSection = ({ messageId: null, }); + const [showApiKeyModal, setShowApiKeyModal] = useState(true); + const [agentic, setAgentic] = useState(agenticSearchEnabled); const toggleAgentic = () => { @@ -146,7 +148,7 @@ export const SearchSection = ({ useState(defaultSearchType); const [selectedPersona, setSelectedPersona] = useState( - personas[0]?.id || 0 + assistants[0]?.id || 0 ); // Used for search state display @@ -157,8 +159,8 @@ export const SearchSection = ({ const availableSources = ccPairs.map((ccPair) => ccPair.source); const [finalAvailableSources, finalAvailableDocumentSets] = computeAvailableFilters({ - selectedPersona: personas.find( - (persona) => persona.id === selectedPersona + selectedPersona: assistants.find( + (assistant) => assistant.id === selectedPersona ), availableSources: availableSources, availableDocumentSets: documentSets, @@ -267,7 +269,7 @@ export const SearchSection = ({ ...(prevState || initialSearchResponse), quotes, })); - setSearchState((searchState) => "input"); + setSearchState((searchState) => "citing"); }; const updateDocs = (documents: SearchDanswerDocument[]) => { @@ -294,7 +296,7 @@ export const SearchSection = ({ })); if (disabledAgentic) { setIsFetching(false); - setSearchState("input"); + setSearchState((searchState) => "citing"); } if (documents.length == 0) { setSearchState("input"); @@ -332,11 +334,8 @@ export const SearchSection = ({ messageId, })); router.refresh(); - // setSearchState("input"); setIsFetching(false); setSearchState((searchState) => "input"); - - // router.replace(`/search?searchId=${chat_session_id}`); }; const updateDocumentRelevance = (relevance: Relevance) => { @@ -364,6 +363,7 @@ export const SearchSection = ({ setSearchState("input"); } }; + const { user } = useUser(); const [searchAnswerExpanded, setSearchAnswerExpanded] = useState(false); const resetInput = (finalized?: boolean) => { @@ -405,8 +405,8 @@ export const SearchSection = ({ documentSets: filterManager.selectedDocumentSets, timeRange: filterManager.timeRange, tags: filterManager.selectedTags, - persona: personas.find( - (persona) => persona.id === selectedPersona + persona: assistants.find( + (assistant) => assistant.id === selectedPersona ) as Persona, updateCurrentAnswer: cancellable({ cancellationToken: lastSearchCancellationToken.current, @@ -511,7 +511,12 @@ export const SearchSection = ({ }; const [firstSearch, setFirstSearch] = useState(true); const [searchState, setSearchState] = useState("input"); + const [deletingChatSession, setDeletingChatSession] = + useState(); + const showDeleteModal = (chatSession: ChatSession) => { + setDeletingChatSession(chatSession); + }; // Used to maintain a "time out" for history sidebar so our existing refs can have time to process change const [untoggled, setUntoggled] = useState(false); @@ -579,10 +584,44 @@ export const SearchSection = ({ const { popup, setPopup } = usePopup(); + const shouldUseAgenticDisplay = + agenticResults && + (searchResponse.documents || []).some( + (document) => + searchResponse.additional_relevance && + searchResponse.additional_relevance[document.document_id] !== undefined + ); + return ( <>
{popup} + + {!shouldDisplayNoSources && + showApiKeyModal && + !shouldShowWelcomeModal && ( + setShowApiKeyModal(false)} /> + )} + + {deletingChatSession && ( + setDeletingChatSession(null)} + onSubmit={async () => { + const response = await deleteChatSession(deletingChatSession.id); + if (response.ok) { + setDeletingChatSession(null); + // go back to the main page + router.push("/search"); + } else { + const responseJson = await response.json(); + setPopup({ message: responseJson.detail, type: "error" }); + } + router.refresh(); + }} + /> + )} {currentFeedback && (
setQuery("")} page="search" @@ -664,7 +704,7 @@ export const SearchSection = ({ (ccPairs.length > 0 || documentSets.length > 0) && (
+ + setShowApiKeyModal(true)} + /> +
+ {availableTags.length > 0 && ( + <> +
+ Tags +
+ + + )} + {existingSources.length > 0 && (
@@ -191,19 +204,6 @@ export function SourceSelector({
)} - - {availableTags.length > 0 && ( - <> -
- Tags -
- - - )}
); } diff --git a/web/src/components/search/filtering/TagFilter.tsx b/web/src/components/search/filtering/TagFilter.tsx index ec3a7f38bb4..80a6ee78922 100644 --- a/web/src/components/search/filtering/TagFilter.tsx +++ b/web/src/components/search/filtering/TagFilter.tsx @@ -115,7 +115,7 @@ export function TagFilter({ Tags
-
+
{filteredTags.length > 0 ? ( filteredTags.map((tag) => (
{ header = <>; body = ( - - {replaceNewlines(props.answer || "")} - + ); // error while building answer (NOTE: if error occurs during quote generation @@ -63,12 +62,7 @@ export const AnswerSection = (props: AnswerSectionProps) => { status = "success"; header = <>; body = ( - - {replaceNewlines(props.answer)} - + ); } diff --git a/web/src/components/user/UserProvider.tsx b/web/src/components/user/UserProvider.tsx index d2cf0f2c94f..67777277c27 100644 --- a/web/src/components/user/UserProvider.tsx +++ b/web/src/components/user/UserProvider.tsx @@ -8,6 +8,7 @@ interface UserContextType { user: User | null; isLoadingUser: boolean; isAdmin: boolean; + isCurator: boolean; refreshUser: () => Promise; } @@ -17,12 +18,16 @@ export function UserProvider({ children }: { children: React.ReactNode }) { const [user, setUser] = useState(null); const [isLoadingUser, setIsLoadingUser] = useState(true); const [isAdmin, setIsAdmin] = useState(false); + const [isCurator, setIsCurator] = useState(false); const fetchUser = async () => { try { const user = await getCurrentUser(); setUser(user); setIsAdmin(user?.role === UserRole.ADMIN); + setIsCurator( + user?.role === UserRole.CURATOR || user?.role == UserRole.GLOBAL_CURATOR + ); } catch (error) { console.error("Error fetching current user:", error); } finally { @@ -40,7 +45,9 @@ export function UserProvider({ children }: { children: React.ReactNode }) { }; return ( - + {children} ); diff --git a/web/src/lib/admin/users/userMutationFetcher.ts b/web/src/lib/admin/users/userMutationFetcher.ts index ee3e201e096..d0c090c89ce 100644 --- a/web/src/lib/admin/users/userMutationFetcher.ts +++ b/web/src/lib/admin/users/userMutationFetcher.ts @@ -1,13 +1,14 @@ const userMutationFetcher = async ( url: string, - { arg }: { arg: { user_email: string; new_role?: string } } + { arg }: { arg: { user_email: string; new_role?: string; method?: string } } ) => { + const { method = "PATCH", ...body } = arg; return fetch(url, { - method: "PATCH", + method, headers: { "Content-Type": "application/json", }, - body: JSON.stringify(arg), + body: JSON.stringify(body), }).then(async (res) => { if (res.ok) return res.json(); const errorDetail = (await res.json()).detail; diff --git a/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts b/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts index f779b2ed613..9873e0a256b 100644 --- a/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts +++ b/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts @@ -2,7 +2,10 @@ import { Persona } from "@/app/admin/assistants/interfaces"; import { CCPairBasicInfo, DocumentSet, User } from "../types"; import { getCurrentUserSS } from "../userSS"; import { fetchSS } from "../utilsSS"; -import { FullLLMProvider } from "@/app/admin/configuration/llm/interfaces"; +import { + FullLLMProvider, + getProviderIcon, +} from "@/app/admin/configuration/llm/interfaces"; import { ToolSnapshot } from "../tools/interfaces"; import { fetchToolsSS } from "../tools/fetchTools"; import { @@ -94,17 +97,7 @@ export async function fetchAssistantEditorInfoSS( } for (const provider of llmProviders) { - if (provider.provider == "openai") { - provider.icon = OpenAIIcon; - } else if (provider.provider == "anthropic") { - provider.icon = AnthropicIcon; - } else if (provider.provider == "bedrock") { - provider.icon = AWSIcon; - } else if (provider.provider == "azure") { - provider.icon = AzureIcon; - } else { - provider.icon = OpenSourceIcon; - } + provider.icon = getProviderIcon(provider.provider); } const existingPersona = personaResponse diff --git a/web/src/lib/browserUtilities.tsx b/web/src/lib/browserUtilities.tsx index 445ed649507..6646db9b5cc 100644 --- a/web/src/lib/browserUtilities.tsx +++ b/web/src/lib/browserUtilities.tsx @@ -1,5 +1,6 @@ "use client"; +import { MacIcon, WindowsIcon } from "@/components/icons/icons"; import { useState, useEffect } from "react"; type OperatingSystem = "Windows" | "Mac" | "Other"; @@ -23,9 +24,9 @@ const KeyboardSymbol = () => { const os = useOperatingSystem(); if (os === "Windows") { - return "⊞"; + return ; } else { - return "⌘"; + return ; } }; diff --git a/web/src/lib/chat/fetchChatData.ts b/web/src/lib/chat/fetchChatData.ts index e90cc6a1a24..d6a92410f04 100644 --- a/web/src/lib/chat/fetchChatData.ts +++ b/web/src/lib/chat/fetchChatData.ts @@ -44,7 +44,6 @@ interface FetchChatDataResult { toggleSidebar: boolean; finalDocumentSidebarInitialWidth?: number; shouldShowWelcomeModal: boolean; - shouldDisplaySourcesIncompleteModal: boolean; userInputPrompts: InputPrompt[]; } @@ -242,7 +241,6 @@ export async function fetchChatData(searchParams: { finalDocumentSidebarInitialWidth, toggleSidebar, shouldShowWelcomeModal, - shouldDisplaySourcesIncompleteModal, userInputPrompts, }; } diff --git a/web/src/lib/connectors/connectors.ts b/web/src/lib/connectors/connectors.ts index c66dbd2453b..e7b07ca8d1d 100644 --- a/web/src/lib/connectors/connectors.ts +++ b/web/src/lib/connectors/connectors.ts @@ -1,4 +1,4 @@ -import { ValidInputTypes, ValidSources } from "../types"; +import { ConfigurableSources, ValidInputTypes, ValidSources } from "../types"; export type InputType = | "list" @@ -26,7 +26,6 @@ export interface Option { export interface SelectOption extends Option { type: "select"; - default?: number; options?: StringWithDescription[]; } @@ -76,7 +75,10 @@ export interface ConnectionConfiguration { overrideDefaultFreq?: number; } -export const connectorConfigs: Record = { +export const connectorConfigs: Record< + ConfigurableSources, + ConnectionConfiguration +> = { web: { description: "Configure Web connector", values: [ @@ -93,7 +95,6 @@ export const connectorConfigs: Record = { query: "Select the web connector type:", label: "Scrape Method", name: "web_connector_type", - optional: true, options: [ { name: "recursive", value: "recursive" }, { name: "single", value: "single" }, @@ -219,28 +220,57 @@ export const connectorConfigs: Record = { }, confluence: { description: "Configure Confluence connector", - subtext: `Specify any link to a Confluence page below and click "Index" to Index. If the provided link is for an entire space, we will index the entire space. However, if you want to index a specific page, you can do so by entering the page's URL. - -For example, entering https://danswer.atlassian.net/wiki/spaces/Engineering/overview and clicking the Index button will index the whole Engineering Confluence space, but entering https://danswer.atlassian.net/wiki/spaces/Engineering/pages/164331/example+page will index that page (and optionally the page's children). + subtext: `Specify the base URL of your Confluence instance, the space name, and optionally a specific page ID to index. If no page ID is provided, the entire space will be indexed. -Selecting the "Index Recursively" checkbox will index the single page's children in addition to itself.`, +For example, entering "https://pablosfsanchez.atlassian.net/wiki" as the Wiki Base URL, "KB" as the Space, and "164331" as the Page ID will index the specific page at https://pablosfsanchez.atlassian.net/wiki/spaces/KB/pages/164331/Page. If you leave the Page ID empty, it will index the entire KB space. + +Selecting the "Index Recursively" checkbox will index the specified page and all of its children.`, values: [ { type: "text", - query: "Enter the wiki page URL:", - label: "Wiki Page URL", - name: "wiki_page_url", + query: "Enter the wiki base URL:", + label: "Wiki Base URL", + name: "wiki_base", + optional: false, + description: + "The base URL of your Confluence instance (e.g., https://your-domain.atlassian.net/wiki)", + }, + { + type: "text", + query: "Enter the space:", + label: "Space", + name: "space", optional: false, - description: "Enter any link to a Confluence space or Page", + description: "The Confluence space name to index (e.g. `KB`)", + }, + { + type: "text", + query: "Enter the page ID (optional):", + label: "Page ID", + name: "page_id", + optional: true, + description: + "Specific page ID to index - leave empty to index the entire space (e.g. `131368`)", }, { type: "checkbox", query: "Should index pages recursively?", - label: - "Index Recursively (if this is set and the Wiki Page URL leads to a page, we will index the page and all of its children instead of just the page)", + label: "Index Recursively", name: "index_recursively", + description: + "If this is set and the Wiki Page URL leads to a page, we will index the page and all of its children instead of just the page. This is set by default for Confluence connectors without a page ID specified.", optional: false, }, + { + type: "checkbox", + query: "Is this a Confluence Cloud instance?", + label: "Is Cloud", + name: "is_cloud", + optional: false, + default: true, + description: + "Check if this is a Confluence Cloud instance, uncheck for Confluence Server/Data Center", + }, ], }, jira: { @@ -556,7 +586,20 @@ For example, specifying .*-support.* as a "channel" will cause the connector to }, zendesk: { description: "Configure Zendesk connector", - values: [], + values: [ + { + type: "select", + query: "Select the what content this connector will index:", + label: "Content Type", + name: "content_type", + optional: false, + options: [ + { name: "articles", value: "articles" }, + { name: "tickets", value: "tickets" }, + ], + default: 0, + }, + ], }, linear: { description: "Configure Dropbox connector", @@ -817,7 +860,10 @@ export interface GmailConfig {} export interface BookstackConfig {} export interface ConfluenceConfig { - wiki_page_url: string; + wiki_base: string; + space: string; + page_id?: string; + is_cloud?: boolean; index_recursively?: boolean; } diff --git a/web/src/lib/connectors/credentials.ts b/web/src/lib/connectors/credentials.ts index 1babf7de12f..424a07c82fe 100644 --- a/web/src/lib/connectors/credentials.ts +++ b/web/src/lib/connectors/credentials.ts @@ -288,6 +288,7 @@ export const credentialTemplates: Record = { mediawiki: null, web: null, not_applicable: null, + ingestion_api: null, // NOTE: These are Special Cases google_drive: { google_drive_tokens: "" } as GoogleDriveCredentialJson, diff --git a/web/src/lib/credential.ts b/web/src/lib/credential.ts index 0552e73cc9e..03f6c6e75da 100644 --- a/web/src/lib/credential.ts +++ b/web/src/lib/credential.ts @@ -73,7 +73,7 @@ export function updateCredential(credentialId: number, newDetails: any) { ([key, value]) => key !== "name" && value !== "" ) ); - return fetch(`/api/manage/admin/credentials/${credentialId}`, { + return fetch(`/api/manage/admin/credential/${credentialId}`, { method: "PUT", headers: { "Content-Type": "application/json", @@ -86,7 +86,7 @@ export function updateCredential(credentialId: number, newDetails: any) { } export function swapCredential(newCredentialId: number, connectorId: number) { - return fetch(`/api/manage/admin/credentials/swap`, { + return fetch(`/api/manage/admin/credential/swap`, { method: "PUT", headers: { "Content-Type": "application/json", diff --git a/web/src/lib/search/interfaces.ts b/web/src/lib/search/interfaces.ts index b33879055cc..6983bd3367f 100644 --- a/web/src/lib/search/interfaces.ts +++ b/web/src/lib/search/interfaces.ts @@ -19,6 +19,15 @@ export interface AnswerPiecePacket { answer_piece: string; } +export enum StreamStopReason { + CONTEXT_LENGTH = "CONTEXT_LENGTH", + CANCELLED = "CANCELLED", +} + +export interface StreamStopInfo { + stop_reason: StreamStopReason; +} + export interface ErrorMessagePacket { error: string; } diff --git a/web/src/lib/search/streamingQa.ts b/web/src/lib/search/streamingQa.ts index 1f9b595c125..e30063e2fb9 100644 --- a/web/src/lib/search/streamingQa.ts +++ b/web/src/lib/search/streamingQa.ts @@ -98,7 +98,7 @@ export const searchRequestStreamed = async ({ } previousPartialChunk = partialChunk as string | null; completedChunks.forEach((chunk) => { - // check for answer peice / end of answer + // check for answer piece / end of answer if (Object.hasOwn(chunk, "relevance_summaries")) { const relevanceChunk = chunk as RelevanceChunk; diff --git a/web/src/lib/sources.ts b/web/src/lib/sources.ts index 83c5c174438..bbc63847adb 100644 --- a/web/src/lib/sources.ts +++ b/web/src/lib/sources.ts @@ -272,6 +272,11 @@ const SOURCE_METADATA_MAP: SourceMap = { category: SourceCategory.Storage, docs: "https://docs.danswer.dev/connectors/google_storage", }, + ingestion_api: { + icon: GlobeIcon, + displayName: "Ingestion", + category: SourceCategory.Other, + }, // currently used for the Internet Search tool docs, which is why // a globe is used not_applicable: { @@ -302,8 +307,12 @@ export function getSourceMetadata(sourceType: ValidSources): SourceMetadata { } export function listSourceMetadata(): SourceMetadata[] { + /* This gives back all the viewable / common sources, primarily for + display in the Add Connector page */ const entries = Object.entries(SOURCE_METADATA_MAP) - .filter(([source, _]) => source !== "not_applicable") + .filter( + ([source, _]) => source !== "not_applicable" && source != "ingestion_api" + ) .map(([source, metadata]) => { return fillSourceMetadata(metadata, source as ValidSources); }); diff --git a/web/src/lib/ss/ccPair.ts b/web/src/lib/ss/ccPair.ts deleted file mode 100644 index 847321d1103..00000000000 --- a/web/src/lib/ss/ccPair.ts +++ /dev/null @@ -1,5 +0,0 @@ -import { fetchSS } from "../utilsSS"; - -export async function getCCPairSS(ccPairId: number) { - return fetchSS(`/manage/admin/cc-pair/${ccPairId}`); -} diff --git a/web/src/lib/time.ts b/web/src/lib/time.ts index 114dbee4c31..418436b3f45 100644 --- a/web/src/lib/time.ts +++ b/web/src/lib/time.ts @@ -101,6 +101,7 @@ export function getSecondsUntilExpiration( if (!userInfo) { return null; } + const { oidc_expiry, current_token_created_at, current_token_expiry_length } = userInfo; diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 93370c1daa5..c178fa5992f 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -237,6 +237,12 @@ const validSources = [ "google_cloud_storage", "oci_storage", "not_applicable", -]; + "ingestion_api", +] as const; export type ValidSources = (typeof validSources)[number]; +// The valid sources that are actually valid to select in the UI +export type ConfigurableSources = Exclude< + ValidSources, + "not_applicable" | "ingestion_api" +>; diff --git a/web/src/middleware.ts b/web/src/middleware.ts index 714e70b4323..706e6ee0f4b 100644 --- a/web/src/middleware.ts +++ b/web/src/middleware.ts @@ -10,6 +10,7 @@ const eePaths = [ "/admin/whitelabeling", "/admin/performance/custom-analytics", ]; + const eePathsForMatcher = eePaths.map((path) => `${path}/:path*`); export async function middleware(request: NextRequest) { diff --git a/web/tailwind-themes/tailwind.config.js b/web/tailwind-themes/tailwind.config.js index f2d7601fd6f..1cac0c877a1 100644 --- a/web/tailwind-themes/tailwind.config.js +++ b/web/tailwind-themes/tailwind.config.js @@ -65,6 +65,7 @@ module.exports = { maxWidth: { "document-sidebar": "1000px", "message-max": "850px", + "content-max": "725px", "searchbar-max": "800px", }, colors: {