From ab2da202a5adb4faab8f4c7c4aec48c3ce27497a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 16 May 2024 21:27:27 +0200 Subject: [PATCH 01/29] serve web UI with FastAPI / uvicorn (#414) --- .github/actions/setup-env/action.yml | 11 +- .github/workflows/test.yml | 4 +- environment-dev.yml | 2 +- pyproject.toml | 13 +- ragna/deploy/_api/core.py | 10 +- ragna/deploy/_ui/api_wrapper.py | 43 ------- ragna/deploy/_ui/app.py | 127 ++++++++++--------- ragna/deploy/_ui/auth_page.py | 91 ------------- ragna/deploy/_ui/central_view.py | 2 +- ragna/deploy/_ui/components/file_uploader.py | 5 +- ragna/deploy/_ui/css/auth/button.css | 5 - ragna/deploy/_ui/css/auth/column.css | 23 ---- ragna/deploy/_ui/css/auth/html.css | 24 ---- ragna/deploy/_ui/css/auth/textinput.css | 18 --- ragna/deploy/_ui/js_utils.py | 62 --------- ragna/deploy/_ui/logout_page.py | 21 --- ragna/deploy/_ui/modal_configuration.py | 1 - ragna/deploy/_ui/resources/upload.js | 20 +-- ragna/deploy/_ui/styles.py | 1 - tests/deploy/api/utils.py | 1 + tests/test_js_utils.py | 47 ------- 21 files changed, 105 insertions(+), 426 deletions(-) delete mode 100644 ragna/deploy/_ui/auth_page.py delete mode 100644 ragna/deploy/_ui/css/auth/button.css delete mode 100644 ragna/deploy/_ui/css/auth/column.css delete mode 100644 ragna/deploy/_ui/css/auth/html.css delete mode 100644 ragna/deploy/_ui/css/auth/textinput.css delete mode 100644 ragna/deploy/_ui/js_utils.py delete mode 100644 ragna/deploy/_ui/logout_page.py delete mode 100644 tests/test_js_utils.py diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index 7deb55c2..46b826af 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -18,7 +18,7 @@ runs: with: miniforge-variant: Mambaforge miniforge-version: latest - activate-environment: ragna-dev + activate-environment: ragna-deploy-dev - name: Display conda info shell: bash -el {0} @@ -58,6 +58,13 @@ runs: shell: bash -el {0} run: mamba install --yes --channel conda-forge redis-server + - name: Install dev dependencies + shell: bash -el {0} + run: | + pip install \ + git+https://github.com/bokeh/bokeh-fastapi.git@main \ + git+https://github.com/holoviz/panel.git@main + - name: Install ragna shell: bash -el {0} run: | @@ -67,7 +74,7 @@ runs: else PROJECT_PATH='.' fi - pip install --editable "${PROJECT_PATH}" + pip install --verbose --editable "${PROJECT_PATH}" - name: Display development environment shell: bash -el {0} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b74f3907..e8462d71 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,7 +38,9 @@ jobs: matrix: os: - ubuntu-latest - - windows-latest + # FIXME + # Building panel from source on Windows does not work through pip + # - windows-latest - macos-latest python-version: ["3.9"] include: diff --git a/environment-dev.yml b/environment-dev.yml index 2a7b6a03..ce5b1658 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -1,4 +1,4 @@ -name: ragna-dev +name: ragna-deploy-dev channels: - conda-forge dependencies: diff --git a/pyproject.toml b/pyproject.toml index 9ceb1569..f85a71a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [build-system] requires = [ - "setuptools>=45", - "setuptools_scm[toml]>=6.2", + "setuptools>=64", + "setuptools_scm[toml]>=8", ] build-backend = "setuptools.build_meta" @@ -23,11 +23,14 @@ requires-python = ">=3.9" dependencies = [ "aiofiles", "emoji", + "eval_type_backport; python_version<'3.10'", "fastapi", "httpx", "importlib_metadata>=4.6; python_version<'3.10'", "packaging", - "panel==1.4.2", + # FIXME: pin them to released versions + "bokeh-fastapi", + "panel", "pydantic>=2", "pydantic-core", "pydantic-settings>=2", @@ -142,6 +145,9 @@ disallow_incomplete_defs = false [[tool.mypy.overrides]] module = [ + # FIXME: the package should be typed + "bokeh_fastapi", + "bokeh_fastapi.handler", "docx", "fitz", "ijson", @@ -150,6 +156,7 @@ module = [ "pptx", "pyarrow", "sentence_transformers", + "traitlets", ] ignore_missing_imports = true diff --git a/ragna/deploy/_api/core.py b/ragna/deploy/_api/core.py index 5346b048..0f04802e 100644 --- a/ragna/deploy/_api/core.py +++ b/ragna/deploy/_api/core.py @@ -23,6 +23,7 @@ from ragna._utils import handle_localhost_origins from ragna.core import Assistant, Component, Rag, RagnaException, SourceStorage from ragna.core._rag import SpecialChatParams +from ragna.core._utils import default_user from ragna.deploy import Config from . import database, schemas @@ -98,13 +99,10 @@ async def ragna_exception_handler( async def version() -> str: return ragna.__version__ - authentication = config.authentication() + def get_user() -> str: + return default_user() - @app.post("/token") - async def create_token(request: Request) -> str: - return await authentication.create_token(request) - - UserDependency = Annotated[str, Depends(authentication.get_user)] + UserDependency = Annotated[str, Depends(get_user)] def _get_component_json_schema( component: Type[Component], diff --git a/ragna/deploy/_ui/api_wrapper.py b/ragna/deploy/_ui/api_wrapper.py index f96375de..a57f88e4 100644 --- a/ragna/deploy/_ui/api_wrapper.py +++ b/ragna/deploy/_ui/api_wrapper.py @@ -6,56 +6,13 @@ import param -class RagnaAuthTokenExpiredException(Exception): - """Just a wrapper around Exception""" - - pass - - # The goal is this class is to provide ready-to-use functions to interact with the API class ApiWrapper(param.Parameterized): - auth_token = param.String(default=None) - def __init__(self, api_url, **params): self.client = httpx.AsyncClient(base_url=api_url, timeout=60) super().__init__(**params) - try: - # If no auth token is provided, we use the API base URL and only test the API is up. - # else, we test the API is up *and* the token is valid. - endpoint = ( - api_url + "/components" if self.auth_token is not None else api_url - ) - httpx.get( - endpoint, headers={"Authorization": f"Bearer {self.auth_token}"} - ).raise_for_status() - - except httpx.HTTPStatusError as e: - # unauthorized - the token is invalid - if e.response.status_code == 401: - raise RagnaAuthTokenExpiredException("Unauthorized") - else: - raise e - - async def auth(self, username, password): - self.auth_token = ( - ( - await self.client.post( - "/token", - data={"username": username, "password": password}, - ) - ) - .raise_for_status() - .json() - ) - - return True - - @param.depends("auth_token", watch=True, on_init=True) - def update_auth_header(self): - self.client.headers["Authorization"] = f"Bearer {self.auth_token}" - async def get_chats(self): json_data = (await self.client.get("/chats")).raise_for_status().json() for chat in json_data: diff --git a/ragna/deploy/_ui/app.py b/ragna/deploy/_ui/app.py index ed418271..dd566e37 100644 --- a/ragna/deploy/_ui/app.py +++ b/ragna/deploy/_ui/app.py @@ -1,18 +1,17 @@ from pathlib import Path -from urllib.parse import urlsplit +from typing import cast import panel as pn import param +from fastapi import FastAPI +from fastapi.staticfiles import StaticFiles from ragna._utils import handle_localhost_origins from ragna.deploy import Config from . import js from . import styles as ui -from .api_wrapper import ApiWrapper, RagnaAuthTokenExpiredException -from .auth_page import AuthPage -from .js_utils import redirect_script -from .logout_page import LogoutPage +from .api_wrapper import ApiWrapper from .main_page import MainPage pn.extension( @@ -79,74 +78,76 @@ def get_template(self): return template def index_page(self): - if "auth_token" not in pn.state.cookies: - return redirect_script(remove="", append="auth") - - try: - api_wrapper = ApiWrapper( - api_url=self.api_url, auth_token=pn.state.cookies["auth_token"] - ) - except RagnaAuthTokenExpiredException: - # If the token has expired / is invalid, we redirect to the logout page. - # The logout page will delete the cookie and redirect to the auth page. - return redirect_script(remove="", append="logout") + api_wrapper = ApiWrapper(api_url=self.api_url) template = self.get_template() main_page = MainPage(api_wrapper=api_wrapper, template=template) template.main.append(main_page) return template - def auth_page(self): - # If the user is already authenticated, we receive the auth token in the cookie. - # in that case, redirect to the index page. - if "auth_token" in pn.state.cookies: - # Usually, we do a redirect this way : - # >>> pn.state.location.param.update(reload=True, pathname="/") - # But it only works once the page is fully loaded. - # So we render a javascript redirect instead. - return redirect_script(remove="auth") - - template = self.get_template() - auth_page = AuthPage(api_wrapper=ApiWrapper(api_url=self.api_url)) - template.main.append(auth_page) - return template - - def logout_page(self): - template = self.get_template() - logout_page = LogoutPage(api_wrapper=ApiWrapper(api_url=self.api_url)) - template.main.append(logout_page) - return template - def health_page(self): return pn.pane.HTML("

Ok

") + def add_panel_app(self, server, panel_app_fn): + # FIXME: this code will ultimately be distributed as part of panel + from functools import partial + + import panel as pn + from bokeh.application import Application + from bokeh.application.handlers.function import FunctionHandler + from bokeh_fastapi import BokehFastAPI + from bokeh_fastapi.handler import WSHandler + from fastapi.responses import FileResponse + from panel.io.document import extra_socket_handlers + from panel.io.resources import COMPONENT_PATH + from panel.io.server import ComponentResourceHandler + from panel.io.state import set_curdoc + + def dispatch_fastapi(conn, events=None, msg=None): + if msg is None: + msg = conn.protocol.create("PATCH-DOC", events) + return [conn._socket.send_message(msg)] + + extra_socket_handlers[WSHandler] = dispatch_fastapi + + def panel_app(doc): + doc.on_event("document_ready", partial(pn.state._schedule_on_load, doc)) + + with set_curdoc(doc): + panel_app = panel_app_fn() + panel_app.server_doc(doc) + + handler = FunctionHandler(panel_app) + application = Application(handler) + + BokehFastAPI(application, server=server) + + @server.get(f"/{COMPONENT_PATH.rstrip('/')}" + "/{path:path}") + def get_component_resource(path: str): + # ComponentResourceHandler.parse_url_path only ever accesses + # self._resource_attrs, which fortunately is a class attribute. Thus, we can + # get away with using the method without actually instantiating the class + self_ = cast(ComponentResourceHandler, ComponentResourceHandler) + resolved_path = ComponentResourceHandler.parse_url_path(self_, path) + return FileResponse(resolved_path) + + def make_app(self): + app = FastAPI() + self.add_panel_app(app, self.index_page) + + for dir in ["css", "imgs", "resources"]: + app.mount( + f"/{dir}", + StaticFiles(directory=str(Path(__file__).parent / dir)), + name=dir, + ) + + return app + def serve(self): - all_pages = { - "/": self.index_page, - "/auth": self.auth_page, - "/logout": self.logout_page, - "/health": self.health_page, - } - titles = {"/": "Home"} - - pn.serve( - all_pages, - titles=titles, - address=self.hostname, - port=self.port, - admin=True, - start=True, - location=True, - show=self.open_browser, - keep_alive=30 * 1000, # 30s - autoreload=True, - profiler="pyinstrument", - allow_websocket_origin=[urlsplit(origin).netloc for origin in self.origins], - static_dirs={ - dir: str(Path(__file__).parent / dir) - for dir in ["css", "imgs", "resources"] - }, - ) + import uvicorn + + uvicorn.run(self.make_app, factory=True, host=self.hostname, port=self.port) def app(*, config: Config, open_browser: bool) -> App: diff --git a/ragna/deploy/_ui/auth_page.py b/ragna/deploy/_ui/auth_page.py deleted file mode 100644 index 4df5098c..00000000 --- a/ragna/deploy/_ui/auth_page.py +++ /dev/null @@ -1,91 +0,0 @@ -import panel as pn -import param - - -class AuthPage(pn.viewable.Viewer, param.Parameterized): - feedback_message = param.String(default=None) - - custom_js = param.String(default="") - - def __init__(self, api_wrapper, **params): - super().__init__(**params) - self.api_wrapper = api_wrapper - - self.main_layout = None - - self.login_input = pn.widgets.TextInput( - name="Email", - css_classes=["auth_login_input"], - ) - self.password_input = pn.widgets.PasswordInput( - name="Password", - css_classes=["auth_password_input"], - ) - - async def perform_login(self, event=None): - self.main_layout.loading = True - - home_path = pn.state.location.pathname.rstrip("/").rstrip("auth") - try: - authed = await self.api_wrapper.auth( - self.login_input.value, self.password_input.value - ) - - if authed: - # Sets the cookie on the JS side - self.custom_js = f""" document.cookie = "auth_token={self.api_wrapper.auth_token}; path:{home_path}"; """ - - except Exception: - authed = False - - if authed: - # perform redirect - pn.state.location.param.update(reload=True, pathname=home_path) - else: - self.feedback_message = "Authentication failed. Please retry." - - self.main_layout.loading = False - - @pn.depends("feedback_message") - def display_error_message(self): - if self.feedback_message is None: - return None - else: - return pn.pane.HTML( - f"""
{self.feedback_message}
""", - css_classes=["auth_error"], - ) - - @pn.depends("custom_js") - def wrapped_custom_js(self): - return pn.pane.HTML( - f""" - Log In", - css_classes=["auth_title"], - ), - self.display_error_message, - self.login_input, - self.password_input, - pn.pane.HTML("
"), - login_button, - css_classes=["auth_page_main_layout"], - ) - - return self.main_layout diff --git a/ragna/deploy/_ui/central_view.py b/ragna/deploy/_ui/central_view.py index 173c6d55..ae0cefc3 100644 --- a/ragna/deploy/_ui/central_view.py +++ b/ragna/deploy/_ui/central_view.py @@ -153,7 +153,7 @@ def _update_placeholder(self): show_timestamp=False, ) - def _build_message(self, *args, **kwargs) -> RagnaChatMessage | None: + def _build_message(self, *args, **kwargs) -> Optional[RagnaChatMessage]: message = super()._build_message(*args, **kwargs) if message is None: return None diff --git a/ragna/deploy/_ui/components/file_uploader.py b/ragna/deploy/_ui/components/file_uploader.py index c4fcbaae..f25fe973 100644 --- a/ragna/deploy/_ui/components/file_uploader.py +++ b/ragna/deploy/_ui/components/file_uploader.py @@ -17,10 +17,9 @@ class FileUploader(ReactiveHTML, Widget): # type: ignore[misc] title = param.String(default="") - def __init__(self, allowed_documents, token, informations_endpoint, **params): + def __init__(self, allowed_documents, informations_endpoint, **params): super().__init__(**params) - self.token = token self.informations_endpoint = informations_endpoint self.after_upload_callback = None @@ -56,7 +55,7 @@ def perform_upload(self, event=None, after_upload_callback=None): self.custom_js = ( final_callback_js + random_id - + f"""upload( self.get_upload_files(), '{self.token}', '{self.informations_endpoint}', final_callback) """ + + f"""upload( self.get_upload_files(), '{self.informations_endpoint}', final_callback) """ ) _child_config = { diff --git a/ragna/deploy/_ui/css/auth/button.css b/ragna/deploy/_ui/css/auth/button.css deleted file mode 100644 index e8a6ad3d..00000000 --- a/ragna/deploy/_ui/css/auth/button.css +++ /dev/null @@ -1,5 +0,0 @@ -:host(.auth_login_button) { - width: 100%; - margin-left: 0px; - margin-right: 0px; -} diff --git a/ragna/deploy/_ui/css/auth/column.css b/ragna/deploy/_ui/css/auth/column.css deleted file mode 100644 index d7bb60aa..00000000 --- a/ragna/deploy/_ui/css/auth/column.css +++ /dev/null @@ -1,23 +0,0 @@ -:host(.auth_page_main_layout) { - background-color: white; - border-radius: 5px; - box-shadow: lightgray 0px 0px 10px; - padding: 0 25px 0 25px; - - width: 30%; - min-width: 360px; - max-width: 430px; - - margin-left: auto; - margin-right: auto; - margin-top: 10%; -} - -:host(.auth_page_main_layout) > div { - margin-bottom: 10px; - margin-top: 10px; -} - -:host(.auth_page_main_layout) .bk-panel-models-layout-Column { - width: 100%; -} diff --git a/ragna/deploy/_ui/css/auth/html.css b/ragna/deploy/_ui/css/auth/html.css deleted file mode 100644 index 16c4490f..00000000 --- a/ragna/deploy/_ui/css/auth/html.css +++ /dev/null @@ -1,24 +0,0 @@ -:host(.auth_error) { - width: 100%; - margin-left: 0px; - margin-right: 0px; -} - -:host(.auth_error) div.auth_error { - width: 100%; - color: red; - text-align: center; - font-weight: 600; - font-size: 16px; -} - -:host(.auth_title) { - width: 100%; - margin-left: 0px; - margin-right: 0px; - text-align: center; -} -:host(.auth_title) h1 { - font-weight: 600; - font-size: 24px; -} diff --git a/ragna/deploy/_ui/css/auth/textinput.css b/ragna/deploy/_ui/css/auth/textinput.css deleted file mode 100644 index b6c16ce9..00000000 --- a/ragna/deploy/_ui/css/auth/textinput.css +++ /dev/null @@ -1,18 +0,0 @@ -:host(.auth_login_input), -:host(.auth_password_input) { - width: 100%; - margin-left: 0px; - margin-right: 0px; -} - -:host(.auth_login_input) label, -:host(.auth_password_input) label { - font-weight: 600; - font-size: 16px; -} - -:host(.auth_login_input) input, -:host(.auth_password_input) input { - background-color: white !important; - height: 38px; -} diff --git a/ragna/deploy/_ui/js_utils.py b/ragna/deploy/_ui/js_utils.py deleted file mode 100644 index 4497de2d..00000000 --- a/ragna/deploy/_ui/js_utils.py +++ /dev/null @@ -1,62 +0,0 @@ -import panel as pn - - -def preformat(text): - """Allows {{key}} to be used for formatting in textcthat already uses - curly braces. First switch this into something else, replace curlies - with double curlies, and then switch back to regular braces - """ - text = text.replace("{{", "<<<").replace("}}", ">>>") - text = text.replace("{", "{{").replace("}", "}}") - text = text.replace("<<<", "{").replace(">>>", "}") - return text - - -def redirect_script(remove, append="/", remove_auth_cookie=False): - """ - This function returns a js script to redirect to correct url. - :param remove: string to remove from the end of the url - :param append: string to append at the end of the url - :param remove_auth_cookie: boolean, will clear auth_token cookie when true. - :return: string javascript script - - Examples: - ========= - - # This will remove nothing from the end of the url and will - # add auth to it, so /foo/bar/car/ becomes /foo/bar/car/auth - >>> redirect_script(remove="", append="auth") - - # This will remove nothing from the end of the url and will - # add auth to it, so /foo/bar/car/ becomes /foo/bar/car/logout - >>> redirect_script(remove="", append="logout") - - # This will remove "auth" from the end of the url and will add / to it - # so /foo/bar/car/auth becomes /foo/bar/car/ - >>> redirect_script(remove="auth", append="/") - """ - js_script = preformat( - r""" - - """ - ) - - return pn.pane.HTML( - js_script.format( - remove=remove, - append=append, - remove_auth_cookie=str(remove_auth_cookie).lower(), - ) - ) diff --git a/ragna/deploy/_ui/logout_page.py b/ragna/deploy/_ui/logout_page.py deleted file mode 100644 index d86b77ab..00000000 --- a/ragna/deploy/_ui/logout_page.py +++ /dev/null @@ -1,21 +0,0 @@ -import panel as pn -import param - -from ragna.deploy._ui.js_utils import redirect_script - - -class LogoutPage(pn.viewable.Viewer, param.Parameterized): - def __init__(self, api_wrapper, **params): - super().__init__(**params) - self.api_wrapper = api_wrapper - - self.api_wrapper.auth_token = None - - def __panel__(self): - # Usually, we do a redirect this way : - # >>> pn.state.location.param.update(reload=True, pathname="/") - # But it only works once the page is fully loaded. - # So we render a javascript redirect instead. - - # To remove the token from the cookie, we have to force its expiry date to the past. - return redirect_script(remove="logout", append="/", remove_auth_cookie=True) diff --git a/ragna/deploy/_ui/modal_configuration.py b/ragna/deploy/_ui/modal_configuration.py index 70b02731..51ec02fb 100644 --- a/ragna/deploy/_ui/modal_configuration.py +++ b/ragna/deploy/_ui/modal_configuration.py @@ -89,7 +89,6 @@ def __init__(self, api_wrapper, **params): ) self.document_uploader = FileUploader( [], # the allowed documents are set in the model_section function - self.api_wrapper.auth_token, upload_endpoints["informations_endpoint"], ) diff --git a/ragna/deploy/_ui/resources/upload.js b/ragna/deploy/_ui/resources/upload.js index 1ecd54b2..905da833 100644 --- a/ragna/deploy/_ui/resources/upload.js +++ b/ragna/deploy/_ui/resources/upload.js @@ -1,8 +1,8 @@ -function upload(files, token, informationEndpoint, final_callback) { - uploadBatches(files, token, informationEndpoint).then(final_callback); +function upload(files, informationEndpoint, final_callback) { + uploadBatches(files, informationEndpoint).then(final_callback); } -async function uploadBatches(files, token, informationEndpoint) { +async function uploadBatches(files, informationEndpoint) { const batchSize = 500; const queue = Array.from(files); @@ -10,20 +10,20 @@ async function uploadBatches(files, token, informationEndpoint) { while (queue.length) { const batch = queue.splice(0, batchSize); - await Promise.all( - batch.map((file) => uploadFile(file, token, informationEndpoint)), - ).then((results) => { - uploaded.push(...results); - }); + await Promise.all(batch.map((file) => uploadFile(file, informationEndpoint))).then( + (results) => { + uploaded.push(...results); + }, + ); } return uploaded; } -async function uploadFile(file, token, informationEndpoint) { +async function uploadFile(file, informationEndpoint) { const response = await fetch(informationEndpoint, { method: "POST", - headers: { "Content-Type": "application/json", Authorization: `Bearer ${token}` }, + headers: { "Content-Type": "application/json" }, body: JSON.stringify({ name: file.name }), }); const documentUpload = await response.json(); diff --git a/ragna/deploy/_ui/styles.py b/ragna/deploy/_ui/styles.py index 7f994eeb..b70894e5 100644 --- a/ragna/deploy/_ui/styles.py +++ b/ragna/deploy/_ui/styles.py @@ -34,7 +34,6 @@ pn.pane.Markdown, ], "chat_info": [pn.pane.Markdown, pn.widgets.Button], - "auth": [pn.widgets.TextInput, pn.pane.HTML, pn.widgets.Button, pn.Column], "central_view": [pn.Column, pn.Row, pn.pane.HTML], "chat_interface": [ pn.widgets.TextInput, diff --git a/tests/deploy/api/utils.py b/tests/deploy/api/utils.py index abcf1411..fee4feff 100644 --- a/tests/deploy/api/utils.py +++ b/tests/deploy/api/utils.py @@ -6,6 +6,7 @@ def authenticate(client: TestClient) -> None: + return username = default_user() token = ( client.post( diff --git a/tests/test_js_utils.py b/tests/test_js_utils.py deleted file mode 100644 index 202af9e7..00000000 --- a/tests/test_js_utils.py +++ /dev/null @@ -1,47 +0,0 @@ -from textwrap import dedent - -from ragna.deploy._ui.js_utils import preformat, redirect_script - - -def test_preformat_basic(): - output = preformat("{ This is awesome {{var}} }") - assert output == "{{ This is awesome {var} }}" - - -def test_preformat_basic_fmt(): - output = preformat("{ This is awesome {{var}} }").format(var="test") - assert output == "{ This is awesome test }" - - -def test_preformat_multivars(): - output = preformat("{ {{var1}} This is awesome {{var2}} }").format( - var1="test1", var2="test2" - ) - assert output == "{ test1 This is awesome test2 }" - - -def test_preformat_unsubs(): - output = preformat("{ This is {Hello} awesome {{var}} }").format(var="test") - assert output == "{ This is {Hello} awesome test }" - - -def test_redirect_script(): - output = redirect_script(remove="foo", append="bar") - expected = dedent( - r""" - - """ - ) - assert dedent(output.object) == expected From 528b953ff962601b881b65401770d6b997bae125 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 12 Jun 2024 10:36:37 +0200 Subject: [PATCH 02/29] unify API and UI servers (#418) --- .github/actions/setup-env/action.yml | 2 +- ragna/__main__.py | 2 +- ragna/{deploy => }/_cli/__init__.py | 0 ragna/{deploy => }/_cli/config.py | 58 +++---- ragna/_cli/core.py | 124 +++++++++++++ ragna/_utils.py | 34 +--- ragna/core/_document.py | 2 +- ragna/deploy/_api/__init__.py | 2 +- ragna/deploy/_api/core.py | 65 ++----- ragna/deploy/_cli/core.py | 173 ------------------- ragna/deploy/_config.py | 110 +++++------- ragna/deploy/_core.py | 76 ++++++++ ragna/deploy/_ui/api_wrapper.py | 3 +- ragna/deploy/_ui/app.py | 37 ++-- ragna/deploy/_ui/components/file_uploader.py | 3 +- ragna/deploy/_utils.py | 57 ++++++ tests/deploy/api/test_batch_endpoints.py | 12 +- tests/deploy/api/test_components.py | 17 +- tests/deploy/api/test_e2e.py | 35 ++-- tests/deploy/api/utils.py | 10 ++ tests/deploy/test_config.py | 27 +-- 21 files changed, 404 insertions(+), 445 deletions(-) rename ragna/{deploy => }/_cli/__init__.py (100%) rename ragna/{deploy => }/_cli/config.py (88%) create mode 100644 ragna/_cli/core.py delete mode 100644 ragna/deploy/_cli/core.py create mode 100644 ragna/deploy/_core.py create mode 100644 ragna/deploy/_utils.py diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index 46b826af..d0392096 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -63,7 +63,7 @@ runs: run: | pip install \ git+https://github.com/bokeh/bokeh-fastapi.git@main \ - git+https://github.com/holoviz/panel.git@main + git+https://github.com/holoviz/panel@7377c9e99bef0d32cbc65e94e908e365211f4421 - name: Install ragna shell: bash -el {0} diff --git a/ragna/__main__.py b/ragna/__main__.py index 1435bf2a..4acdf3a3 100644 --- a/ragna/__main__.py +++ b/ragna/__main__.py @@ -1,4 +1,4 @@ -from ragna.deploy._cli import app +from ragna._cli import app if __name__ == "__main__": app() diff --git a/ragna/deploy/_cli/__init__.py b/ragna/_cli/__init__.py similarity index 100% rename from ragna/deploy/_cli/__init__.py rename to ragna/_cli/__init__.py diff --git a/ragna/deploy/_cli/config.py b/ragna/_cli/config.py similarity index 88% rename from ragna/deploy/_cli/config.py rename to ragna/_cli/config.py index 621d1ef9..73044ed9 100644 --- a/ragna/deploy/_cli/config.py +++ b/ragna/_cli/config.py @@ -197,7 +197,7 @@ def _handle_unmet_requirements(components: Iterable[Type[Component]]) -> None: return rich.print( - "You have selected components, which have additional requirements that are" + "You have selected components, which have additional requirements that are " "currently not met." ) unmet_requirements_by_type = _split_requirements(unmet_requirements) @@ -251,51 +251,37 @@ def _wizard_common() -> Config: ).unsafe_ask() ) - for sub_config, title in [(config.api, "REST API"), (config.ui, "web UI")]: - sub_config.hostname = questionary.text( # type: ignore[attr-defined] - f"What hostname do you want to bind the the Ragna {title} to?", - default=sub_config.hostname, # type: ignore[attr-defined] - qmark=QMARK, - ).unsafe_ask() - - sub_config.port = int( # type: ignore[attr-defined] - questionary.text( - f"What port do you want to bind the the Ragna {title} to?", - default=str(sub_config.port), # type: ignore[attr-defined] - qmark=QMARK, - ).unsafe_ask() - ) - - config.api.database_url = questionary.text( - "What is the URL of the SQL database?", - default=Config(local_root=config.local_root).api.database_url, + config.hostname = questionary.text( + "What hostname do you want to bind the the Ragna server to?", + default=config.hostname, qmark=QMARK, ).unsafe_ask() - config.api.url = questionary.text( - "At which URL will the Ragna REST API be served?", - default=Config( - api=dict( # type: ignore[arg-type] - hostname=config.api.hostname, - port=config.api.port, - ) - ).api.url, - qmark=QMARK, - ).unsafe_ask() + config.port = int( + questionary.text( + "What port do you want to bind the the Ragna server to?", + default=str(config.port), + qmark=QMARK, + ).unsafe_ask() + ) - config.api.origins = config.ui.origins = [ + config.origins = [ questionary.text( - "At which URL will the Ragna web UI be served?", + "At which URL will Ragna be served?", default=Config( - ui=dict( # type: ignore[arg-type] - hostname=config.ui.hostname, - port=config.ui.port, - ) - ).api.origins[0], + hostname=config.hostname, + port=config.port, + ).origins[0], qmark=QMARK, ).unsafe_ask() ] + config.database_url = questionary.text( + "What is the URL of the SQL database?", + default=Config(local_root=config.local_root).database_url, + qmark=QMARK, + ).unsafe_ask() + return config diff --git a/ragna/_cli/core.py b/ragna/_cli/core.py new file mode 100644 index 00000000..6961b87b --- /dev/null +++ b/ragna/_cli/core.py @@ -0,0 +1,124 @@ +from pathlib import Path +from typing import Annotated, Optional + +import httpx +import rich +import typer +import uvicorn + +import ragna +from ragna.deploy._core import make_app + +from .config import ConfigOption, check_config, init_config + +app = typer.Typer( + name="Ragna", + invoke_without_command=True, + no_args_is_help=True, + add_completion=False, + pretty_exceptions_enable=False, +) + + +def version_callback(value: bool) -> None: + if value: + rich.print(f"ragna {ragna.__version__} from {ragna.__path__[0]}") + raise typer.Exit() + + +@app.callback() +def _main( + version: Annotated[ + Optional[bool], + typer.Option( + "--version", callback=version_callback, help="Show version and exit." + ), + ] = None, +) -> None: + pass + + +@app.command(help="Start a wizard to build a Ragna configuration interactively.") +def init( + *, + output_path: Annotated[ + Path, + typer.Option( + "-o", + "--output-file", + metavar="OUTPUT_PATH", + default_factory=lambda: Path.cwd() / "ragna.toml", + show_default="./ragna.toml", + help="Write configuration to .", + ), + ], + force: Annotated[ + bool, + typer.Option( + "-f", "--force", help="Overwrite an existing file at ." + ), + ] = False, +) -> None: + config, output_path, force = init_config(output_path=output_path, force=force) + config.to_file(output_path, force=force) + + +@app.command(help="Check the availability of components.") +def check(config: ConfigOption = "./ragna.toml") -> None: # type: ignore[assignment] + is_available = check_config(config) + raise typer.Exit(int(not is_available)) + + +@app.command(help="Deploy Ragna REST API and web UI.") +def deploy( + *, + config: ConfigOption = "./ragna.toml", # type: ignore[assignment] + api: Annotated[ + Optional[bool], + typer.Option( + "--api/--no-api", + help="Deploy the Ragna REST API.", + show_default="True if UI is not deployed and otherwise check availability", + ), + ] = None, + ui: Annotated[ + bool, + typer.Option( + help="Deploy the Ragna web UI.", + ), + ] = True, + ignore_unavailable_components: Annotated[ + bool, + typer.Option( + help=( + "Ignore components that are not available, " + "i.e. their requirements are not met. " + ) + ), + ] = False, +) -> None: + def api_available() -> bool: + try: + return httpx.get(f"{config._url}/health").is_success + except httpx.ConnectError: + return False + + if api is None: + api = not api_available() if ui else True + + if not (api or ui): + raise Exception + elif ui and not api and not api_available(): + raise Exception + + uvicorn.run( + lambda: make_app( + config, + ui=ui, + api=api, + ignore_unavailable_components=ignore_unavailable_components, + ), + factory=True, + host=config.hostname, + port=config.port, + ) diff --git a/ragna/_utils.py b/ragna/_utils.py index a71d11a5..efa6f95c 100644 --- a/ragna/_utils.py +++ b/ragna/_utils.py @@ -5,7 +5,6 @@ import threading from pathlib import Path from typing import Any, Callable, Optional, Union -from urllib.parse import SplitResult, urlsplit, urlunsplit _LOCAL_ROOT = ( Path(os.environ.get("RAGNA_LOCAL_ROOT", "~/.cache/ragna")).expanduser().resolve() @@ -28,7 +27,7 @@ def local_root(path: Optional[Union[str, Path]] = None) -> Path: path: If passed, this is set as new local root directory. Returns: - Ragnas local root directory. + Ragna's local root directory. """ global _LOCAL_ROOT if path is not None: @@ -59,37 +58,6 @@ def fix_module(globals: dict[str, Any]) -> None: obj.__module__ = globals["__package__"] -def _replace_hostname(split_result: SplitResult, hostname: str) -> SplitResult: - # This is a separate function, since hostname is not an element of the SplitResult - # namedtuple, but only a property. Thus, we need to replace the netloc item, from - # which the hostname is generated. - if split_result.port is None: - netloc = hostname - else: - netloc = f"{hostname}:{split_result.port}" - return split_result._replace(netloc=netloc) - - -def handle_localhost_origins(origins: list[str]) -> list[str]: - # Since localhost is an alias for 127.0.0.1, we allow both so users and developers - # don't need to worry about it. - localhost_origins = { - components.hostname: components - for url in origins - if (components := urlsplit(url)).hostname in {"127.0.0.1", "localhost"} - } - if "127.0.0.1" in localhost_origins: - origins.append( - urlunsplit(_replace_hostname(localhost_origins["127.0.0.1"], "localhost")) - ) - elif "localhost" in localhost_origins: - origins.append( - urlunsplit(_replace_hostname(localhost_origins["localhost"], "127.0.0.1")) - ) - - return origins - - def timeout_after( seconds: float = 30, *, message: str = "" ) -> Callable[[Callable], Callable]: diff --git a/ragna/core/_document.py b/ragna/core/_document.py index 878742eb..436344b4 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -142,7 +142,7 @@ def read(self) -> bytes: async def get_upload_info( cls, *, config: Config, user: str, id: uuid.UUID, name: str ) -> tuple[dict[str, Any], DocumentUploadParameters]: - url = f"{config.api.url}/document" + url = f"{config._url}/api/document" data = { "token": jwt.encode( payload={ diff --git a/ragna/deploy/_api/__init__.py b/ragna/deploy/_api/__init__.py index 93eefb4d..f99fb828 100644 --- a/ragna/deploy/_api/__init__.py +++ b/ragna/deploy/_api/__init__.py @@ -1 +1 @@ -from .core import app +from .core import make_router diff --git a/ragna/deploy/_api/core.py b/ragna/deploy/_api/core.py index 0f04802e..9bb9c682 100644 --- a/ragna/deploy/_api/core.py +++ b/ragna/deploy/_api/core.py @@ -4,23 +4,20 @@ import aiofiles from fastapi import ( + APIRouter, Body, Depends, - FastAPI, Form, HTTPException, - Request, UploadFile, status, ) -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.responses import StreamingResponse from pydantic import BaseModel import ragna import ragna.core from ragna._compat import aiter, anext -from ragna._utils import handle_localhost_origins from ragna.core import Assistant, Component, Rag, RagnaException, SourceStorage from ragna.core._rag import SpecialChatParams from ragna.core._utils import default_user @@ -29,8 +26,8 @@ from . import database, schemas -def app(*, config: Config, ignore_unavailable_components: bool) -> FastAPI: - ragna.local_root(config.local_root) +def make_router(config: Config, ignore_unavailable_components: bool) -> APIRouter: + router = APIRouter(tags=["API"]) rag = Rag() # type: ignore[var-annotated] components_map: dict[str, Component] = {} @@ -67,35 +64,7 @@ def get_component(display_name: str) -> Component: return component - app = FastAPI( - title="ragna", - version=ragna.__version__, - root_path=config.api.root_path, - ) - app.add_middleware( - CORSMiddleware, - allow_origins=handle_localhost_origins(config.api.origins), - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - @app.exception_handler(RagnaException) - async def ragna_exception_handler( - request: Request, exc: RagnaException - ) -> JSONResponse: - if exc.http_detail is RagnaException.EVENT: - detail = exc.event - elif exc.http_detail is RagnaException.MESSAGE: - detail = str(exc) - else: - detail = cast(str, exc.http_detail) - return JSONResponse( - status_code=exc.http_status_code, - content={"error": {"message": detail}}, - ) - - @app.get("/") + @router.get("/") async def version() -> str: return ragna.__version__ @@ -120,7 +89,7 @@ def _get_component_json_schema( json_schema["required"].remove(special_param) return json_schema - @app.get("/components") + @router.get("/components") async def get_components(_: UserDependency) -> schemas.Components: return schemas.Components( documents=sorted(config.document.supported_suffixes()), @@ -136,14 +105,14 @@ async def get_components(_: UserDependency) -> schemas.Components: ], ) - make_session = database.get_sessionmaker(config.api.database_url) + make_session = database.get_sessionmaker(config.database_url) @contextlib.contextmanager def get_session() -> Iterator[database.Session]: with make_session() as session: # type: ignore[attr-defined] yield session - @app.post("/document") + @router.post("/document") async def create_document_upload_info( user: UserDependency, name: Annotated[str, Body(..., embed=True)], @@ -159,7 +128,7 @@ async def create_document_upload_info( return schemas.DocumentUpload(parameters=parameters, document=document) # TODO: Add UI support and documentation for this endpoint (#406) - @app.post("/documents") + @router.post("/documents") async def create_documents_upload_info( user: UserDependency, names: Annotated[list[str], Body(..., embed=True)], @@ -185,7 +154,7 @@ async def create_documents_upload_info( return document_upload_collection # TODO: Add new endpoint for batch uploading documents (#407) - @app.put("/document") + @router.put("/document") async def upload_document( token: Annotated[str, Form()], file: UploadFile ) -> schemas.Document: @@ -240,7 +209,7 @@ def schema_to_core_chat( return core_chat - @app.post("/chats") + @router.post("/chats") async def create_chat( user: UserDependency, chat_metadata: schemas.ChatMetadata, @@ -255,17 +224,17 @@ async def create_chat( database.add_chat(session, user=user, chat=chat) return chat - @app.get("/chats") + @router.get("/chats") async def get_chats(user: UserDependency) -> list[schemas.Chat]: with get_session() as session: return database.get_chats(session, user=user) - @app.get("/chats/{id}") + @router.get("/chats/{id}") async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat: with get_session() as session: return database.get_chat(session, user=user, id=id) - @app.post("/chats/{id}/prepare") + @router.post("/chats/{id}/prepare") async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message: with get_session() as session: chat = database.get_chat(session, user=user, id=id) @@ -280,7 +249,7 @@ async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message: return welcome - @app.post("/chats/{id}/answer") + @router.post("/chats/{id}/answer") async def answer( user: UserDependency, id: uuid.UUID, @@ -341,9 +310,9 @@ async def to_jsonl(models: AsyncIterator[Any]) -> AsyncIterator[str]: return answer - @app.delete("/chats/{id}") + @router.delete("/chats/{id}") async def delete_chat(user: UserDependency, id: uuid.UUID) -> None: with get_session() as session: database.delete_chat(session, user=user, id=id) - return app + return router diff --git a/ragna/deploy/_cli/core.py b/ragna/deploy/_cli/core.py deleted file mode 100644 index fa493c21..00000000 --- a/ragna/deploy/_cli/core.py +++ /dev/null @@ -1,173 +0,0 @@ -import subprocess -import sys -import time -from pathlib import Path -from typing import Annotated, Optional - -import httpx -import rich -import typer -import uvicorn - -import ragna -from ragna._utils import timeout_after -from ragna.deploy._api import app as api_app -from ragna.deploy._ui import app as ui_app - -from .config import ConfigOption, check_config, init_config - -app = typer.Typer( - name="ragna", - invoke_without_command=True, - no_args_is_help=True, - add_completion=False, - pretty_exceptions_enable=False, -) - - -def version_callback(value: bool) -> None: - if value: - rich.print(f"ragna {ragna.__version__} from {ragna.__path__[0]}") - raise typer.Exit() - - -@app.callback() -def _main( - version: Annotated[ - Optional[bool], - typer.Option( - "--version", callback=version_callback, help="Show version and exit." - ), - ] = None, -) -> None: - pass - - -@app.command(help="Start a wizard to build a Ragna configuration interactively.") -def init( - *, - output_path: Annotated[ - Path, - typer.Option( - "-o", - "--output-file", - metavar="OUTPUT_PATH", - default_factory=lambda: Path.cwd() / "ragna.toml", - show_default="./ragna.toml", - help="Write configuration to .", - ), - ], - force: Annotated[ - bool, - typer.Option( - "-f", "--force", help="Overwrite an existing file at ." - ), - ] = False, -) -> None: - config, output_path, force = init_config(output_path=output_path, force=force) - config.to_file(output_path, force=force) - - -@app.command(help="Check the availability of components.") -def check(config: ConfigOption = "./ragna.toml") -> None: # type: ignore[assignment] - is_available = check_config(config) - raise typer.Exit(int(not is_available)) - - -@app.command(help="Start the REST API.") -def api( - *, - config: ConfigOption = "./ragna.toml", # type: ignore[assignment] - ignore_unavailable_components: Annotated[ - bool, - typer.Option( - help=( - "Ignore components that are not available, " - "i.e. their requirements are not met. " - ) - ), - ] = False, -) -> None: - uvicorn.run( - api_app( - config=config, ignore_unavailable_components=ignore_unavailable_components - ), - host=config.api.hostname, - port=config.api.port, - ) - - -@app.command(help="Start the web UI.") -def ui( - *, - config: ConfigOption = "./ragna.toml", # type: ignore[assignment] - start_api: Annotated[ - Optional[bool], - typer.Option( - help="Start the ragna REST API alongside the web UI in a subprocess.", - show_default="Start if the API is not served at the configured URL.", - ), - ] = None, - ignore_unavailable_components: Annotated[ - bool, - typer.Option( - help=( - "Ignore components that are not available, " - "i.e. their requirements are not met. " - "This option as no effect if --no-start-api is used." - ) - ), - ] = False, - open_browser: Annotated[ - bool, - typer.Option(help="Open the web UI in the browser when it is started."), - ] = True, -) -> None: - def check_api_available() -> bool: - try: - return httpx.get(config.api.url).is_success - except httpx.ConnectError: - return False - - if start_api is None: - start_api = not check_api_available() - - if start_api: - process = subprocess.Popen( - [ - sys.executable, - "-m", - "ragna", - "api", - "--config", - config.__ragna_cli_config_path__, # type: ignore[attr-defined] - f"--{'' if ignore_unavailable_components else 'no-'}ignore-unavailable-components", - ], - stdout=sys.stdout, - stderr=sys.stderr, - ) - else: - process = None - - try: - if process is not None: - - @timeout_after(60) - def wait_for_api() -> None: - while not check_api_available(): - time.sleep(0.5) - - try: - wait_for_api() - except TimeoutError: - rich.print( - "Failed to start the API in 60 seconds. " - "Please start it manually with [bold]ragna api[/bold]." - ) - raise typer.Exit(1) - - ui_app(config=config, open_browser=open_browser).serve() # type: ignore[no-untyped-call] - finally: - if process is not None: - process.kill() - process.communicate() diff --git a/ragna/deploy/_config.py b/ragna/deploy/_config.py index fa7a01f9..e960f831 100644 --- a/ragna/deploy/_config.py +++ b/ragna/deploy/_config.py @@ -7,12 +7,7 @@ import tomlkit import tomlkit.container import tomlkit.items -from pydantic import ( - AfterValidator, - Field, - ImportString, - model_validator, -) +from pydantic import AfterValidator, Field, ImportString, model_validator from pydantic_settings import ( BaseSettings, PydanticBaseSettingsSource, @@ -52,7 +47,11 @@ def make(cls, make_default: Callable[[Config], T]) -> Any: return Field(default=cls(make_default), validate_default=False) -class ConfigBase(BaseSettings): +class Config(BaseSettings): + """Ragna configuration""" + + model_config = SettingsConfigDict(env_prefix="ragna_") + @classmethod def settings_customise_sources( cls, @@ -76,14 +75,48 @@ def settings_customise_sources( # 4. Default return env_settings, init_settings + local_root: Annotated[Path, AfterValidator(make_directory)] = Field( + default_factory=ragna.local_root + ) + + authentication: ImportString[type[Authentication]] = ( + "ragna.deploy.RagnaDemoAuthentication" # type: ignore[assignment] + ) + + document: ImportString[type[Document]] = "ragna.core.LocalDocument" # type: ignore[assignment] + source_storages: list[ImportString[type[SourceStorage]]] = [ + "ragna.source_storages.RagnaDemoSourceStorage" # type: ignore[list-item] + ] + assistants: list[ImportString[type[Assistant]]] = [ + "ragna.assistants.RagnaDemoAssistant" # type: ignore[list-item] + ] + + hostname: str = "127.0.0.1" + port: int = 31476 + root_path: str = "" + origins: list[str] = AfterConfigValidateDefault.make( + lambda config: [f"http://{config.hostname}:{config.port}"] + ) + + database_url: str = AfterConfigValidateDefault.make( + lambda config: f"sqlite:///{config.local_root}/ragna.db", + ) + + @model_validator(mode="after") + def _validate_model(self) -> Config: + self._resolve_default_sentinels(self) + return self + def _resolve_default_sentinels(self, config: Config) -> None: for name, info in self.model_fields.items(): value = getattr(self, name) - if isinstance(value, ConfigBase): - value._resolve_default_sentinels(config) - elif isinstance(value, AfterConfigValidateDefault): + if isinstance(value, AfterConfigValidateDefault): setattr(self, name, value.make_default(config)) + @property + def _url(self) -> str: + return f"http://{self.hostname}:{self.port}{self.root_path}" + def __str__(self) -> str: toml = tomlkit.item(self.model_dump(mode="json")) self._set_multiline_array(toml) @@ -102,63 +135,6 @@ def _set_multiline_array(self, item: tomlkit.items.Item) -> None: ): self._set_multiline_array(child) - -def make_default_origins(config: Config) -> list[str]: - return [f"http://{config.ui.hostname}:{config.ui.port}"] - - -class ApiConfig(ConfigBase): - model_config = SettingsConfigDict(env_prefix="ragna_api_") - - hostname: str = "127.0.0.1" - port: int = 31476 - root_path: str = "" - url: str = AfterConfigValidateDefault.make( - lambda config: f"http://{config.api.hostname}:{config.api.port}{config.api.root_path}", - ) - database_url: str = AfterConfigValidateDefault.make( - lambda config: f"sqlite:///{config.local_root}/ragna.db", - ) - origins: list[str] = AfterConfigValidateDefault.make(make_default_origins) - - -class UiConfig(ConfigBase): - model_config = SettingsConfigDict(env_prefix="ragna_ui_") - - hostname: str = "127.0.0.1" - port: int = 31477 - origins: list[str] = AfterConfigValidateDefault.make(make_default_origins) - - -class Config(ConfigBase): - """Ragna configuration""" - - model_config = SettingsConfigDict(env_prefix="ragna_") - - local_root: Annotated[Path, AfterValidator(make_directory)] = Field( - default_factory=ragna.local_root - ) - - authentication: ImportString[type[Authentication]] = ( - "ragna.deploy.RagnaDemoAuthentication" # type: ignore[assignment] - ) - - document: ImportString[type[Document]] = "ragna.core.LocalDocument" # type: ignore[assignment] - source_storages: list[ImportString[type[SourceStorage]]] = [ - "ragna.source_storages.RagnaDemoSourceStorage" # type: ignore[list-item] - ] - assistants: list[ImportString[type[Assistant]]] = [ - "ragna.assistants.RagnaDemoAssistant" # type: ignore[list-item] - ] - - api: ApiConfig = Field(default_factory=ApiConfig) - ui: UiConfig = Field(default_factory=UiConfig) - - @model_validator(mode="after") - def _validate_model(self) -> Config: - self._resolve_default_sentinels(self) - return self - @classmethod def from_file(cls, path: Union[str, Path]) -> Config: path = Path(path).expanduser().resolve() diff --git a/ragna/deploy/_core.py b/ragna/deploy/_core.py new file mode 100644 index 00000000..fa516e13 --- /dev/null +++ b/ragna/deploy/_core.py @@ -0,0 +1,76 @@ +from typing import cast + +from fastapi import FastAPI, Request, status +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, Response + +import ragna +from ragna.core import RagnaException + +from ._api import make_router as make_api_router +from ._config import Config +from ._ui import app as make_ui_app +from ._utils import handle_localhost_origins, redirect, set_redirect_root_path + + +def make_app( + config: Config, + *, + api: bool, + ui: bool, + ignore_unavailable_components: bool, +) -> FastAPI: + ragna.local_root(config.local_root) + set_redirect_root_path(config.root_path) + + app = FastAPI(title="Ragna", version=ragna.__version__) + + app.add_middleware( + CORSMiddleware, + allow_origins=handle_localhost_origins(config.origins), + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + if api: + app.include_router( + make_api_router( + config, + ignore_unavailable_components=ignore_unavailable_components, + ), + prefix="/api", + ) + + if ui: + panel_app = make_ui_app(config=config) + panel_app.serve_with_fastapi(app, endpoint="/ui") + + @app.get("/", include_in_schema=False) + async def base_redirect() -> Response: + return redirect("/ui" if ui else "/docs") + + @app.get("/health") + async def health() -> Response: + return Response(b"", status_code=status.HTTP_200_OK) + + @app.get("/version") + async def version() -> str: + return ragna.__version__ + + @app.exception_handler(RagnaException) + async def ragna_exception_handler( + request: Request, exc: RagnaException + ) -> JSONResponse: + if exc.http_detail is RagnaException.EVENT: + detail = exc.event + elif exc.http_detail is RagnaException.MESSAGE: + detail = str(exc) + else: + detail = cast(str, exc.http_detail) + return JSONResponse( + status_code=exc.http_status_code, + content={"error": {"message": detail}}, + ) + + return app diff --git a/ragna/deploy/_ui/api_wrapper.py b/ragna/deploy/_ui/api_wrapper.py index a57f88e4..5fb7b42d 100644 --- a/ragna/deploy/_ui/api_wrapper.py +++ b/ragna/deploy/_ui/api_wrapper.py @@ -9,6 +9,7 @@ # The goal is this class is to provide ready-to-use functions to interact with the API class ApiWrapper(param.Parameterized): def __init__(self, api_url, **params): + self.api_url = api_url self.client = httpx.AsyncClient(base_url=api_url, timeout=60) super().__init__(**params) @@ -34,7 +35,7 @@ async def get_components(self): # Upload and related functions def upload_endpoints(self): return { - "informations_endpoint": f"{self.client.base_url}/document", + "informations_endpoint": f"{self.api_url}/document", } async def start_and_prepare( diff --git a/ragna/deploy/_ui/app.py b/ragna/deploy/_ui/app.py index dd566e37..49b8d628 100644 --- a/ragna/deploy/_ui/app.py +++ b/ragna/deploy/_ui/app.py @@ -6,7 +6,6 @@ from fastapi import FastAPI from fastapi.staticfiles import StaticFiles -from ragna._utils import handle_localhost_origins from ragna.deploy import Config from . import js @@ -23,17 +22,13 @@ class App(param.Parameterized): - def __init__(self, *, hostname, port, api_url, origins, open_browser): + def __init__(self, *, api_url): super().__init__() # Apply the design modifiers to the panel components # It returns all the CSS files of the modifiers self.css_filepaths = ui.apply_design_modifiers() - self.hostname = hostname - self.port = port self.api_url = api_url - self.origins = origins - self.open_browser = open_browser def get_template(self): # A bit hacky, but works. @@ -88,7 +83,7 @@ def index_page(self): def health_page(self): return pn.pane.HTML("

Ok

") - def add_panel_app(self, server, panel_app_fn): + def add_panel_app(self, server, panel_app_fn, endpoint): # FIXME: this code will ultimately be distributed as part of panel from functools import partial @@ -120,9 +115,11 @@ def panel_app(doc): handler = FunctionHandler(panel_app) application = Application(handler) - BokehFastAPI(application, server=server) + BokehFastAPI({endpoint: application}, server=server) - @server.get(f"/{COMPONENT_PATH.rstrip('/')}" + "/{path:path}") + @server.get( + f"/{COMPONENT_PATH.rstrip('/')}" + "/{path:path}", include_in_schema=False + ) def get_component_resource(path: str): # ComponentResourceHandler.parse_url_path only ever accesses # self._resource_attrs, which fortunately is a class attribute. Thus, we can @@ -131,9 +128,8 @@ def get_component_resource(path: str): resolved_path = ComponentResourceHandler.parse_url_path(self_, path) return FileResponse(resolved_path) - def make_app(self): - app = FastAPI() - self.add_panel_app(app, self.index_page) + def serve_with_fastapi(self, app: FastAPI, endpoint: str): + self.add_panel_app(app, self.index_page, endpoint) for dir in ["css", "imgs", "resources"]: app.mount( @@ -142,19 +138,6 @@ def make_app(self): name=dir, ) - return app - - def serve(self): - import uvicorn - - uvicorn.run(self.make_app, factory=True, host=self.hostname, port=self.port) - -def app(*, config: Config, open_browser: bool) -> App: - return App( - hostname=config.ui.hostname, - port=config.ui.port, - api_url=config.api.url, - origins=handle_localhost_origins(config.ui.origins), - open_browser=open_browser, - ) +def app(*, config: Config) -> App: + return App(api_url=f"{config._url}/api") diff --git a/ragna/deploy/_ui/components/file_uploader.py b/ragna/deploy/_ui/components/file_uploader.py index f25fe973..1568bc61 100644 --- a/ragna/deploy/_ui/components/file_uploader.py +++ b/ragna/deploy/_ui/components/file_uploader.py @@ -133,9 +133,8 @@ def perform_upload(self, event=None, after_upload_callback=None):
diff --git a/ragna/deploy/_utils.py b/ragna/deploy/_utils.py new file mode 100644 index 00000000..4f369a52 --- /dev/null +++ b/ragna/deploy/_utils.py @@ -0,0 +1,57 @@ +from typing import Optional +from urllib.parse import SplitResult, urlsplit, urlunsplit + +from fastapi import status +from fastapi.responses import RedirectResponse + +from ragna.core import RagnaException + +_REDIRECT_ROOT_PATH: Optional[str] = None + + +def set_redirect_root_path(root_path: str) -> None: + global _REDIRECT_ROOT_PATH + _REDIRECT_ROOT_PATH = root_path + + +def redirect( + url: str, *, status_code: int = status.HTTP_303_SEE_OTHER +) -> RedirectResponse: + if _REDIRECT_ROOT_PATH is None: + raise RagnaException + + if url.startswith("/"): + url = _REDIRECT_ROOT_PATH + url + + return RedirectResponse(url, status_code=status_code) + + +def handle_localhost_origins(origins: list[str]) -> list[str]: + # Since localhost is an alias for 127.0.0.1, we allow both so users and developers + # don't need to worry about it. + localhost_origins = { + components.hostname: components + for url in origins + if (components := urlsplit(url)).hostname in {"127.0.0.1", "localhost"} + } + if "127.0.0.1" in localhost_origins and "localhost" not in localhost_origins: + origins.append( + urlunsplit(_replace_hostname(localhost_origins["127.0.0.1"], "localhost")) + ) + elif "localhost" in localhost_origins and "127.0.0.1" not in localhost_origins: + origins.append( + urlunsplit(_replace_hostname(localhost_origins["localhost"], "127.0.0.1")) + ) + + return origins + + +def _replace_hostname(split_result: SplitResult, hostname: str) -> SplitResult: + # This is a separate function, since hostname is not an element of the SplitResult + # namedtuple, but only a property. Thus, we need to replace the netloc item, from + # which the hostname is generated. + if split_result.port is None: + netloc = hostname + else: + netloc = f"{hostname}:{split_result.port}" + return split_result._replace(netloc=netloc) diff --git a/tests/deploy/api/test_batch_endpoints.py b/tests/deploy/api/test_batch_endpoints.py index 94740750..2736df24 100644 --- a/tests/deploy/api/test_batch_endpoints.py +++ b/tests/deploy/api/test_batch_endpoints.py @@ -2,9 +2,8 @@ from fastapi.testclient import TestClient from ragna.deploy import Config -from ragna.deploy._api import app -from .utils import authenticate +from .utils import authenticate, make_api_app def test_batch_sequential_upload_equivalence(tmp_local_root): @@ -21,24 +20,25 @@ def test_batch_sequential_upload_equivalence(tmp_local_root): file.write("?\n") with TestClient( - app(config=Config(), ignore_unavailable_components=False) + make_api_app(config=Config(), ignore_unavailable_components=False) ) as client: authenticate(client) document1_upload = ( - client.post("/document", json={"name": document_path1.name}) + client.post("/api/document", json={"name": document_path1.name}) .raise_for_status() .json() ) document2_upload = ( - client.post("/document", json={"name": document_path2.name}) + client.post("/api/document", json={"name": document_path2.name}) .raise_for_status() .json() ) documents_upload = ( client.post( - "/documents", json={"names": [document_path1.name, document_path2.name]} + "/api/documents", + json={"names": [document_path1.name, document_path2.name]}, ) .raise_for_status() .json() diff --git a/tests/deploy/api/test_components.py b/tests/deploy/api/test_components.py index 65f02209..b459f12e 100644 --- a/tests/deploy/api/test_components.py +++ b/tests/deploy/api/test_components.py @@ -5,9 +5,8 @@ from ragna import assistants from ragna.core import RagnaException from ragna.deploy import Config -from ragna.deploy._api import app -from .utils import authenticate +from .utils import authenticate, make_api_app @pytest.mark.parametrize("ignore_unavailable_components", [True, False]) @@ -22,20 +21,20 @@ def test_ignore_unavailable_components(ignore_unavailable_components): if ignore_unavailable_components: with TestClient( - app( + make_api_app( config=config, ignore_unavailable_components=ignore_unavailable_components, ) ) as client: authenticate(client) - components = client.get("/components").raise_for_status().json() + components = client.get("/api/components").raise_for_status().json() assert [assistant["title"] for assistant in components["assistants"]] == [ available_assistant.display_name() ] else: with pytest.raises(RagnaException, match="not available"): - app( + make_api_app( config=config, ignore_unavailable_components=ignore_unavailable_components, ) @@ -48,7 +47,7 @@ def test_ignore_unavailable_components_at_least_one(): config = Config(assistants=[unavailable_assistant]) with pytest.raises(RagnaException, match="No component available"): - app( + make_api_app( config=config, ignore_unavailable_components=True, ) @@ -64,12 +63,12 @@ def test_unknown_component(tmp_local_root): file.write("!\n") with TestClient( - app(config=Config(), ignore_unavailable_components=False) + make_api_app(config=Config(), ignore_unavailable_components=False) ) as client: authenticate(client) document_upload = ( - client.post("/document", json={"name": document_path.name}) + client.post("/api/document", json={"name": document_path.name}) .raise_for_status() .json() ) @@ -86,7 +85,7 @@ def test_unknown_component(tmp_local_root): ) response = client.post( - "/chats", + "/api/chats", json={ "name": "test-chat", "source_storage": "unknown_source_storage", diff --git a/tests/deploy/api/test_e2e.py b/tests/deploy/api/test_e2e.py index 41b154db..c1a80ad5 100644 --- a/tests/deploy/api/test_e2e.py +++ b/tests/deploy/api/test_e2e.py @@ -6,9 +6,8 @@ from ragna.assistants import RagnaDemoAssistant from ragna.deploy import Config -from ragna.deploy._api import app -from .utils import authenticate +from .utils import authenticate, make_api_app class TestAssistant(RagnaDemoAssistant): @@ -37,13 +36,15 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): with open(document_path, "w") as file: file.write("!\n") - with TestClient(app(config=config, ignore_unavailable_components=False)) as client: + with TestClient( + make_api_app(config=config, ignore_unavailable_components=False) + ) as client: authenticate(client) - assert client.get("/chats").raise_for_status().json() == [] + assert client.get("/api/chats").raise_for_status().json() == [] document_upload = ( - client.post("/document", json={"name": document_path.name}) + client.post("/api/document", json={"name": document_path.name}) .raise_for_status() .json() ) @@ -59,7 +60,7 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): files={"file": file}, ) - components = client.get("/components").raise_for_status().json() + components = client.get("/api/components").raise_for_status().json() documents = components["documents"] assert set(documents) == config.document.supported_suffixes() source_storages = [ @@ -83,19 +84,21 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): "params": {"multiple_answer_chunks": multiple_answer_chunks}, "documents": [document], } - chat = client.post("/chats", json=chat_metadata).raise_for_status().json() + chat = client.post("/api/chats", json=chat_metadata).raise_for_status().json() assert chat["metadata"] == chat_metadata assert not chat["prepared"] assert chat["messages"] == [] - assert client.get("/chats").raise_for_status().json() == [chat] - assert client.get(f"/chats/{chat['id']}").raise_for_status().json() == chat + assert client.get("/api/chats").raise_for_status().json() == [chat] + assert client.get(f"/api/chats/{chat['id']}").raise_for_status().json() == chat - message = client.post(f"/chats/{chat['id']}/prepare").raise_for_status().json() + message = ( + client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status().json() + ) assert message["role"] == "system" assert message["sources"] == [] - chat = client.get(f"/chats/{chat['id']}").raise_for_status().json() + chat = client.get(f"/api/chats/{chat['id']}").raise_for_status().json() assert chat["prepared"] assert len(chat["messages"]) == 1 assert chat["messages"][-1] == message @@ -104,7 +107,7 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): if stream_answer: with client.stream( "POST", - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": prompt, "stream": True}, ) as response: chunks = [json.loads(chunk) for chunk in response.iter_lines()] @@ -113,7 +116,7 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): message["content"] = "".join(chunk["content"] for chunk in chunks) else: message = ( - client.post(f"/chats/{chat['id']}/answer", json={"prompt": prompt}) + client.post(f"/api/chats/{chat['id']}/answer", json={"prompt": prompt}) .raise_for_status() .json() ) @@ -123,7 +126,7 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): document_path.name } - chat = client.get(f"/chats/{chat['id']}").raise_for_status().json() + chat = client.get(f"/api/chats/{chat['id']}").raise_for_status().json() assert len(chat["messages"]) == 3 assert ( chat["messages"][-2]["role"] == "user" @@ -132,5 +135,5 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): ) assert chat["messages"][-1] == message - client.delete(f"/chats/{chat['id']}").raise_for_status() - assert client.get("/chats").raise_for_status().json() == [] + client.delete(f"/api/chats/{chat['id']}").raise_for_status() + assert client.get("/api/chats").raise_for_status().json() == [] diff --git a/tests/deploy/api/utils.py b/tests/deploy/api/utils.py index fee4feff..86ce34a8 100644 --- a/tests/deploy/api/utils.py +++ b/tests/deploy/api/utils.py @@ -3,6 +3,16 @@ from fastapi.testclient import TestClient from ragna.core._utils import default_user +from ragna.deploy._core import make_app + + +def make_api_app(*, config, ignore_unavailable_components): + return make_app( + config, + api=True, + ui=False, + ignore_unavailable_components=ignore_unavailable_components, + ) def authenticate(client: TestClient) -> None: diff --git a/tests/deploy/test_config.py b/tests/deploy/test_config.py index a8250acd..be403cd7 100644 --- a/tests/deploy/test_config.py +++ b/tests/deploy/test_config.py @@ -17,24 +17,6 @@ def test_env_var_prefix(mocker, tmp_path): assert config.local_root == env_var -def test_env_var_api_prefix(mocker): - env_var = "hostname" - mocker.patch.dict(os.environ, values={"RAGNA_API_HOSTNAME": env_var}) - - config = Config() - - assert config.api.hostname == env_var - - -def test_env_var_ui_prefix(mocker): - env_var = "hostname" - mocker.patch.dict(os.environ, values={"RAGNA_UI_HOSTNAME": env_var}) - - config = Config() - - assert config.ui.hostname == env_var - - @pytest.mark.xfail() def test_explicit_gt_env_var(mocker, tmp_path): explicit = tmp_path / "explicit" @@ -65,15 +47,14 @@ def test_env_var_gt_config_file(mocker, tmp_path): def test_api_database_url_default_path(tmp_path): config = Config(local_root=tmp_path) - assert Path(urlsplit(config.api.database_url).path[1:]).parent == tmp_path + assert Path(urlsplit(config.database_url).path[1:]).parent == tmp_path -@pytest.mark.parametrize("config_subsection", ["api", "ui"]) -def test_origins_default(config_subsection): +def test_origins_default(): hostname, port = "0.0.0.0", "80" - config = Config(ui=dict(hostname=hostname, port=port)) + config = Config(hostname=hostname, port=port) - assert getattr(config, config_subsection).origins == [f"http://{hostname}:{port}"] + assert config.origins == [f"http://{hostname}:{port}"] def test_from_file_path_not_exists(tmp_path): From 425b8b8b7cb152f3c554adf22c93258ba609f897 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 19 Jun 2024 12:32:41 +0200 Subject: [PATCH 03/29] re-enable browser opening functionality (#431) --- ragna/_cli/core.py | 8 ++++++++ ragna/deploy/_core.py | 41 +++++++++++++++++++++++++++++++++++++-- tests/deploy/api/utils.py | 1 + 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/ragna/_cli/core.py b/ragna/_cli/core.py index 6961b87b..2e9030df 100644 --- a/ragna/_cli/core.py +++ b/ragna/_cli/core.py @@ -96,6 +96,10 @@ def deploy( ) ), ] = False, + open_browser: Annotated[ + Optional[bool], + typer.Option(help="Open a browser when Ragna is deployed."), + ] = None, ) -> None: def api_available() -> bool: try: @@ -111,12 +115,16 @@ def api_available() -> bool: elif ui and not api and not api_available(): raise Exception + if open_browser is None: + open_browser = ui + uvicorn.run( lambda: make_app( config, ui=ui, api=api, ignore_unavailable_components=ignore_unavailable_components, + open_browser=open_browser, ), factory=True, host=config.hostname, diff --git a/ragna/deploy/_core.py b/ragna/deploy/_core.py index fa516e13..4cdb14d4 100644 --- a/ragna/deploy/_core.py +++ b/ragna/deploy/_core.py @@ -1,5 +1,10 @@ -from typing import cast +import contextlib +import threading +import time +import webbrowser +from typing import AsyncContextManager, AsyncIterator, Callable, Optional, cast +import httpx from fastapi import FastAPI, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response @@ -19,11 +24,43 @@ def make_app( api: bool, ui: bool, ignore_unavailable_components: bool, + open_browser: bool, ) -> FastAPI: ragna.local_root(config.local_root) set_redirect_root_path(config.root_path) - app = FastAPI(title="Ragna", version=ragna.__version__) + lifespan: Optional[Callable[[FastAPI], AsyncContextManager]] + if open_browser: + + @contextlib.asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncIterator[None]: + def target() -> None: + client = httpx.Client(base_url=config._url) + + def server_available(): + try: + return client.get("/health").is_success + except httpx.ConnectError: + return False + + while not server_available(): + time.sleep(0.1) + + webbrowser.open(config._url) + + # We are starting the browser on a thread, because the server can only + # become available _after_ the yield below. By setting daemon=True, the + # thread will automatically terminated together with the main thread. This + # is only relevant when the server never becomes available, e.g. if an error + # occurs. In this case our thread would be stuck in an endless loop. + thread = threading.Thread(target=target, daemon=True) + thread.start() + yield + + else: + lifespan = None + + app = FastAPI(title="Ragna", version=ragna.__version__, lifespan=lifespan) app.add_middleware( CORSMiddleware, diff --git a/tests/deploy/api/utils.py b/tests/deploy/api/utils.py index 86ce34a8..e72e69f6 100644 --- a/tests/deploy/api/utils.py +++ b/tests/deploy/api/utils.py @@ -12,6 +12,7 @@ def make_api_app(*, config, ignore_unavailable_components): api=True, ui=False, ignore_unavailable_components=ignore_unavailable_components, + open_browser=False, ) From 4d37e33859f56c0a7b4c1b8e9bcae9b1d5eb2867 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 26 Jun 2024 15:40:01 +0200 Subject: [PATCH 04/29] introduce engine for API (#434) --- pyproject.toml | 2 +- ragna/core/_components.py | 12 + ragna/core/_rag.py | 87 ++++- ragna/deploy/_api.py | 163 +++++++++ ragna/deploy/_api/__init__.py | 1 - ragna/deploy/_api/core.py | 318 ------------------ ragna/deploy/_api/database.py | 270 --------------- ragna/deploy/_core.py | 14 +- ragna/deploy/_database.py | 274 +++++++++++++++ ragna/deploy/_engine.py | 205 +++++++++++ ragna/deploy/{_api/orm.py => _orm.py} | 0 ragna/deploy/{_api/schemas.py => _schemas.py} | 26 +- tests/deploy/api/conftest.py | 41 +++ 13 files changed, 774 insertions(+), 639 deletions(-) create mode 100644 ragna/deploy/_api.py delete mode 100644 ragna/deploy/_api/__init__.py delete mode 100644 ragna/deploy/_api/core.py delete mode 100644 ragna/deploy/_api/database.py create mode 100644 ragna/deploy/_database.py create mode 100644 ragna/deploy/_engine.py rename ragna/deploy/{_api/orm.py => _orm.py} (100%) rename ragna/deploy/{_api/schemas.py => _schemas.py} (62%) create mode 100644 tests/deploy/api/conftest.py diff --git a/pyproject.toml b/pyproject.toml index f85a71a4..6df079a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -162,7 +162,7 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = [ - "ragna.deploy._api.orm", + "ragna.deploy._orm", ] # Our ORM schema doesn't really work with mypy. There are some other ways to define it # to play ball. We should do that in the future. diff --git a/ragna/core/_components.py b/ragna/core/_components.py index d237c1b8..2f987910 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -1,9 +1,11 @@ from __future__ import annotations import abc +import datetime import enum import functools import inspect +import uuid from typing import AsyncIterable, AsyncIterator, Iterator, Optional, Type, Union import pydantic @@ -157,6 +159,8 @@ def __init__( *, role: MessageRole = MessageRole.SYSTEM, sources: Optional[list[Source]] = None, + id: Optional[uuid.UUID] = None, + timestamp: Optional[datetime.datetime] = None, ) -> None: if isinstance(content, str): self._content: str = content @@ -166,6 +170,14 @@ def __init__( self.role = role self.sources = sources or [] + if id is None: + id = uuid.uuid4() + self.id = id + + if timestamp is None: + timestamp = datetime.datetime.utcnow() + self.timestamp = timestamp + async def __aiter__(self) -> AsyncIterator[str]: if hasattr(self, "_content"): yield self._content diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index 6cdff127..98c32a42 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -4,6 +4,7 @@ import inspect import uuid from typing import ( + TYPE_CHECKING, Any, AsyncIterator, Awaitable, @@ -12,21 +13,24 @@ Iterable, Iterator, Optional, - Type, TypeVar, Union, cast, ) import pydantic +from fastapi import status from starlette.concurrency import iterate_in_threadpool, run_in_threadpool from ._components import Assistant, Component, Message, MessageRole, SourceStorage from ._document import Document, LocalDocument from ._utils import RagnaException, default_user, merge_models +if TYPE_CHECKING: + from ragna.deploy import Config + T = TypeVar("T") -C = TypeVar("C", bound=Component) +C = TypeVar("C", bound=Component, covariant=True) class Rag(Generic[C]): @@ -41,13 +45,49 @@ class Rag(Generic[C]): ``` """ - def __init__(self) -> None: - self._components: dict[Type[C], C] = {} + def __init__( + self, + *, + config: Optional[Config] = None, + ignore_unavailable_components: bool = False, + ) -> None: + self._components: dict[type[C], C] = {} + self._display_name_map: dict[str, type[C]] = {} + + if config is not None: + self._preload_components( + config=config, + ignore_unavailable_components=ignore_unavailable_components, + ) + + def _preload_components( + self, *, config: Config, ignore_unavailable_components: bool + ) -> None: + for components in [config.source_storages, config.assistants]: + components = cast(list[type[Component]], components) + at_least_one = False + for component in components: + loaded_component = self._load_component( + component, # type: ignore[arg-type] + ignore_unavailable=ignore_unavailable_components, + ) + if loaded_component is None: + print( + f"Ignoring {component.display_name()}, because it is not available." + ) + else: + at_least_one = True + + if not at_least_one: + raise RagnaException( + "No component available", + components=[component.display_name() for component in components], + ) def _load_component( - self, component: Union[Type[C], C], *, ignore_unavailable: bool = False + self, component: Union[C, type[C], str], *, ignore_unavailable: bool = False ) -> Optional[C]: - cls: Type[C] + cls: type[C] instance: Optional[C] if isinstance(component, Component): @@ -55,6 +95,19 @@ def _load_component( cls = type(instance) elif isinstance(component, type) and issubclass(component, Component): cls = component + instance = None + elif isinstance(component, str): + try: + cls = self._display_name_map[component] + except KeyError: + raise RagnaException( + "Unknown component", + display_name=component, + help="Did you forget to create the Rag() instance with a config?", + http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + http_detail=f"Unknown component '{component}'", + ) from None + instance = None else: raise RagnaException @@ -71,6 +124,7 @@ def _load_component( instance = cls() self._components[cls] = instance + self._display_name_map[cls.display_name()] = cls return self._components[cls] @@ -78,8 +132,8 @@ def chat( self, *, documents: Iterable[Any], - source_storage: Union[Type[SourceStorage], SourceStorage], - assistant: Union[Type[Assistant], Assistant], + source_storage: Union[SourceStorage, type[SourceStorage], str], + assistant: Union[Assistant, type[Assistant], str], **params: Any, ) -> Chat: """Create a new [ragna.core.Chat][]. @@ -87,6 +141,7 @@ def chat( Args: documents: Documents to use. If any item is not a [ragna.core.Document][], [ragna.core.LocalDocument.from_path][] is invoked on it. + FIXME source_storage: Source storage to use. assistant: Assistant to use. **params: Additional parameters passed to the source storage and assistant. @@ -94,8 +149,8 @@ def chat( return Chat( self, documents=documents, - source_storage=source_storage, - assistant=assistant, + source_storage=cast(SourceStorage, self._load_component(source_storage)), # type: ignore[arg-type] + assistant=cast(Assistant, self._load_component(assistant)), # type: ignore[arg-type] **params, ) @@ -146,17 +201,15 @@ def __init__( rag: Rag, *, documents: Iterable[Any], - source_storage: Union[Type[SourceStorage], SourceStorage], - assistant: Union[Type[Assistant], Assistant], + source_storage: SourceStorage, + assistant: Assistant, **params: Any, ) -> None: self._rag = rag self.documents = self._parse_documents(documents) - self.source_storage = cast( - SourceStorage, self._rag._load_component(source_storage) - ) - self.assistant = cast(Assistant, self._rag._load_component(assistant)) + self.source_storage = source_storage + self.assistant = assistant special_params = SpecialChatParams().model_dump() special_params.update(params) @@ -306,6 +359,6 @@ async def __aenter__(self) -> Chat: return self async def __aexit__( - self, exc_type: Type[Exception], exc: Exception, traceback: str + self, exc_type: type[Exception], exc: Exception, traceback: str ) -> None: pass diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py new file mode 100644 index 00000000..d3194064 --- /dev/null +++ b/ragna/deploy/_api.py @@ -0,0 +1,163 @@ +import uuid +from typing import Annotated, AsyncIterator, cast + +import aiofiles +import pydantic +from fastapi import ( + APIRouter, + Body, + Depends, + Form, + HTTPException, + UploadFile, +) +from fastapi.responses import StreamingResponse + +import ragna +import ragna.core +from ragna._compat import anext +from ragna.core._utils import default_user +from ragna.deploy import Config + +from . import _schemas as schemas +from ._engine import Engine + + +def make_router(config: Config, engine: Engine) -> APIRouter: + router = APIRouter(tags=["API"]) + + def get_user() -> str: + return default_user() + + UserDependency = Annotated[str, Depends(get_user)] + + # TODO: the document endpoints do not go through the engine, because they'll change + # quite drastically when the UI no longer depends on the API + + _database = engine._database + + @router.post("/document") + async def create_document_upload_info( + user: UserDependency, + name: Annotated[str, Body(..., embed=True)], + ) -> schemas.DocumentUpload: + with _database.get_session() as session: + document = schemas.Document(name=name) + metadata, parameters = await config.document.get_upload_info( + config=config, user=user, id=document.id, name=document.name + ) + document.metadata = metadata + _database.add_document( + session, user=user, document=document, metadata=metadata + ) + return schemas.DocumentUpload(parameters=parameters, document=document) + + # TODO: Add UI support and documentation for this endpoint (#406) + @router.post("/documents") + async def create_documents_upload_info( + user: UserDependency, + names: Annotated[list[str], Body(..., embed=True)], + ) -> list[schemas.DocumentUpload]: + with _database.get_session() as session: + document_metadata_collection = [] + document_upload_collection = [] + for name in names: + document = schemas.Document(name=name) + metadata, parameters = await config.document.get_upload_info( + config=config, user=user, id=document.id, name=document.name + ) + document.metadata = metadata + document_metadata_collection.append((document, metadata)) + document_upload_collection.append( + schemas.DocumentUpload(parameters=parameters, document=document) + ) + + _database.add_documents( + session, + user=user, + document_metadata_collection=document_metadata_collection, + ) + return document_upload_collection + + # TODO: Add new endpoint for batch uploading documents (#407) + @router.put("/document") + async def upload_document( + token: Annotated[str, Form()], file: UploadFile + ) -> schemas.Document: + if not issubclass(config.document, ragna.core.LocalDocument): + raise HTTPException( + status_code=400, + detail="Ragna configuration does not support local upload", + ) + with _database.get_session() as session: + user, id = ragna.core.LocalDocument.decode_upload_token(token) + document = _database.get_document(session, user=user, id=id) + + core_document = cast( + ragna.core.LocalDocument, engine._to_core.document(document) + ) + core_document.path.parent.mkdir(parents=True, exist_ok=True) + async with aiofiles.open(core_document.path, "wb") as document_file: + while content := await file.read(1024): + await document_file.write(content) + + return document + + @router.get("/components") + def get_components(_: UserDependency) -> schemas.Components: + return engine.get_components() + + @router.post("/chats") + async def create_chat( + user: UserDependency, + chat_metadata: schemas.ChatMetadata, + ) -> schemas.Chat: + return engine.create_chat(user=user, chat_metadata=chat_metadata) + + @router.get("/chats") + async def get_chats(user: UserDependency) -> list[schemas.Chat]: + return engine.get_chats(user=user) + + @router.get("/chats/{id}") + async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat: + return engine.get_chat(user=user, id=id) + + @router.post("/chats/{id}/prepare") + async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message: + return await engine.prepare_chat(user=user, id=id) + + @router.post("/chats/{id}/answer") + async def answer( + user: UserDependency, + id: uuid.UUID, + prompt: Annotated[str, Body(..., embed=True)], + stream: Annotated[bool, Body(..., embed=True)] = False, + ) -> schemas.Message: + message_stream = engine.answer_stream(user=user, chat_id=id, prompt=prompt) + answer = await anext(message_stream) + + if not stream: + content_chunks = [chunk.content async for chunk in message_stream] + answer.content += "".join(content_chunks) + return answer + + async def message_chunks() -> AsyncIterator[schemas.Message]: + yield answer + async for chunk in message_stream: + yield chunk + + async def to_jsonl( + models: AsyncIterator[pydantic.BaseModel], + ) -> AsyncIterator[str]: + async for model in models: + yield f"{model.model_dump_json()}\n" + + return StreamingResponse( # type: ignore[return-value] + to_jsonl(message_chunks()) + ) + + @router.delete("/chats/{id}") + async def delete_chat(user: UserDependency, id: uuid.UUID) -> None: + engine.delete_chat(user=user, id=id) + + return router diff --git a/ragna/deploy/_api/__init__.py b/ragna/deploy/_api/__init__.py deleted file mode 100644 index f99fb828..00000000 --- a/ragna/deploy/_api/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .core import make_router diff --git a/ragna/deploy/_api/core.py b/ragna/deploy/_api/core.py deleted file mode 100644 index 9bb9c682..00000000 --- a/ragna/deploy/_api/core.py +++ /dev/null @@ -1,318 +0,0 @@ -import contextlib -import uuid -from typing import Annotated, Any, AsyncIterator, Iterator, Type, cast - -import aiofiles -from fastapi import ( - APIRouter, - Body, - Depends, - Form, - HTTPException, - UploadFile, - status, -) -from fastapi.responses import StreamingResponse -from pydantic import BaseModel - -import ragna -import ragna.core -from ragna._compat import aiter, anext -from ragna.core import Assistant, Component, Rag, RagnaException, SourceStorage -from ragna.core._rag import SpecialChatParams -from ragna.core._utils import default_user -from ragna.deploy import Config - -from . import database, schemas - - -def make_router(config: Config, ignore_unavailable_components: bool) -> APIRouter: - router = APIRouter(tags=["API"]) - - rag = Rag() # type: ignore[var-annotated] - components_map: dict[str, Component] = {} - for components in [config.source_storages, config.assistants]: - components = cast(list[Type[Component]], components) - at_least_one = False - for component in components: - loaded_component = rag._load_component( - component, ignore_unavailable=ignore_unavailable_components - ) - if loaded_component is None: - print( - f"Ignoring {component.display_name()}, because it is not available." - ) - else: - at_least_one = True - components_map[component.display_name()] = loaded_component - - if not at_least_one: - raise RagnaException( - "No component available", - components=[component.display_name() for component in components], - ) - - def get_component(display_name: str) -> Component: - component = components_map.get(display_name) - if component is None: - raise RagnaException( - "Unknown component", - display_name=display_name, - http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - http_detail=RagnaException.MESSAGE, - ) - - return component - - @router.get("/") - async def version() -> str: - return ragna.__version__ - - def get_user() -> str: - return default_user() - - UserDependency = Annotated[str, Depends(get_user)] - - def _get_component_json_schema( - component: Type[Component], - ) -> dict[str, dict[str, Any]]: - json_schema = component._protocol_model().model_json_schema() - # FIXME: there is likely a better way to exclude certain fields builtin in - # pydantic - for special_param in SpecialChatParams.model_fields: - if ( - "properties" in json_schema - and special_param in json_schema["properties"] - ): - del json_schema["properties"][special_param] - if "required" in json_schema and special_param in json_schema["required"]: - json_schema["required"].remove(special_param) - return json_schema - - @router.get("/components") - async def get_components(_: UserDependency) -> schemas.Components: - return schemas.Components( - documents=sorted(config.document.supported_suffixes()), - source_storages=[ - _get_component_json_schema(type(source_storage)) - for source_storage in components_map.values() - if isinstance(source_storage, SourceStorage) - ], - assistants=[ - _get_component_json_schema(type(assistant)) - for assistant in components_map.values() - if isinstance(assistant, Assistant) - ], - ) - - make_session = database.get_sessionmaker(config.database_url) - - @contextlib.contextmanager - def get_session() -> Iterator[database.Session]: - with make_session() as session: # type: ignore[attr-defined] - yield session - - @router.post("/document") - async def create_document_upload_info( - user: UserDependency, - name: Annotated[str, Body(..., embed=True)], - ) -> schemas.DocumentUpload: - with get_session() as session: - document = schemas.Document(name=name) - metadata, parameters = await config.document.get_upload_info( - config=config, user=user, id=document.id, name=document.name - ) - database.add_document( - session, user=user, document=document, metadata=metadata - ) - return schemas.DocumentUpload(parameters=parameters, document=document) - - # TODO: Add UI support and documentation for this endpoint (#406) - @router.post("/documents") - async def create_documents_upload_info( - user: UserDependency, - names: Annotated[list[str], Body(..., embed=True)], - ) -> list[schemas.DocumentUpload]: - with get_session() as session: - document_metadata_collection = [] - document_upload_collection = [] - for name in names: - document = schemas.Document(name=name) - metadata, parameters = await config.document.get_upload_info( - config=config, user=user, id=document.id, name=document.name - ) - document_metadata_collection.append((document, metadata)) - document_upload_collection.append( - schemas.DocumentUpload(parameters=parameters, document=document) - ) - - database.add_documents( - session, - user=user, - document_metadata_collection=document_metadata_collection, - ) - return document_upload_collection - - # TODO: Add new endpoint for batch uploading documents (#407) - @router.put("/document") - async def upload_document( - token: Annotated[str, Form()], file: UploadFile - ) -> schemas.Document: - if not issubclass(config.document, ragna.core.LocalDocument): - raise HTTPException( - status_code=400, - detail="Ragna configuration does not support local upload", - ) - with get_session() as session: - user, id = ragna.core.LocalDocument.decode_upload_token(token) - document, metadata = database.get_document(session, user=user, id=id) - - core_document = ragna.core.LocalDocument( - id=document.id, name=document.name, metadata=metadata - ) - core_document.path.parent.mkdir(parents=True, exist_ok=True) - async with aiofiles.open(core_document.path, "wb") as document_file: - while content := await file.read(1024): - await document_file.write(content) - - return document - - def schema_to_core_chat( - session: database.Session, *, user: str, chat: schemas.Chat - ) -> ragna.core.Chat: - core_chat = rag.chat( - documents=[ - config.document( - id=document.id, - name=document.name, - metadata=database.get_document( - session, - user=user, - id=document.id, - )[1], - ) - for document in chat.metadata.documents - ], - source_storage=get_component(chat.metadata.source_storage), # type: ignore[arg-type] - assistant=get_component(chat.metadata.assistant), # type: ignore[arg-type] - user=user, - chat_id=chat.id, - chat_name=chat.metadata.name, - **chat.metadata.params, - ) - # FIXME: We need to reconstruct the previous messages here. Right now this is - # not needed, because the chat itself never accesses past messages. However, - # if we implement a chat history feature, i.e. passing past messages to - # the assistant, this becomes crucial. - core_chat._messages = [] - core_chat._prepared = chat.prepared - - return core_chat - - @router.post("/chats") - async def create_chat( - user: UserDependency, - chat_metadata: schemas.ChatMetadata, - ) -> schemas.Chat: - with get_session() as session: - chat = schemas.Chat(metadata=chat_metadata) - - # Although we don't need the actual ragna.core.Chat object here, - # we use it to validate the documents and metadata. - schema_to_core_chat(session, user=user, chat=chat) - - database.add_chat(session, user=user, chat=chat) - return chat - - @router.get("/chats") - async def get_chats(user: UserDependency) -> list[schemas.Chat]: - with get_session() as session: - return database.get_chats(session, user=user) - - @router.get("/chats/{id}") - async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat: - with get_session() as session: - return database.get_chat(session, user=user, id=id) - - @router.post("/chats/{id}/prepare") - async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message: - with get_session() as session: - chat = database.get_chat(session, user=user, id=id) - - core_chat = schema_to_core_chat(session, user=user, chat=chat) - - welcome = schemas.Message.from_core(await core_chat.prepare()) - - chat.prepared = True - chat.messages.append(welcome) - database.update_chat(session, user=user, chat=chat) - - return welcome - - @router.post("/chats/{id}/answer") - async def answer( - user: UserDependency, - id: uuid.UUID, - prompt: Annotated[str, Body(..., embed=True)], - stream: Annotated[bool, Body(..., embed=True)] = False, - ) -> schemas.Message: - with get_session() as session: - chat = database.get_chat(session, user=user, id=id) - chat.messages.append( - schemas.Message(content=prompt, role=ragna.core.MessageRole.USER) - ) - core_chat = schema_to_core_chat(session, user=user, chat=chat) - - core_answer = await core_chat.answer(prompt, stream=stream) - - if stream: - - async def message_chunks() -> AsyncIterator[BaseModel]: - core_answer_stream = aiter(core_answer) - content_chunk = await anext(core_answer_stream) - - answer = schemas.Message( - content=content_chunk, - role=core_answer.role, - sources=[ - schemas.Source.from_core(source) - for source in core_answer.sources - ], - ) - yield answer - - # Avoid sending the sources multiple times - answer_chunk = answer.model_copy(update=dict(sources=None)) - content_chunks = [answer_chunk.content] - async for content_chunk in core_answer_stream: - content_chunks.append(content_chunk) - answer_chunk.content = content_chunk - yield answer_chunk - - with get_session() as session: - answer.content = "".join(content_chunks) - chat.messages.append(answer) - database.update_chat(session, user=user, chat=chat) - - async def to_jsonl(models: AsyncIterator[Any]) -> AsyncIterator[str]: - async for model in models: - yield f"{model.model_dump_json()}\n" - - return StreamingResponse( # type: ignore[return-value] - to_jsonl(message_chunks()) - ) - else: - answer = schemas.Message.from_core(core_answer) - - with get_session() as session: - chat.messages.append(answer) - database.update_chat(session, user=user, chat=chat) - - return answer - - @router.delete("/chats/{id}") - async def delete_chat(user: UserDependency, id: uuid.UUID) -> None: - with get_session() as session: - database.delete_chat(session, user=user, id=id) - - return router diff --git a/ragna/deploy/_api/database.py b/ragna/deploy/_api/database.py deleted file mode 100644 index 2a61b048..00000000 --- a/ragna/deploy/_api/database.py +++ /dev/null @@ -1,270 +0,0 @@ -from __future__ import annotations - -import functools -import uuid -from typing import Any, Callable, Optional, cast -from urllib.parse import urlsplit - -from sqlalchemy import create_engine, select -from sqlalchemy.orm import Session, joinedload -from sqlalchemy.orm import sessionmaker as _sessionmaker - -from ragna.core import RagnaException - -from . import orm, schemas - - -def get_sessionmaker(database_url: str) -> Callable[[], Session]: - components = urlsplit(database_url) - if components.scheme == "sqlite": - connect_args = dict(check_same_thread=False) - else: - connect_args = dict() - engine = create_engine(database_url, connect_args=connect_args) - orm.Base.metadata.create_all(bind=engine) - return _sessionmaker(autocommit=False, autoflush=False, bind=engine) - - -@functools.lru_cache(maxsize=1024) -def _get_user_id(session: Session, username: str) -> uuid.UUID: - user: Optional[orm.User] = session.execute( - select(orm.User).where(orm.User.name == username) - ).scalar_one_or_none() - - if user is None: - # Add a new user if the current username is not registered yet. Since this is - # behind the authentication layer, we don't need any extra security here. - user = orm.User(id=uuid.uuid4(), name=username) - session.add(user) - session.commit() - - return cast(uuid.UUID, user.id) - - -def add_document( - session: Session, *, user: str, document: schemas.Document, metadata: dict[str, Any] -) -> None: - session.add( - orm.Document( - id=document.id, - user_id=_get_user_id(session, user), - name=document.name, - metadata_=metadata, - ) - ) - session.commit() - - -def add_documents( - session: Session, - *, - user: str, - document_metadata_collection: list[tuple[schemas.Document, dict[str, Any]]], -) -> None: - """ - Add multiple documents to the database. - - This function allows adding multiple documents at once by calling `add_all`. This is - important when there is non-negligible latency attached to each database operation. - """ - user_id = _get_user_id(session, user) - documents = [ - orm.Document( - id=document.id, - user_id=user_id, - name=document.name, - metadata_=metadata, - ) - for document, metadata in document_metadata_collection - ] - session.add_all(documents) - session.commit() - - -def _orm_to_schema_document(document: orm.Document) -> schemas.Document: - return schemas.Document(id=document.id, name=document.name) - - -@functools.lru_cache(maxsize=1024) -def get_document( - session: Session, *, user: str, id: uuid.UUID -) -> tuple[schemas.Document, dict[str, Any]]: - document = session.execute( - select(orm.Document).where( - (orm.Document.user_id == _get_user_id(session, user)) - & (orm.Document.id == id) - ) - ).scalar_one_or_none() - return _orm_to_schema_document(document), document.metadata_ - - -def add_chat(session: Session, *, user: str, chat: schemas.Chat) -> None: - document_ids = {document.id for document in chat.metadata.documents} - documents = ( - session.execute(select(orm.Document).where(orm.Document.id.in_(document_ids))) - .scalars() - .all() - ) - if len(documents) != len(document_ids): - raise RagnaException( - str(set(document_ids) - {document.id for document in documents}) - ) - session.add( - orm.Chat( - id=chat.id, - user_id=_get_user_id(session, user), - name=chat.metadata.name, - documents=documents, - source_storage=chat.metadata.source_storage, - assistant=chat.metadata.assistant, - params=chat.metadata.params, - prepared=chat.prepared, - ) - ) - session.commit() - - -def _orm_to_schema_chat(chat: orm.Chat) -> schemas.Chat: - documents = [ - schemas.Document(id=document.id, name=document.name) - for document in chat.documents - ] - messages = [ - schemas.Message( - id=message.id, - role=message.role, - content=message.content, - sources=[ - schemas.Source( - id=source.id, - document=_orm_to_schema_document(source.document), - location=source.location, - content=source.content, - num_tokens=source.num_tokens, - ) - for source in message.sources - ], - timestamp=message.timestamp, - ) - for message in chat.messages - ] - return schemas.Chat( - id=chat.id, - metadata=schemas.ChatMetadata( - name=chat.name, - documents=documents, - source_storage=chat.source_storage, - assistant=chat.assistant, - params=chat.params, - ), - messages=messages, - prepared=chat.prepared, - ) - - -def _select_chat(*, eager: bool = False) -> Any: - selector = select(orm.Chat) - if eager: - selector = selector.options( # type: ignore[attr-defined] - joinedload(orm.Chat.messages).joinedload(orm.Message.sources), - joinedload(orm.Chat.documents), - ) - return selector - - -def get_chats(session: Session, *, user: str) -> list[schemas.Chat]: - return [ - _orm_to_schema_chat(chat) - for chat in session.execute( - _select_chat(eager=True).where( - orm.Chat.user_id == _get_user_id(session, user) - ) - ) - .scalars() - .unique() - .all() - ] - - -def _get_orm_chat( - session: Session, *, user: str, id: uuid.UUID, eager: bool = False -) -> orm.Chat: - chat: Optional[orm.Chat] = ( - session.execute( - _select_chat(eager=eager).where( - (orm.Chat.id == id) & (orm.Chat.user_id == _get_user_id(session, user)) - ) - ) - .unique() - .scalar_one_or_none() - ) - if chat is None: - raise RagnaException() - return chat - - -def get_chat(session: Session, *, user: str, id: uuid.UUID) -> schemas.Chat: - return _orm_to_schema_chat(_get_orm_chat(session, user=user, id=id, eager=True)) - - -def _schema_to_orm_source(session: Session, source: schemas.Source) -> orm.Source: - orm_source: Optional[orm.Source] = session.execute( - select(orm.Source).where(orm.Source.id == source.id) - ).scalar_one_or_none() - - if orm_source is None: - orm_source = orm.Source( - id=source.id, - document_id=source.document.id, - location=source.location, - content=source.content, - num_tokens=source.num_tokens, - ) - session.add(orm_source) - session.commit() - session.refresh(orm_source) - - return orm_source - - -def _schema_to_orm_message( - session: Session, chat_id: uuid.UUID, message: schemas.Message -) -> orm.Message: - orm_message: Optional[orm.Message] = session.execute( - select(orm.Message).where(orm.Message.id == message.id) - ).scalar_one_or_none() - if orm_message is None: - orm_message = orm.Message( - id=message.id, - chat_id=chat_id, - content=message.content, - role=message.role, - sources=[ - _schema_to_orm_source(session, source=source) - for source in message.sources - ], - timestamp=message.timestamp, - ) - session.add(orm_message) - session.commit() - session.refresh(orm_message) - - return orm_message - - -def update_chat(session: Session, user: str, chat: schemas.Chat) -> None: - orm_chat = _get_orm_chat(session, user=user, id=chat.id) - - orm_chat.prepared = chat.prepared - orm_chat.messages = [ # type: ignore[assignment] - _schema_to_orm_message(session, chat_id=chat.id, message=message) - for message in chat.messages - ] - - session.commit() - - -def delete_chat(session: Session, user: str, id: uuid.UUID) -> None: - orm_chat = _get_orm_chat(session, user=user, id=id) - session.delete(orm_chat) # type: ignore[no-untyped-call] - session.commit() diff --git a/ragna/deploy/_core.py b/ragna/deploy/_core.py index 4cdb14d4..67f067e0 100644 --- a/ragna/deploy/_core.py +++ b/ragna/deploy/_core.py @@ -14,6 +14,7 @@ from ._api import make_router as make_api_router from ._config import Config +from ._engine import Engine from ._ui import app as make_ui_app from ._utils import handle_localhost_origins, redirect, set_redirect_root_path @@ -70,14 +71,13 @@ def server_available(): allow_headers=["*"], ) + engine = Engine( + config=config, + ignore_unavailable_components=ignore_unavailable_components, + ) + if api: - app.include_router( - make_api_router( - config, - ignore_unavailable_components=ignore_unavailable_components, - ), - prefix="/api", - ) + app.include_router(make_api_router(config, engine), prefix="/api") if ui: panel_app = make_ui_app(config=config) diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py new file mode 100644 index 00000000..323ccd21 --- /dev/null +++ b/ragna/deploy/_database.py @@ -0,0 +1,274 @@ +from __future__ import annotations + +import uuid +from typing import Any, Optional +from urllib.parse import urlsplit + +from sqlalchemy import create_engine, select +from sqlalchemy.orm import Session, joinedload, sessionmaker + +from ragna.core import RagnaException + +from . import _orm as orm +from . import _schemas as schemas + + +class Database: + def __init__(self, url: str) -> None: + components = urlsplit(url) + if components.scheme == "sqlite": + connect_args = dict(check_same_thread=False) + else: + connect_args = dict() + engine = create_engine(url, connect_args=connect_args) + orm.Base.metadata.create_all(bind=engine) + + self.get_session = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + self._to_orm = SchemaToOrmConverter() + self._to_schema = OrmToSchemaConverter() + + def _get_user(self, session: Session, *, username: str) -> orm.User: + user: Optional[orm.User] = session.execute( + select(orm.User).where(orm.User.name == username) + ).scalar_one_or_none() + + if user is None: + # Add a new user if the current username is not registered yet. Since this + # is behind the authentication layer, we don't need any extra security here. + user = orm.User(id=uuid.uuid4(), name=username) + session.add(user) + session.commit() + + return user + + def add_document( + self, + session: Session, + *, + user: str, + document: schemas.Document, + metadata: dict[str, Any], + ) -> None: + session.add( + orm.Document( + id=document.id, + user_id=self._get_user(session, username=user).id, + name=document.name, + metadata_=metadata, + ) + ) + session.commit() + + def add_documents( + self, + session: Session, + *, + user: str, + document_metadata_collection: list[tuple[schemas.Document, dict[str, Any]]], + ) -> None: + """ + Add multiple documents to the database. + + This function allows adding multiple documents at once by calling `add_all`. This is + important when there is non-negligible latency attached to each database operation. + """ + documents = [ + orm.Document( + id=document.id, + user_id=self._get_user(session, username=user).id, + name=document.name, + metadata_=metadata, + ) + for document, metadata in document_metadata_collection + ] + session.add_all(documents) + session.commit() + + def get_document( + self, session: Session, *, user: str, id: uuid.UUID + ) -> schemas.Document: + document = session.execute( + select(orm.Document).where( + (orm.Document.user_id == self._get_user(session, username=user).id) + & (orm.Document.id == id) + ) + ).scalar_one_or_none() + return self._to_schema.document(document) + + def add_chat(self, session: Session, *, user: str, chat: schemas.Chat) -> None: + document_ids = {document.id for document in chat.metadata.documents} + # FIXME also check if the user is allowed to access the documents? + documents = ( + session.execute( + select(orm.Document).where(orm.Document.id.in_(document_ids)) + ) + .scalars() + .all() + ) + if len(documents) != len(document_ids): + raise RagnaException( + str(document_ids - {document.id for document in documents}) + ) + + orm_chat = self._to_orm.chat( + chat, + user_id=self._get_user(session, username=user).id, + # We have to pass the documents here, because SQLAlchemy does not allow a + # second instance of orm.Document with the same primary key in the session. + documents=documents, + ) + session.add(orm_chat) + session.commit() + + def _select_chat(self, *, eager: bool = False) -> Any: + selector = select(orm.Chat) + if eager: + selector = selector.options( # type: ignore[attr-defined] + joinedload(orm.Chat.messages).joinedload(orm.Message.sources), + joinedload(orm.Chat.documents), + ) + return selector + + def get_chats(self, session: Session, *, user: str) -> list[schemas.Chat]: + return [ + self._to_schema.chat(chat) + for chat in session.execute( + self._select_chat(eager=True).where( + orm.Chat.user_id == self._get_user(session, username=user).id + ) + ) + .scalars() + .unique() + .all() + ] + + def _get_orm_chat( + self, session: Session, *, user: str, id: uuid.UUID, eager: bool = False + ) -> orm.Chat: + chat: Optional[orm.Chat] = ( + session.execute( + self._select_chat(eager=eager).where( + (orm.Chat.id == id) + & (orm.Chat.user_id == self._get_user(session, username=user).id) + ) + ) + .unique() + .scalar_one_or_none() + ) + if chat is None: + raise RagnaException() + return chat + + def get_chat(self, session: Session, *, user: str, id: uuid.UUID) -> schemas.Chat: + return self._to_schema.chat( + (self._get_orm_chat(session, user=user, id=id, eager=True)) + ) + + def update_chat(self, session: Session, user: str, chat: schemas.Chat) -> None: + orm_chat = self._to_orm.chat( + chat, user_id=self._get_user(session, username=user).id + ) + session.merge(orm_chat) + session.commit() + + def delete_chat(self, session: Session, user: str, id: uuid.UUID) -> None: + orm_chat = self._get_orm_chat(session, user=user, id=id) + session.delete(orm_chat) # type: ignore[no-untyped-call] + session.commit() + + +class SchemaToOrmConverter: + def document( + self, document: schemas.Document, *, user_id: uuid.UUID + ) -> orm.Document: + return orm.Document( + id=document.id, + user_id=user_id, + name=document.name, + metadata_=document.metadata, + ) + + def source(self, source: schemas.Source) -> orm.Source: + return orm.Source( + id=source.id, + document_id=source.document.id, + location=source.location, + content=source.content, + num_tokens=source.num_tokens, + ) + + def message(self, message: schemas.Message, *, chat_id: uuid.UUID) -> orm.Message: + return orm.Message( + id=message.id, + chat_id=chat_id, + content=message.content, + role=message.role, + sources=[self.source(source) for source in message.sources], + timestamp=message.timestamp, + ) + + def chat( + self, + chat: schemas.Chat, + *, + user_id: uuid.UUID, + documents: Optional[list[orm.Document]] = None, + ) -> orm.Chat: + if documents is None: + documents = [ + self.document(document, user_id=user_id) + for document in chat.metadata.documents + ] + return orm.Chat( + id=chat.id, + user_id=user_id, + name=chat.metadata.name, + documents=documents, + source_storage=chat.metadata.source_storage, + assistant=chat.metadata.assistant, + params=chat.metadata.params, + messages=[ + self.message(message, chat_id=chat.id) for message in chat.messages + ], + prepared=chat.prepared, + ) + + +class OrmToSchemaConverter: + def document(self, document: orm.Document) -> schemas.Document: + return schemas.Document( + id=document.id, name=document.name, metadata=document.metadata_ + ) + + def source(self, source: orm.Source) -> schemas.Source: + return schemas.Source( + id=source.id, + document=self.document(source.document), + location=source.location, + content=source.content, + num_tokens=source.num_tokens, + ) + + def message(self, message: orm.Message) -> schemas.Message: + return schemas.Message( + id=message.id, + role=message.role, # type: ignore[arg-type] + content=message.content, + sources=[self.source(source) for source in message.sources], + timestamp=message.timestamp, + ) + + def chat(self, chat: orm.Chat) -> schemas.Chat: + return schemas.Chat( + id=chat.id, + metadata=schemas.ChatMetadata( + name=chat.name, + documents=[self.document(document) for document in chat.documents], + source_storage=chat.source_storage, + assistant=chat.assistant, + params=chat.params, + ), + messages=[self.message(message) for message in chat.messages], + prepared=chat.prepared, + ) diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py new file mode 100644 index 00000000..847f7a93 --- /dev/null +++ b/ragna/deploy/_engine.py @@ -0,0 +1,205 @@ +import uuid +from typing import Any, AsyncIterator, Optional, Type + +from ragna import Rag, core +from ragna._compat import aiter, anext +from ragna.core._rag import SpecialChatParams +from ragna.deploy import Config + +from . import _schemas as schemas +from ._database import Database + + +class Engine: + def __init__(self, *, config: Config, ignore_unavailable_components: bool) -> None: + self._config = config + + self._database = Database(url=config.database_url) + + self._rag: Rag = Rag( + config=config, + ignore_unavailable_components=ignore_unavailable_components, + ) + + self._to_core = SchemaToCoreConverter(config=config, rag=self._rag) + self._to_schema = CoreToSchemaConverter() + + def _get_component_json_schema( + self, + component: Type[core.Component], + ) -> dict[str, dict[str, Any]]: + json_schema = component._protocol_model().model_json_schema() + # FIXME: there is likely a better way to exclude certain fields builtin in + # pydantic + for special_param in SpecialChatParams.model_fields: + if ( + "properties" in json_schema + and special_param in json_schema["properties"] + ): + del json_schema["properties"][special_param] + if "required" in json_schema and special_param in json_schema["required"]: + json_schema["required"].remove(special_param) + return json_schema + + def get_components(self) -> schemas.Components: + return schemas.Components( + documents=sorted(self._config.document.supported_suffixes()), + source_storages=[ + self._get_component_json_schema(source_storage) + for source_storage in self._rag._components.keys() + if issubclass(source_storage, core.SourceStorage) + ], + assistants=[ + self._get_component_json_schema(assistant) + for assistant in self._rag._components.keys() + if issubclass(assistant, core.Assistant) + ], + ) + + def create_chat( + self, *, user: str, chat_metadata: schemas.ChatMetadata + ) -> schemas.Chat: + chat = schemas.Chat(metadata=chat_metadata) + + # Although we don't need the actual core.Chat here, this just performs the input + # validation. + self._to_core.chat(chat, user=user) + + with self._database.get_session() as session: + self._database.add_chat(session, user=user, chat=chat) + + return chat + + def get_chats(self, *, user: str) -> list[schemas.Chat]: + with self._database.get_session() as session: + return self._database.get_chats(session, user=user) + + def get_chat(self, *, user: str, id: uuid.UUID) -> schemas.Chat: + with self._database.get_session() as session: + return self._database.get_chat(session, user=user, id=id) + + async def prepare_chat(self, *, user: str, id: uuid.UUID) -> schemas.Message: + core_chat = self._to_core.chat(self.get_chat(user=user, id=id), user=user) + core_message = await core_chat.prepare() + + with self._database.get_session() as session: + self._database.update_chat( + session, chat=self._to_schema.chat(core_chat), user=user + ) + + return self._to_schema.message(core_message) + + async def answer_stream( + self, *, user: str, chat_id: uuid.UUID, prompt: str + ) -> AsyncIterator[schemas.Message]: + core_chat = self._to_core.chat(self.get_chat(user=user, id=chat_id), user=user) + core_message = await core_chat.answer(prompt, stream=True) + + content_stream = aiter(core_message) + content_chunk = await anext(content_stream) + message = self._to_schema.message(core_message, content_override=content_chunk) + yield message + + # Avoid sending the sources multiple times + message_chunk = message.model_copy(update=dict(sources=None)) + async for content_chunk in content_stream: + message_chunk.content = content_chunk + yield message_chunk + + with self._database.get_session() as session: + self._database.update_chat( + session, chat=self._to_schema.chat(core_chat), user=user + ) + + def delete_chat(self, *, user: str, id: uuid.UUID) -> None: + with self._database.get_session() as session: + self._database.delete_chat(session, user=user, id=id) + + +class SchemaToCoreConverter: + def __init__(self, *, config: Config, rag: Rag) -> None: + self._config = config + self._rag = rag + + def document(self, document: schemas.Document) -> core.Document: + return self._config.document( + id=document.id, + name=document.name, + metadata=document.metadata, + ) + + def source(self, source: schemas.Source) -> core.Source: + return core.Source( + id=source.id, + document=self.document(source.document), + location=source.location, + content=source.content, + num_tokens=source.num_tokens, + ) + + def message(self, message: schemas.Message) -> core.Message: + return core.Message( + message.content, + role=message.role, + sources=[self.source(source) for source in message.sources], + ) + + def chat(self, chat: schemas.Chat, *, user: str) -> core.Chat: + core_chat = self._rag.chat( + documents=[self.document(document) for document in chat.metadata.documents], + source_storage=chat.metadata.source_storage, + assistant=chat.metadata.assistant, + user=user, + chat_id=chat.id, + chat_name=chat.metadata.name, + **chat.metadata.params, + ) + core_chat._messages = [self.message(message) for message in chat.messages] + core_chat._prepared = chat.prepared + + return core_chat + + +class CoreToSchemaConverter: + def document(self, document: core.Document) -> schemas.Document: + return schemas.Document( + id=document.id, + name=document.name, + metadata=document.metadata, + ) + + def source(self, source: core.Source) -> schemas.Source: + return schemas.Source( + id=source.id, + document=self.document(source.document), + location=source.location, + content=source.content, + num_tokens=source.num_tokens, + ) + + def message( + self, message: core.Message, *, content_override: Optional[str] = None + ) -> schemas.Message: + return schemas.Message( + id=message.id, + content=content_override or message.content, + role=message.role, + sources=[self.source(source) for source in message.sources], + timestamp=message.timestamp, + ) + + def chat(self, chat: core.Chat) -> schemas.Chat: + params = chat.params.copy() + del params["user"] + return schemas.Chat( + id=params.pop("chat_id"), + metadata=schemas.ChatMetadata( + name=params.pop("chat_name"), + source_storage=chat.source_storage.display_name(), + assistant=chat.assistant.display_name(), + params=params, + documents=[self.document(document) for document in chat.documents], + ), + messages=[self.message(message) for message in chat._messages], + prepared=chat._prepared, + ) diff --git a/ragna/deploy/_api/orm.py b/ragna/deploy/_orm.py similarity index 100% rename from ragna/deploy/_api/orm.py rename to ragna/deploy/_orm.py diff --git a/ragna/deploy/_api/schemas.py b/ragna/deploy/_schemas.py similarity index 62% rename from ragna/deploy/_api/schemas.py rename to ragna/deploy/_schemas.py index 53957a74..55ae333f 100644 --- a/ragna/deploy/_api/schemas.py +++ b/ragna/deploy/_schemas.py @@ -18,13 +18,7 @@ class Components(BaseModel): class Document(BaseModel): id: uuid.UUID = Field(default_factory=uuid.uuid4) name: str - - @classmethod - def from_core(cls, document: ragna.core.Document) -> Document: - return cls( - id=document.id, - name=document.name, - ) + metadata: dict[str, Any] = Field(default_factory=dict) class DocumentUpload(BaseModel): @@ -40,16 +34,6 @@ class Source(BaseModel): content: str num_tokens: int - @classmethod - def from_core(cls, source: ragna.core.Source) -> Source: - return cls( - id=source.id, - document=Document.from_core(source.document), - location=source.location, - content=source.content, - num_tokens=source.num_tokens, - ) - class Message(BaseModel): id: uuid.UUID = Field(default_factory=uuid.uuid4) @@ -58,14 +42,6 @@ class Message(BaseModel): sources: list[Source] = Field(default_factory=list) timestamp: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) - @classmethod - def from_core(cls, message: ragna.core.Message) -> Message: - return cls( - content=message.content, - role=message.role, - sources=[Source.from_core(source) for source in message.sources], - ) - class ChatMetadata(BaseModel): name: str diff --git a/tests/deploy/api/conftest.py b/tests/deploy/api/conftest.py new file mode 100644 index 00000000..4bc8053c --- /dev/null +++ b/tests/deploy/api/conftest.py @@ -0,0 +1,41 @@ +import contextlib +import json + +import httpx +import pytest + + +@pytest.fixture(scope="package", autouse=True) +def enhance_raise_for_status(package_mocker): + raise_for_status = httpx.Response.raise_for_status + + def enhanced_raise_for_status(self): + __tracebackhide__ = True + + try: + return raise_for_status(self) + except httpx.HTTPStatusError as error: + content = None + with contextlib.suppress(Exception): + content = error.response.read() + content = content.decode() + content = "\n" + json.dumps(json.loads(content), indent=2) + + if content is None: + raise error + + message = f"{error}\nResponse content: {content}" + raise httpx.HTTPStatusError( + message, request=error.request, response=error.response + ) from None + + yield package_mocker.patch( + ".".join( + [ + httpx.Response.__module__, + httpx.Response.__name__, + raise_for_status.__name__, + ] + ), + new=enhanced_raise_for_status, + ) From 40f0b6c7bc0381ca185f413047f7cb4ffbd77718 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 9 Jul 2024 17:09:54 +0200 Subject: [PATCH 05/29] refactor document registering and upload (#441) --- ragna/core/__init__.py | 1 - ragna/core/_document.py | 105 +++++++------------- ragna/core/_rag.py | 22 ++--- ragna/deploy/_api.py | 109 ++++++--------------- ragna/deploy/_core.py | 5 +- ragna/deploy/_database.py | 119 ++++++++--------------- ragna/deploy/_engine.py | 95 ++++++++++++++---- ragna/deploy/_schemas.py | 22 +++-- scripts/add_chats.py | 95 +++++++++--------- tests/deploy/api/test_batch_endpoints.py | 86 ---------------- tests/deploy/api/test_components.py | 19 +--- tests/deploy/api/test_e2e.py | 31 +++--- 12 files changed, 271 insertions(+), 438 deletions(-) delete mode 100644 tests/deploy/api/test_batch_endpoints.py diff --git a/ragna/core/__init__.py b/ragna/core/__init__.py index 0f4b4bdf..44449775 100644 --- a/ragna/core/__init__.py +++ b/ragna/core/__init__.py @@ -34,7 +34,6 @@ from ._document import ( Document, DocumentHandler, - DocumentUploadParameters, DocxDocumentHandler, LocalDocument, Page, diff --git a/ragna/core/_document.py b/ragna/core/_document.py index 436344b4..7a1cef7f 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -2,26 +2,17 @@ import abc import io -import os -import secrets -import time import uuid +from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterator, Optional, Type, TypeVar, Union +from typing import Any, AsyncIterator, Iterator, Optional, Type, TypeVar, Union -import jwt +import aiofiles from pydantic import BaseModel -from ._utils import PackageRequirement, RagnaException, Requirement, RequirementsMixin - -if TYPE_CHECKING: - from ragna.deploy import Config +import ragna - -class DocumentUploadParameters(BaseModel): - method: str - url: str - data: dict +from ._utils import PackageRequirement, RagnaException, Requirement, RequirementsMixin class Document(RequirementsMixin, abc.ABC): @@ -62,16 +53,6 @@ def get_handler(name: str) -> DocumentHandler: return handler - @classmethod - @abc.abstractmethod - async def get_upload_info( - cls, *, config: Config, user: str, id: uuid.UUID, name: str - ) -> tuple[dict[str, Any], DocumentUploadParameters]: - pass - - @abc.abstractmethod - def is_readable(self) -> bool: ... - @abc.abstractmethod def read(self) -> bytes: ... @@ -88,12 +69,25 @@ class LocalDocument(Document): [ragna.core.LocalDocument.from_path][]. """ + def __init__( + self, + *, + id: Optional[uuid.UUID] = None, + name: str, + metadata: dict[str, Any], + handler: Optional[DocumentHandler] = None, + ): + super().__init__(id=id, name=name, metadata=metadata, handler=handler) + if "path" not in self.metadata: + metadata["path"] = str(ragna.local_root() / "documents" / str(self.id)) + @classmethod def from_path( cls, path: Union[str, Path], *, id: Optional[uuid.UUID] = None, + name: Optional[str] = None, metadata: Optional[dict[str, Any]] = None, handler: Optional[DocumentHandler] = None, ) -> LocalDocument: @@ -102,6 +96,7 @@ def from_path( Args: path: Local path to the file. id: ID of the document. If omitted, one is generated. + name: Name of the document. If omitted, defaults to the name of the `path`. metadata: Optional metadata of the document. handler: Document handler. If omitted, a builtin handler is selected based on the suffix of the `path`. @@ -118,60 +113,34 @@ def from_path( ) path = Path(path).expanduser().resolve() + if name is None: + name = path.name metadata["path"] = str(path) - return cls(id=id, name=path.name, metadata=metadata, handler=handler) + return cls(id=id, name=name, metadata=metadata, handler=handler) - @property + @cached_property def path(self) -> Path: return Path(self.metadata["path"]) - def is_readable(self) -> bool: - return self.path.exists() - - def read(self) -> bytes: - with open(self.path, "rb") as stream: - return stream.read() - - _JWT_SECRET = os.environ.get( - "RAGNA_API_DOCUMENT_UPLOAD_SECRET", secrets.token_urlsafe(32)[:32] - ) - _JWT_ALGORITHM = "HS256" - - @classmethod - async def get_upload_info( - cls, *, config: Config, user: str, id: uuid.UUID, name: str - ) -> tuple[dict[str, Any], DocumentUploadParameters]: - url = f"{config._url}/api/document" - data = { - "token": jwt.encode( - payload={ - "user": user, - "id": str(id), - "exp": time.time() + 5 * 60, - }, - key=cls._JWT_SECRET, - algorithm=cls._JWT_ALGORITHM, - ) - } - metadata = {"path": str(config.local_root / "documents" / str(id))} - return metadata, DocumentUploadParameters(method="PUT", url=url, data=data) - - @classmethod - def decode_upload_token(cls, token: str) -> tuple[str, uuid.UUID]: - try: - payload = jwt.decode( - token, key=cls._JWT_SECRET, algorithms=[cls._JWT_ALGORITHM] - ) - except jwt.InvalidSignatureError: + async def _write(self, stream: AsyncIterator[bytes]) -> None: + if self.path.exists(): raise RagnaException( - "Token invalid", http_status_code=401, http_detail=RagnaException.EVENT + "File already exists", path=self.path, http_detail=RagnaException.EVENT ) - except jwt.ExpiredSignatureError: + + async with aiofiles.open(self.path, "wb") as file: + async for content in stream: + await file.write(content) + + def read(self) -> bytes: + if not self.path.is_file(): raise RagnaException( - "Token expired", http_status_code=401, http_detail=RagnaException.EVENT + "File does not exist", path=self.path, http_detail=RagnaException.EVENT ) - return payload["user"], uuid.UUID(payload["id"]) + + with open(self.path, "rb") as file: + return file.read() class Page(BaseModel): diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index 98c32a42..c3da0c76 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -286,20 +286,14 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message: return answer def _parse_documents(self, documents: Iterable[Any]) -> list[Document]: - documents_ = [] - for document in documents: - if not isinstance(document, Document): - document = LocalDocument.from_path(document) - - if not document.is_readable(): - raise RagnaException( - "Document not readable", - document=document, - http_status_code=404, - ) - - documents_.append(document) - return documents_ + return [ + ( + document + if isinstance(document, Document) + else LocalDocument.from_path(document) + ) + for document in documents + ] def _unpack_chat_params( self, params: dict[str, Any] diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py index d3194064..4de2737c 100644 --- a/ragna/deploy/_api.py +++ b/ragna/deploy/_api.py @@ -1,29 +1,23 @@ import uuid -from typing import Annotated, AsyncIterator, cast +from typing import Annotated, AsyncIterator -import aiofiles import pydantic from fastapi import ( APIRouter, Body, Depends, - Form, - HTTPException, UploadFile, ) from fastapi.responses import StreamingResponse -import ragna -import ragna.core from ragna._compat import anext from ragna.core._utils import default_user -from ragna.deploy import Config from . import _schemas as schemas from ._engine import Engine -def make_router(config: Config, engine: Engine) -> APIRouter: +def make_router(engine: Engine) -> APIRouter: router = APIRouter(tags=["API"]) def get_user() -> str: @@ -31,77 +25,32 @@ def get_user() -> str: UserDependency = Annotated[str, Depends(get_user)] - # TODO: the document endpoints do not go through the engine, because they'll change - # quite drastically when the UI no longer depends on the API - - _database = engine._database - - @router.post("/document") - async def create_document_upload_info( - user: UserDependency, - name: Annotated[str, Body(..., embed=True)], - ) -> schemas.DocumentUpload: - with _database.get_session() as session: - document = schemas.Document(name=name) - metadata, parameters = await config.document.get_upload_info( - config=config, user=user, id=document.id, name=document.name - ) - document.metadata = metadata - _database.add_document( - session, user=user, document=document, metadata=metadata - ) - return schemas.DocumentUpload(parameters=parameters, document=document) - - # TODO: Add UI support and documentation for this endpoint (#406) @router.post("/documents") - async def create_documents_upload_info( - user: UserDependency, - names: Annotated[list[str], Body(..., embed=True)], - ) -> list[schemas.DocumentUpload]: - with _database.get_session() as session: - document_metadata_collection = [] - document_upload_collection = [] - for name in names: - document = schemas.Document(name=name) - metadata, parameters = await config.document.get_upload_info( - config=config, user=user, id=document.id, name=document.name - ) - document.metadata = metadata - document_metadata_collection.append((document, metadata)) - document_upload_collection.append( - schemas.DocumentUpload(parameters=parameters, document=document) - ) - - _database.add_documents( - session, - user=user, - document_metadata_collection=document_metadata_collection, - ) - return document_upload_collection - - # TODO: Add new endpoint for batch uploading documents (#407) - @router.put("/document") - async def upload_document( - token: Annotated[str, Form()], file: UploadFile - ) -> schemas.Document: - if not issubclass(config.document, ragna.core.LocalDocument): - raise HTTPException( - status_code=400, - detail="Ragna configuration does not support local upload", - ) - with _database.get_session() as session: - user, id = ragna.core.LocalDocument.decode_upload_token(token) - document = _database.get_document(session, user=user, id=id) - - core_document = cast( - ragna.core.LocalDocument, engine._to_core.document(document) - ) - core_document.path.parent.mkdir(parents=True, exist_ok=True) - async with aiofiles.open(core_document.path, "wb") as document_file: - while content := await file.read(1024): - await document_file.write(content) - - return document + def register_documents( + user: UserDependency, document_registrations: list[schemas.DocumentRegistration] + ) -> list[schemas.Document]: + return engine.register_documents( + user=user, document_registrations=document_registrations + ) + + @router.put("/documents") + async def upload_documents( + user: UserDependency, documents: list[UploadFile] + ) -> None: + def make_content_stream(file: UploadFile) -> AsyncIterator[bytes]: + async def content_stream() -> AsyncIterator[bytes]: + while content := await file.read(16 * 1024): + yield content + + return content_stream() + + await engine.store_documents( + user=user, + ids_and_streams=[ + (uuid.UUID(document.filename), make_content_stream(document)) + for document in documents + ], + ) @router.get("/components") def get_components(_: UserDependency) -> schemas.Components: @@ -110,9 +59,9 @@ def get_components(_: UserDependency) -> schemas.Components: @router.post("/chats") async def create_chat( user: UserDependency, - chat_metadata: schemas.ChatMetadata, + chat_creation: schemas.ChatCreation, ) -> schemas.Chat: - return engine.create_chat(user=user, chat_metadata=chat_metadata) + return engine.create_chat(user=user, chat_creation=chat_creation) @router.get("/chats") async def get_chats(user: UserDependency) -> list[schemas.Chat]: diff --git a/ragna/deploy/_core.py b/ragna/deploy/_core.py index 67f067e0..6df4b71b 100644 --- a/ragna/deploy/_core.py +++ b/ragna/deploy/_core.py @@ -27,7 +27,6 @@ def make_app( ignore_unavailable_components: bool, open_browser: bool, ) -> FastAPI: - ragna.local_root(config.local_root) set_redirect_root_path(config.root_path) lifespan: Optional[Callable[[FastAPI], AsyncContextManager]] @@ -38,7 +37,7 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: def target() -> None: client = httpx.Client(base_url=config._url) - def server_available(): + def server_available() -> bool: try: return client.get("/health").is_success except httpx.ConnectError: @@ -77,7 +76,7 @@ def server_available(): ) if api: - app.include_router(make_api_router(config, engine), prefix="/api") + app.include_router(make_api_router(engine), prefix="/api") if ui: panel_app = make_ui_app(config=config) diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py index 323ccd21..529fa3b6 100644 --- a/ragna/deploy/_database.py +++ b/ragna/deploy/_database.py @@ -1,7 +1,7 @@ from __future__ import annotations import uuid -from typing import Any, Optional +from typing import Any, Collection, Optional from urllib.parse import urlsplit from sqlalchemy import create_engine, select @@ -42,83 +42,50 @@ def _get_user(self, session: Session, *, username: str) -> orm.User: return user - def add_document( - self, - session: Session, - *, - user: str, - document: schemas.Document, - metadata: dict[str, Any], - ) -> None: - session.add( - orm.Document( - id=document.id, - user_id=self._get_user(session, username=user).id, - name=document.name, - metadata_=metadata, - ) - ) - session.commit() - def add_documents( self, session: Session, *, user: str, - document_metadata_collection: list[tuple[schemas.Document, dict[str, Any]]], + documents: list[schemas.Document], ) -> None: - """ - Add multiple documents to the database. - - This function allows adding multiple documents at once by calling `add_all`. This is - important when there is non-negligible latency attached to each database operation. - """ - documents = [ - orm.Document( - id=document.id, - user_id=self._get_user(session, username=user).id, - name=document.name, - metadata_=metadata, - ) - for document, metadata in document_metadata_collection - ] - session.add_all(documents) + user_id = self._get_user(session, username=user).id + session.add_all( + [self._to_orm.document(document, user_id=user_id) for document in documents] + ) session.commit() - def get_document( - self, session: Session, *, user: str, id: uuid.UUID - ) -> schemas.Document: - document = session.execute( - select(orm.Document).where( - (orm.Document.user_id == self._get_user(session, username=user).id) - & (orm.Document.id == id) - ) - ).scalar_one_or_none() - return self._to_schema.document(document) - - def add_chat(self, session: Session, *, user: str, chat: schemas.Chat) -> None: - document_ids = {document.id for document in chat.metadata.documents} - # FIXME also check if the user is allowed to access the documents? + def _get_orm_documents( + self, session: Session, *, user: str, ids: Collection[uuid.UUID] + ) -> list[orm.Document]: + # FIXME also check if the user is allowed to access the documents + # FIXME: maybe just take the user id to avoid getting it twice in add_chat? documents = ( - session.execute( - select(orm.Document).where(orm.Document.id.in_(document_ids)) - ) + session.execute(select(orm.Document).where(orm.Document.id.in_(ids))) .scalars() .all() ) - if len(documents) != len(document_ids): + if len(documents) != len(ids): raise RagnaException( - str(document_ids - {document.id for document in documents}) + str(set(ids) - {document.id for document in documents}) ) + return documents # type: ignore[no-any-return] + + def get_documents( + self, session: Session, *, user: str, ids: Collection[uuid.UUID] + ) -> list[schemas.Document]: + return [ + self._to_schema.document(document) + for document in self._get_orm_documents(session, user=user, ids=ids) + ] + + def add_chat(self, session: Session, *, user: str, chat: schemas.Chat) -> None: orm_chat = self._to_orm.chat( - chat, - user_id=self._get_user(session, username=user).id, - # We have to pass the documents here, because SQLAlchemy does not allow a - # second instance of orm.Document with the same primary key in the session. - documents=documents, + chat, user_id=self._get_user(session, username=user).id ) - session.add(orm_chat) + # We need to merge and not add here, because the documents are already in the DB + session.merge(orm_chat) session.commit() def _select_chat(self, *, eager: bool = False) -> Any: @@ -213,21 +180,17 @@ def chat( chat: schemas.Chat, *, user_id: uuid.UUID, - documents: Optional[list[orm.Document]] = None, ) -> orm.Chat: - if documents is None: - documents = [ - self.document(document, user_id=user_id) - for document in chat.metadata.documents - ] return orm.Chat( id=chat.id, user_id=user_id, - name=chat.metadata.name, - documents=documents, - source_storage=chat.metadata.source_storage, - assistant=chat.metadata.assistant, - params=chat.metadata.params, + name=chat.name, + documents=[ + self.document(document, user_id=user_id) for document in chat.documents + ], + source_storage=chat.source_storage, + assistant=chat.assistant, + params=chat.params, messages=[ self.message(message, chat_id=chat.id) for message in chat.messages ], @@ -262,13 +225,11 @@ def message(self, message: orm.Message) -> schemas.Message: def chat(self, chat: orm.Chat) -> schemas.Chat: return schemas.Chat( id=chat.id, - metadata=schemas.ChatMetadata( - name=chat.name, - documents=[self.document(document) for document in chat.documents], - source_storage=chat.source_storage, - assistant=chat.assistant, - params=chat.params, - ), + name=chat.name, + documents=[self.document(document) for document in chat.documents], + source_storage=chat.source_storage, + assistant=chat.assistant, + params=chat.params, messages=[self.message(message) for message in chat.messages], prepared=chat.prepared, ) diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 847f7a93..2209a61f 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -1,8 +1,13 @@ import uuid -from typing import Any, AsyncIterator, Optional, Type +from typing import Any, AsyncIterator, Optional, Type, cast +from fastapi import status as http_status_code + +import ragna from ragna import Rag, core from ragna._compat import aiter, anext +from ragna._utils import make_directory +from ragna.core import RagnaException from ragna.core._rag import SpecialChatParams from ragna.deploy import Config @@ -13,15 +18,20 @@ class Engine: def __init__(self, *, config: Config, ignore_unavailable_components: bool) -> None: self._config = config + ragna.local_root(config.local_root) + self._documents_root = make_directory(config.local_root / "documents") + self.supports_store_documents = issubclass( + self._config.document, ragna.core.LocalDocument + ) self._database = Database(url=config.database_url) - self._rag: Rag = Rag( + self._rag = Rag( # type: ignore[var-annotated] config=config, ignore_unavailable_components=ignore_unavailable_components, ) - self._to_core = SchemaToCoreConverter(config=config, rag=self._rag) + self._to_core = SchemaToCoreConverter(config=self._config, rag=self._rag) self._to_schema = CoreToSchemaConverter() def _get_component_json_schema( @@ -56,12 +66,59 @@ def get_components(self) -> schemas.Components: ], ) + def register_documents( + self, *, user: str, document_registrations: list[schemas.DocumentRegistration] + ) -> list[schemas.Document]: + # We create core.Document's first, because they might update the metadata + core_documents = [ + self._config.document( + name=registration.name, metadata=registration.metadata + ) + for registration in document_registrations + ] + documents = [self._to_schema.document(document) for document in core_documents] + + with self._database.get_session() as session: + self._database.add_documents(session, user=user, documents=documents) + + return documents + + async def store_documents( + self, + *, + user: str, + ids_and_streams: list[tuple[uuid.UUID, AsyncIterator[bytes]]], + ) -> None: + if not self.supports_store_documents: + raise RagnaException( + "Ragna configuration does not support local upload", + http_status_code=http_status_code.HTTP_400_BAD_REQUEST, + ) + + ids, streams = zip(*ids_and_streams) + + with self._database.get_session() as session: + documents = self._database.get_documents(session, user=user, ids=ids) + + for document, stream in zip(documents, streams): + core_document = cast( + ragna.core.LocalDocument, self._to_core.document(document) + ) + await core_document._write(stream) + def create_chat( - self, *, user: str, chat_metadata: schemas.ChatMetadata + self, *, user: str, chat_creation: schemas.ChatCreation ) -> schemas.Chat: - chat = schemas.Chat(metadata=chat_metadata) + params = chat_creation.model_dump() + document_ids = params.pop("document_ids") + with self._database.get_session() as session: + documents = self._database.get_documents( + session, user=user, ids=document_ids + ) + + chat = schemas.Chat(documents=documents, **params) - # Although we don't need the actual core.Chat here, this just performs the input + # Although we don't need the actual core.Chat here, this performs the input # validation. self._to_core.chat(chat, user=user) @@ -146,13 +203,13 @@ def message(self, message: schemas.Message) -> core.Message: def chat(self, chat: schemas.Chat, *, user: str) -> core.Chat: core_chat = self._rag.chat( - documents=[self.document(document) for document in chat.metadata.documents], - source_storage=chat.metadata.source_storage, - assistant=chat.metadata.assistant, user=user, chat_id=chat.id, - chat_name=chat.metadata.name, - **chat.metadata.params, + chat_name=chat.name, + documents=[self.document(document) for document in chat.documents], + source_storage=chat.source_storage, + assistant=chat.assistant, + **chat.params, ) core_chat._messages = [self.message(message) for message in chat.messages] core_chat._prepared = chat.prepared @@ -182,7 +239,9 @@ def message( ) -> schemas.Message: return schemas.Message( id=message.id, - content=content_override or message.content, + content=( + content_override if content_override is not None else message.content + ), role=message.role, sources=[self.source(source) for source in message.sources], timestamp=message.timestamp, @@ -193,13 +252,11 @@ def chat(self, chat: core.Chat) -> schemas.Chat: del params["user"] return schemas.Chat( id=params.pop("chat_id"), - metadata=schemas.ChatMetadata( - name=params.pop("chat_name"), - source_storage=chat.source_storage.display_name(), - assistant=chat.assistant.display_name(), - params=params, - documents=[self.document(document) for document in chat.documents], - ), + name=params.pop("chat_name"), + documents=[self.document(document) for document in chat.documents], + source_storage=chat.source_storage.display_name(), + assistant=chat.assistant.display_name(), + params=params, messages=[self.message(message) for message in chat._messages], prepared=chat._prepared, ) diff --git a/ragna/deploy/_schemas.py b/ragna/deploy/_schemas.py index 55ae333f..cc5490b7 100644 --- a/ragna/deploy/_schemas.py +++ b/ragna/deploy/_schemas.py @@ -15,15 +15,15 @@ class Components(BaseModel): assistants: list[dict[str, Any]] -class Document(BaseModel): - id: uuid.UUID = Field(default_factory=uuid.uuid4) +class DocumentRegistration(BaseModel): name: str metadata: dict[str, Any] = Field(default_factory=dict) -class DocumentUpload(BaseModel): - parameters: ragna.core.DocumentUploadParameters - document: Document +class Document(BaseModel): + id: uuid.UUID = Field(default_factory=uuid.uuid4) + name: str + metadata: dict[str, Any] class Source(BaseModel): @@ -43,16 +43,20 @@ class Message(BaseModel): timestamp: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) -class ChatMetadata(BaseModel): +class ChatCreation(BaseModel): name: str + document_ids: list[uuid.UUID] source_storage: str assistant: str - params: dict - documents: list[Document] + params: dict[str, Any] = Field(default_factory=dict) class Chat(BaseModel): id: uuid.UUID = Field(default_factory=uuid.uuid4) - metadata: ChatMetadata + name: str + documents: list[Document] + source_storage: str + assistant: str + params: dict[str, Any] messages: list[Message] = Field(default_factory=list) prepared: bool = False diff --git a/scripts/add_chats.py b/scripts/add_chats.py index b8c15194..5f550289 100644 --- a/scripts/add_chats.py +++ b/scripts/add_chats.py @@ -1,71 +1,70 @@ import datetime import json -import os import httpx -from ragna.core._utils import default_user - def main(): client = httpx.Client(base_url="http://127.0.0.1:31476") - client.get("/").raise_for_status() + client.get("/health").raise_for_status() + + # ## authentication + # + # username = default_user() + # token = ( + # client.post( + # "/token", + # data={ + # "username": username, + # "password": os.environ.get( + # "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username + # ), + # }, + # ) + # .raise_for_status() + # .json() + # ) + # client.headers["Authorization"] = f"Bearer {token}" + + print() - ## authentication + ## documents - username = default_user() - token = ( + documents = ( client.post( - "/token", - data={ - "username": username, - "password": os.environ.get( - "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username - ), - }, + "/api/documents", json=[{"name": f"document{i}.txt"} for i in range(5)] ) .raise_for_status() .json() ) - client.headers["Authorization"] = f"Bearer {token}" - ## documents - - documents = [] - for i in range(5): - name = f"document{i}.txt" - document_upload = ( - client.post("/document", json={"name": name}).raise_for_status().json() - ) - parameters = document_upload["parameters"] - client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": f"Content of {name}".encode()}, - ).raise_for_status() - documents.append(document_upload["document"]) + client.put( + "/api/documents", + files=[ + ("documents", (document["id"], f"Content of {document['name']}".encode())) + for document in documents + ], + ).raise_for_status() ## chat 1 chat = ( client.post( - "/chats", + "/api/chats", json={ "name": "Test chat", - "documents": documents[:2], + "document_ids": [document["id"] for document in documents[:2]], "source_storage": "Ragna/DemoSourceStorage", "assistant": "Ragna/DemoAssistant", - "params": {}, }, ) .raise_for_status() .json() ) - client.post(f"/chats/{chat['id']}/prepare").raise_for_status() + client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "Hello!"}, ).raise_for_status() @@ -73,55 +72,53 @@ def main(): chat = ( client.post( - "/chats", + "/api/chats", json={ "name": f"Chat {datetime.datetime.now():%x %X}", - "documents": documents[2:4], + "document_ids": [document["id"] for document in documents[2:]], "source_storage": "Ragna/DemoSourceStorage", "assistant": "Ragna/DemoAssistant", - "params": {}, }, ) .raise_for_status() .json() ) - client.post(f"/chats/{chat['id']}/prepare").raise_for_status() + client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() for _ in range(3): client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "What is Ragna? Please, I need to know!"}, ).raise_for_status() - ## chat 3 + # ## chat 3 chat = ( client.post( - "/chats", + "/api/chats", json={ "name": ( "Really long chat name that likely needs to be truncated somehow. " "If you can read this, truncating failed :boom:" ), - "documents": [documents[i] for i in [0, 2, 4]], + "document_ids": [documents[i]["id"] for i in [0, 2, 4]], "source_storage": "Ragna/DemoSourceStorage", "assistant": "Ragna/DemoAssistant", - "params": {}, }, ) .raise_for_status() .json() ) - client.post(f"/chats/{chat['id']}/prepare").raise_for_status() + client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "Hello!"}, ).raise_for_status() client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "Ok, in that case show me some pretty markdown!"}, ).raise_for_status() - chats = client.get("/chats").raise_for_status().json() + chats = client.get("/api/chats").raise_for_status().json() print(json.dumps(chats)) diff --git a/tests/deploy/api/test_batch_endpoints.py b/tests/deploy/api/test_batch_endpoints.py deleted file mode 100644 index 2736df24..00000000 --- a/tests/deploy/api/test_batch_endpoints.py +++ /dev/null @@ -1,86 +0,0 @@ -from fastapi import status -from fastapi.testclient import TestClient - -from ragna.deploy import Config - -from .utils import authenticate, make_api_app - - -def test_batch_sequential_upload_equivalence(tmp_local_root): - "Check that uploading documents sequentially and in batch gives the same result" - config = Config(local_root=tmp_local_root) - - document_root = config.local_root / "documents" - document_root.mkdir() - document_path1 = document_root / "test1.txt" - with open(document_path1, "w") as file: - file.write("!\n") - document_path2 = document_root / "test2.txt" - with open(document_path2, "w") as file: - file.write("?\n") - - with TestClient( - make_api_app(config=Config(), ignore_unavailable_components=False) - ) as client: - authenticate(client) - - document1_upload = ( - client.post("/api/document", json={"name": document_path1.name}) - .raise_for_status() - .json() - ) - document2_upload = ( - client.post("/api/document", json={"name": document_path2.name}) - .raise_for_status() - .json() - ) - - documents_upload = ( - client.post( - "/api/documents", - json={"names": [document_path1.name, document_path2.name]}, - ) - .raise_for_status() - .json() - ) - - assert ( - document1_upload["parameters"]["url"] - == documents_upload[0]["parameters"]["url"] - ) - assert ( - document2_upload["parameters"]["url"] - == documents_upload[1]["parameters"]["url"] - ) - - assert ( - document1_upload["document"]["name"] - == documents_upload[0]["document"]["name"] - ) - assert ( - document2_upload["document"]["name"] - == documents_upload[1]["document"]["name"] - ) - - # assuming that if test passes for first document it will also pass for the other - with open(document_path1, "rb") as file: - response_sequential_upload1 = client.request( - document1_upload["parameters"]["method"], - document1_upload["parameters"]["url"], - data=document1_upload["parameters"]["data"], - files={"file": file}, - ) - response_batch_upload1 = client.request( - documents_upload[0]["parameters"]["method"], - documents_upload[0]["parameters"]["url"], - data=documents_upload[0]["parameters"]["data"], - files={"file": file}, - ) - - assert response_sequential_upload1.status_code == status.HTTP_200_OK - assert response_batch_upload1.status_code == status.HTTP_200_OK - - assert ( - response_sequential_upload1.json()["name"] - == response_batch_upload1.json()["name"] - ) diff --git a/tests/deploy/api/test_components.py b/tests/deploy/api/test_components.py index b459f12e..0d44790c 100644 --- a/tests/deploy/api/test_components.py +++ b/tests/deploy/api/test_components.py @@ -67,31 +67,22 @@ def test_unknown_component(tmp_local_root): ) as client: authenticate(client) - document_upload = ( - client.post("/api/document", json={"name": document_path.name}) + document = ( + client.post("/api/documents", json=[{"name": document_path.name}]) .raise_for_status() - .json() + .json()[0] ) - document = document_upload["document"] - assert document["name"] == document_path.name - parameters = document_upload["parameters"] with open(document_path, "rb") as file: - client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": file}, - ) + client.put("/api/documents", files={"documents": (document["id"], file)}) response = client.post( "/api/chats", json={ "name": "test-chat", + "document_ids": [document["id"]], "source_storage": "unknown_source_storage", "assistant": "unknown_assistant", - "params": {}, - "documents": [document], }, ) diff --git a/tests/deploy/api/test_e2e.py b/tests/deploy/api/test_e2e.py index c1a80ad5..61632251 100644 --- a/tests/deploy/api/test_e2e.py +++ b/tests/deploy/api/test_e2e.py @@ -43,26 +43,21 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): assert client.get("/api/chats").raise_for_status().json() == [] - document_upload = ( - client.post("/api/document", json={"name": document_path.name}) + documents = ( + client.post("/api/documents", json=[{"name": document_path.name}]) .raise_for_status() .json() ) - document = document_upload["document"] + assert len(documents) == 1 + document = documents[0] assert document["name"] == document_path.name - parameters = document_upload["parameters"] with open(document_path, "rb") as file: - client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": file}, - ) + client.put("/api/documents", files={"documents": (document["id"], file)}) components = client.get("/api/components").raise_for_status().json() - documents = components["documents"] - assert set(documents) == config.document.supported_suffixes() + supported_documents = components["documents"] + assert set(supported_documents) == config.document.supported_suffixes() source_storages = [ json_schema["title"] for json_schema in components["source_storages"] ] @@ -77,15 +72,19 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): source_storage = source_storages[0] assistant = assistants[0] - chat_metadata = { + chat_creation = { "name": "test-chat", + "document_ids": [document["id"]], "source_storage": source_storage, "assistant": assistant, "params": {"multiple_answer_chunks": multiple_answer_chunks}, - "documents": [document], } - chat = client.post("/api/chats", json=chat_metadata).raise_for_status().json() - assert chat["metadata"] == chat_metadata + chat = client.post("/api/chats", json=chat_creation).raise_for_status().json() + for field in ["name", "source_storage", "assistant", "params"]: + assert chat[field] == chat_creation[field] + assert [document["id"] for document in chat["documents"]] == chat_creation[ + "document_ids" + ] assert not chat["prepared"] assert chat["messages"] == [] From 974862a8b3bb540d1848a3a6cf0559b266a22d41 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 11 Jul 2024 10:55:25 +0200 Subject: [PATCH 06/29] use backend engine in UI (#443) --- ragna/_cli/core.py | 22 +-- ragna/deploy/_core.py | 2 +- ragna/deploy/_ui/api_wrapper.py | 67 +++---- ragna/deploy/_ui/app.py | 14 +- ragna/deploy/_ui/central_view.py | 35 ++-- ragna/deploy/_ui/left_sidebar.py | 2 +- ragna/deploy/_ui/modal_configuration.py | 44 ++++- tests/deploy/ui/test_ui.py | 252 +++++++++++------------- 8 files changed, 213 insertions(+), 225 deletions(-) diff --git a/ragna/_cli/core.py b/ragna/_cli/core.py index 2e9030df..64d89e25 100644 --- a/ragna/_cli/core.py +++ b/ragna/_cli/core.py @@ -1,7 +1,6 @@ from pathlib import Path from typing import Annotated, Optional -import httpx import rich import typer import uvicorn @@ -74,13 +73,12 @@ def deploy( *, config: ConfigOption = "./ragna.toml", # type: ignore[assignment] api: Annotated[ - Optional[bool], + bool, typer.Option( "--api/--no-api", help="Deploy the Ragna REST API.", - show_default="True if UI is not deployed and otherwise check availability", ), - ] = None, + ] = True, ui: Annotated[ bool, typer.Option( @@ -98,22 +96,14 @@ def deploy( ] = False, open_browser: Annotated[ Optional[bool], - typer.Option(help="Open a browser when Ragna is deployed."), + typer.Option( + help="Open a browser when Ragna is deployed.", + show_default="value of ui / no-ui", + ), ] = None, ) -> None: - def api_available() -> bool: - try: - return httpx.get(f"{config._url}/health").is_success - except httpx.ConnectError: - return False - - if api is None: - api = not api_available() if ui else True - if not (api or ui): raise Exception - elif ui and not api and not api_available(): - raise Exception if open_browser is None: open_browser = ui diff --git a/ragna/deploy/_core.py b/ragna/deploy/_core.py index 6df4b71b..44c672a8 100644 --- a/ragna/deploy/_core.py +++ b/ragna/deploy/_core.py @@ -79,7 +79,7 @@ def server_available() -> bool: app.include_router(make_api_router(engine), prefix="/api") if ui: - panel_app = make_ui_app(config=config) + panel_app = make_ui_app(engine) panel_app.serve_with_fastapi(app, endpoint="/ui") @app.get("/", include_in_schema=False) diff --git a/ragna/deploy/_ui/api_wrapper.py b/ragna/deploy/_ui/api_wrapper.py index 5fb7b42d..170e8bbd 100644 --- a/ragna/deploy/_ui/api_wrapper.py +++ b/ragna/deploy/_ui/api_wrapper.py @@ -1,62 +1,53 @@ -import json +import uuid from datetime import datetime import emoji -import httpx import param +from ragna.core._utils import default_user +from ragna.deploy import _schemas as schemas +from ragna.deploy._engine import Engine -# The goal is this class is to provide ready-to-use functions to interact with the API -class ApiWrapper(param.Parameterized): - def __init__(self, api_url, **params): - self.api_url = api_url - self.client = httpx.AsyncClient(base_url=api_url, timeout=60) - super().__init__(**params) +class ApiWrapper(param.Parameterized): + def __init__(self, engine: Engine): + super().__init__() + self._user = default_user() + self._engine = engine async def get_chats(self): - json_data = (await self.client.get("/chats")).raise_for_status().json() + json_data = [ + chat.model_dump(mode="json") + for chat in self._engine.get_chats(user=self._user) + ] for chat in json_data: chat["messages"] = [self.improve_message(msg) for msg in chat["messages"]] return json_data async def answer(self, chat_id, prompt): - async with self.client.stream( - "POST", - f"/chats/{chat_id}/answer", - json={"prompt": prompt, "stream": True}, - ) as response: - async for data in response.aiter_lines(): - yield self.improve_message(json.loads(data)) + async for message in self._engine.answer_stream( + user=self._user, chat_id=uuid.UUID(chat_id), prompt=prompt + ): + yield self.improve_message(message.model_dump(mode="json")) async def get_components(self): - return (await self.client.get("/components")).raise_for_status().json() - - # Upload and related functions - def upload_endpoints(self): - return { - "informations_endpoint": f"{self.api_url}/document", - } + return self._engine.get_components().model_dump(mode="json") async def start_and_prepare( self, name, documents, source_storage, assistant, params ): - response = await self.client.post( - "/chats", - json={ - "name": name, - "documents": documents, - "source_storage": source_storage, - "assistant": assistant, - "params": params, - }, + chat = self._engine.create_chat( + user=self._user, + chat_creation=schemas.ChatCreation( + name=name, + document_ids=[document.id for document in documents], + source_storage=source_storage, + assistant=assistant, + params=params, + ), ) - chat = response.raise_for_status().json() - - response = await self.client.post(f"/chats/{chat['id']}/prepare", timeout=None) - response.raise_for_status() - - return chat["id"] + await self._engine.prepare_chat(user=self._user, id=chat.id) + return str(chat.id) def improve_message(self, msg): msg["timestamp"] = datetime.strptime(msg["timestamp"], "%Y-%m-%dT%H:%M:%S.%f") diff --git a/ragna/deploy/_ui/app.py b/ragna/deploy/_ui/app.py index 49b8d628..052ff36d 100644 --- a/ragna/deploy/_ui/app.py +++ b/ragna/deploy/_ui/app.py @@ -6,7 +6,7 @@ from fastapi import FastAPI from fastapi.staticfiles import StaticFiles -from ragna.deploy import Config +from ragna.deploy._engine import Engine from . import js from . import styles as ui @@ -22,13 +22,13 @@ class App(param.Parameterized): - def __init__(self, *, api_url): + def __init__(self, engine: Engine): super().__init__() # Apply the design modifiers to the panel components # It returns all the CSS files of the modifiers self.css_filepaths = ui.apply_design_modifiers() - self.api_url = api_url + self._engine = engine def get_template(self): # A bit hacky, but works. @@ -73,7 +73,7 @@ def get_template(self): return template def index_page(self): - api_wrapper = ApiWrapper(api_url=self.api_url) + api_wrapper = ApiWrapper(self._engine) template = self.get_template() main_page = MainPage(api_wrapper=api_wrapper, template=template) @@ -131,7 +131,7 @@ def get_component_resource(path: str): def serve_with_fastapi(self, app: FastAPI, endpoint: str): self.add_panel_app(app, self.index_page, endpoint) - for dir in ["css", "imgs", "resources"]: + for dir in ["css", "imgs"]: app.mount( f"/{dir}", StaticFiles(directory=str(Path(__file__).parent / dir)), @@ -139,5 +139,5 @@ def serve_with_fastapi(self, app: FastAPI, endpoint: str): ) -def app(*, config: Config) -> App: - return App(api_url=f"{config._url}/api") +def app(engine: Engine) -> App: + return App(engine) diff --git a/ragna/deploy/_ui/central_view.py b/ragna/deploy/_ui/central_view.py index ae0cefc3..00de3e8e 100644 --- a/ragna/deploy/_ui/central_view.py +++ b/ragna/deploy/_ui/central_view.py @@ -189,11 +189,11 @@ def on_click_chat_info_wrapper(self, event): pills = "".join( [ f"""
{d['name']}
""" - for d in self.current_chat["metadata"]["documents"] + for d in self.current_chat["documents"] ] ) - grid_height = len(self.current_chat["metadata"]["documents"]) // 3 + grid_height = len(self.current_chat["documents"]) // 3 markdown = "\n".join( [ @@ -202,14 +202,14 @@ def on_click_chat_info_wrapper(self, event): f"
{pills}

\n\n", "----", "**Source Storage**", - f"""{self.current_chat['metadata']['source_storage']}\n""", + f"""{self.current_chat['source_storage']}\n""", "----", "**Assistant**", - f"""{self.current_chat['metadata']['assistant']}\n""", + f"""{self.current_chat['assistant']}\n""", "**Advanced configuration**", *[ f"- **{key.replace('_', ' ').title()}**: {value}" - for key, value in self.current_chat["metadata"]["params"].items() + for key, value in self.current_chat["params"].items() ], ] ) @@ -275,7 +275,7 @@ def get_user_from_role(self, role: Literal["system", "user", "assistant"]) -> st elif role == "user": return cast(str, self.user) elif role == "assistant": - return cast(str, self.current_chat["metadata"]["assistant"]) + return cast(str, self.current_chat["assistant"]) else: raise RuntimeError @@ -301,12 +301,15 @@ async def chat_callback( message.clipboard_button.value = message.content_pane.object message.assistant_toolbar.visible = True - except Exception: + except Exception as exc: + import traceback + yield RagnaChatMessage( - ( - "Sorry, something went wrong. " - "If this problem persists, please contact your administrator." - ), + # ( + # "Sorry, something went wrong. " + # "If this problem persists, please contact your administrator." + # ), + "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)), role="system", user=self.get_user_from_role("system"), ) @@ -358,7 +361,7 @@ def header(self): current_chat_name = "" if self.current_chat is not None: - current_chat_name = self.current_chat["metadata"]["name"] + current_chat_name = self.current_chat["name"] chat_name_header = pn.pane.HTML( f"

{current_chat_name}

", @@ -370,9 +373,9 @@ def header(self): if ( self.current_chat is not None and "metadata" in self.current_chat - and "documents" in self.current_chat["metadata"] + and "documents" in self.current_chat ): - doc_names = [d["name"] for d in self.current_chat["metadata"]["documents"]] + doc_names = [d["name"] for d in self.current_chat["documents"]] # FIXME: Instead of setting a hard limit of 20 documents here, this should # scale automatically with the width of page @@ -385,7 +388,9 @@ def header(self): chat_documents_pills.append(pill) - self.chat_info_button.name = f"{self.current_chat['metadata']['assistant']} | {self.current_chat['metadata']['source_storage']}" + self.chat_info_button.name = ( + f"{self.current_chat['assistant']} | {self.current_chat['source_storage']}" + ) return pn.Row( chat_name_header, diff --git a/ragna/deploy/_ui/left_sidebar.py b/ragna/deploy/_ui/left_sidebar.py index ab8bc1c0..21c747aa 100644 --- a/ragna/deploy/_ui/left_sidebar.py +++ b/ragna/deploy/_ui/left_sidebar.py @@ -62,7 +62,7 @@ def __panel__(self): self.chat_buttons = [] for chat in self.chats: button = pn.widgets.Button( - name=chat["metadata"]["name"], + name=chat["name"], css_classes=["chat_button"], ) button.on_click(lambda event, c=chat: self.on_click_chat_wrapper(event, c)) diff --git a/ragna/deploy/_ui/modal_configuration.py b/ragna/deploy/_ui/modal_configuration.py index 51ec02fb..a8bb44e7 100644 --- a/ragna/deploy/_ui/modal_configuration.py +++ b/ragna/deploy/_ui/modal_configuration.py @@ -1,11 +1,13 @@ from datetime import datetime, timedelta, timezone +from typing import AsyncIterator import panel as pn import param +from ragna.deploy import _schemas as schemas + from . import js from . import styles as ui -from .components.file_uploader import FileUploader def get_default_chat_name(timezone_offset=None): @@ -82,15 +84,11 @@ def __init__(self, api_wrapper, **params): self.api_wrapper = api_wrapper - upload_endpoints = self.api_wrapper.upload_endpoints() - self.chat_name_input = pn.widgets.TextInput.from_param( self.param.chat_name, ) - self.document_uploader = FileUploader( - [], # the allowed documents are set in the model_section function - upload_endpoints["informations_endpoint"], - ) + # FIXME: accept + self.document_uploader = pn.widgets.FileInput(multiple=True) # Most widgets (including those that use from_param) should be placed after the super init call self.cancel_button = pn.widgets.Button( @@ -114,12 +112,38 @@ def __init__(self, api_wrapper, **params): self.got_timezone = False - def did_click_on_start_chat_button(self, event): - if not self.document_uploader.can_proceed_to_upload(): + async def did_click_on_start_chat_button(self, event): + if not self.document_uploader.value: self.change_upload_files_label("missing_file") else: self.start_chat_button.disabled = True - self.document_uploader.perform_upload(event, self.did_finish_upload) + documents = self.api_wrapper._engine.register_documents( + user=self.api_wrapper._user, + document_registrations=[ + schemas.DocumentRegistration(name=name) + for name in self.document_uploader.filename + ], + ) + + if self.api_wrapper._engine.supports_store_documents: + + def make_content_stream(data: bytes) -> AsyncIterator[bytes]: + async def content_stream() -> AsyncIterator[bytes]: + yield data + + return content_stream() + + await self.api_wrapper._engine.store_documents( + user=self.api_wrapper._user, + ids_and_streams=[ + (document.id, make_content_stream(data)) + for document, data in zip( + documents, self.document_uploader.value + ) + ], + ) + + await self.did_finish_upload(documents) async def did_finish_upload(self, uploaded_documents): # at this point, the UI has uploaded the files to the API. diff --git a/tests/deploy/ui/test_ui.py b/tests/deploy/ui/test_ui.py index 85699278..e815011b 100644 --- a/tests/deploy/ui/test_ui.py +++ b/tests/deploy/ui/test_ui.py @@ -1,14 +1,13 @@ +import contextlib +import multiprocessing import socket -import subprocess -import sys import time import httpx -import panel as pn import pytest -from playwright.sync_api import Page, expect +from playwright.sync_api import expect -from ragna._utils import timeout_after +from ragna._cli.core import deploy as _deploy from ragna.deploy import Config from tests.deploy.utils import TestAssistant @@ -19,139 +18,118 @@ def get_available_port(): return s.getsockname()[1] +@contextlib.contextmanager +def deploy(config): + process = multiprocessing.Process( + target=_deploy, + kwargs=dict( + config=config, + api=False, + ui=True, + ignore_unavailable_components=False, + open_browser=False, + ), + ) + try: + process.start() + + client = httpx.Client(base_url=config._url) + + # FIXME: create a generic utility for this + def server_available() -> bool: + try: + return client.get("/health").is_success + except httpx.ConnectError: + return False + + while not server_available(): + time.sleep(0.1) + + yield process + finally: + process.terminate() + process.join() + process.close() + + @pytest.fixture -def config( - tmp_local_root, -): - config = Config( +def default_config(tmp_local_root): + return Config( local_root=tmp_local_root, assistants=[TestAssistant], - ui=dict(port=get_available_port()), - api=dict(port=get_available_port()), + port=get_available_port(), ) - path = tmp_local_root / "ragna.toml" - config.to_file(path) - return config - - -class Server: - def __init__(self, config): - self.config = config - self.base_url = f"http://{config.ui.hostname}:{config.ui.port}" - - def server_up(self): - try: - return httpx.get(self.base_url).is_success - except httpx.ConnectError: - return False - - @timeout_after(60) - def start(self): - self.proc = subprocess.Popen( - [ - sys.executable, - "-m", - "ragna", - "ui", - "--config", - self.config.local_root / "ragna.toml", - "--start-api", - "--ignore-unavailable-components", - "--no-open-browser", - ], - stdout=sys.stdout, - stderr=sys.stderr, - ) - - while not self.server_up(): - time.sleep(1) - - def stop(self): - self.proc.kill() - pn.state.kill_all_servers() - - def __enter__(self): - self.start() - return self - - def __exit__(self, *args): - self.stop() - - -def test_health(config, page: Page) -> None: - with Server(config) as server: - health_url = f"{server.base_url}/health" - response = page.goto(health_url) - assert response.ok - - -def test_start_chat(config, page: Page) -> None: - with Server(config) as server: - # Index page, no auth - index_url = server.base_url - page.goto(index_url) - expect(page.get_by_role("button", name="Sign In")).to_be_visible() - - # Authorize with no credentials - page.get_by_role("button", name="Sign In").click() - expect(page.get_by_role("button", name=" New Chat")).to_be_visible() - - # expect auth token to be set - cookies = page.context.cookies() - assert len(cookies) == 1 - cookie = cookies[0] - assert cookie.get("name") == "auth_token" - auth_token = cookie.get("value") - assert auth_token is not None - - # New page button - new_chat_button = page.get_by_role("button", name=" New Chat") - expect(new_chat_button).to_be_visible() - new_chat_button.click() - - document_root = config.local_root / "documents" - document_root.mkdir() - document_name = "test.txt" - document_path = document_root / document_name - with open(document_path, "w") as file: - file.write("!\n") - - # File upload selector - with page.expect_file_chooser() as fc_info: - page.locator(".fileUpload").click() - file_chooser = fc_info.value - file_chooser.set_files(document_path) - - # Upload document and expect to see it listed - file_list = page.locator(".fileListContainer") - expect(file_list.first).to_have_text(str(document_name)) - - chat_dialog = page.get_by_role("dialog") - expect(chat_dialog).to_be_visible() - start_chat_button = page.get_by_role("button", name="Start Conversation") - expect(start_chat_button).to_be_visible() - time.sleep(0.5) # hack while waiting for button to be fully clickable - start_chat_button.click(delay=5) - - chat_box_row = page.locator(".chat-interface-input-row") - expect(chat_box_row).to_be_visible() - - chat_box = chat_box_row.get_by_role("textbox") - expect(chat_box).to_be_visible() - - # Document should be in the database - chats_url = f"http://{config.api.hostname}:{config.api.port}/chats" - chats = httpx.get( - chats_url, headers={"Authorization": f"Bearer {auth_token}"} - ).json() - assert len(chats) == 1 - chat = chats[0] - chat_documents = chat["metadata"]["documents"] - assert len(chat_documents) == 1 - assert chat_documents[0]["name"] == document_name - - chat_box.fill("Tell me about the documents") - - chat_button = chat_box_row.get_by_role("button") - expect(chat_button).to_be_visible() - chat_button.click() + + +@pytest.fixture +def index_page(default_config, page): + config = default_config + with deploy(default_config): + page.goto(f"http://{config.hostname}:{config.port}/ui") + yield page + + +def test_start_chat(index_page, tmp_path) -> None: + # expect(page.get_by_role("button", name="Sign In")).to_be_visible() + + # # Authorize with no credentials + # page.get_by_role("button", name="Sign In").click() + # expect(page.get_by_role("button", name=" New Chat")).to_be_visible() + # + # # expect auth token to be set + # cookies = page.context.cookies() + # assert len(cookies) == 1 + # cookie = cookies[0] + # assert cookie.get("name") == "auth_token" + # auth_token = cookie.get("value") + # assert auth_token is not None + + # New page button + new_chat_button = index_page.get_by_role("button", name=" New Chat") + expect(new_chat_button).to_be_visible() + new_chat_button.click() + + # document_name = "test.txt" + # document_path = tmp_path / document_name + # with open(document_path, "w") as file: + # file.write("!\n") + + # # File upload selector + # with index_page.expect_file_chooser() as fc_info: + # index_page.locator(".fileUpload").click() + # file_chooser = fc_info.value + # file_chooser.set_files(document_path) + + # # Upload document and expect to see it listed + # file_list = page.locator(".fileListContainer") + # expect(file_list.first).to_have_text(str(document_name)) + # + # chat_dialog = page.get_by_role("dialog") + # expect(chat_dialog).to_be_visible() + # start_chat_button = page.get_by_role("button", name="Start Conversation") + # expect(start_chat_button).to_be_visible() + # time.sleep(0.5) # hack while waiting for button to be fully clickable + # start_chat_button.click(delay=5) + # + # chat_box_row = page.locator(".chat-interface-input-row") + # expect(chat_box_row).to_be_visible() + # + # chat_box = chat_box_row.get_by_role("textbox") + # expect(chat_box).to_be_visible() + # + # # Document should be in the database + # chats_url = f"http://{config.api.hostname}:{config.api.port}/chats" + # chats = httpx.get( + # chats_url, headers={"Authorization": f"Bearer {auth_token}"} + # ).json() + # assert len(chats) == 1 + # chat = chats[0] + # chat_documents = chat["metadata"]["documents"] + # assert len(chat_documents) == 1 + # assert chat_documents[0]["name"] == document_name + # + # chat_box.fill("Tell me about the documents") + # + # chat_button = chat_box_row.get_by_role("button") + # expect(chat_button).to_be_visible() + # chat_button.click() From f41aa30c0cc02a857a4c0683e8286828141709b6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 11 Dec 2024 10:18:39 +0100 Subject: [PATCH 07/29] use bokeh_fastapi through panel (#503) Co-authored-by: Kim Pevey --- .github/actions/setup-env/action.yml | 8 - .github/workflows/test.yml | 127 ++++---- pyproject.toml | 9 +- ragna/core/_components.py | 6 +- ragna/deploy/_core.py | 13 +- ragna/deploy/_engine.py | 2 +- ragna/deploy/_ui/api_wrapper.py | 4 +- ragna/deploy/_ui/app.py | 60 ---- ragna/deploy/_ui/central_view.py | 6 +- .../{card.css => chatinterface.css} | 0 .../_ui/css/modal_configuration/fileinput.css | 5 + ragna/deploy/_ui/modal_configuration.py | 16 +- ragna/deploy/_ui/styles.py | 3 +- requirements-docker.lock | 290 +++++++++--------- 14 files changed, 241 insertions(+), 308 deletions(-) rename ragna/deploy/_ui/css/chat_interface/{card.css => chatinterface.css} (100%) create mode 100644 ragna/deploy/_ui/css/modal_configuration/fileinput.css diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index 795d52bd..b39f4b18 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -16,7 +16,6 @@ runs: - name: Setup mambaforge and development environment uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest activate-environment: ragna-deploy-dev @@ -57,13 +56,6 @@ runs: shell: bash -el {0} run: playwright install - - name: Install dev dependencies - shell: bash -el {0} - run: | - pip install \ - git+https://github.com/bokeh/bokeh-fastapi.git@main \ - git+https://github.com/holoviz/panel@7377c9e99bef0d32cbc65e94e908e365211f4421 - - name: Install ragna shell: bash -el {0} run: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 611f5e13..1fd5080f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,9 +38,7 @@ jobs: matrix: os: - ubuntu-latest - # FIXME - # Building panel from source on Windows does not work through pip - # - windows-latest + - windows-latest - macos-latest python-version: ["3.10"] include: @@ -81,65 +79,64 @@ jobs: uses: pmeier/pytest-results-action@v0.3.0 with: path: test-results.xml - - pytest-ui: - strategy: - matrix: - os: - - ubuntu-latest - - windows-latest - - macos-latest - browser: - - chromium - - firefox - python-version: - - "3.10" - - "3.10" - - "3.12" - exclude: - - python-version: "3.11" - os: windows-latest - - python-version: "3.12" - os: windows-latest - - python-version: "3.11" - os: macos-latest - - python-version: "3.12" - os: macos-latest - include: - - browser: webkit - os: macos-latest - python-version: "3.10" - - fail-fast: false - - runs-on: ${{ matrix.os }} - - defaults: - run: - shell: bash -el {0} - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Setup environment - uses: ./.github/actions/setup-env - with: - python-version: ${{ matrix.python-version }} - - - name: Run unit tests - id: tests - run: | - pytest tests/deploy/ui \ - --browser ${{ matrix.browser }} \ - --video=retain-on-failure - - - name: Upload playwright video - if: failure() - uses: actions/upload-artifact@v4 - with: - name: - playwright-${{ matrix.os }}-${{ matrix.python-version}}-${{ github.run_id }} - path: test-results +# pytest-ui: +# strategy: +# matrix: +# os: +# - ubuntu-latest +# - windows-latest +# - macos-latest +# browser: +# - chromium +# - firefox +# python-version: +# - "3.10" +# - "3.10" +# - "3.12" +# exclude: +# - python-version: "3.11" +# os: windows-latest +# - python-version: "3.12" +# os: windows-latest +# - python-version: "3.11" +# os: macos-latest +# - python-version: "3.12" +# os: macos-latest +# include: +# - browser: webkit +# os: macos-latest +# python-version: "3.10" +# +# fail-fast: false +# +# runs-on: ${{ matrix.os }} +# +# defaults: +# run: +# shell: bash -el {0} +# +# steps: +# - name: Checkout repository +# uses: actions/checkout@v4 +# with: +# fetch-depth: 0 +# +# - name: Setup environment +# uses: ./.github/actions/setup-env +# with: +# python-version: ${{ matrix.python-version }} +# +# - name: Run unit tests +# id: tests +# run: | +# pytest tests/deploy/ui \ +# --browser ${{ matrix.browser }} \ +# --video=retain-on-failure +# +# - name: Upload playwright video +# if: failure() +# uses: actions/upload-artifact@v4 +# with: +# name: +# playwright-${{ matrix.os }}-${{ matrix.python-version}}-${{ github.run_id }} +# path: test-results diff --git a/pyproject.toml b/pyproject.toml index f44c823f..67ffe501 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,21 +15,22 @@ authors = [ readme = "README.md" classifiers = [ "Development Status :: 4 - Beta", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] requires-python = ">=3.10" dependencies = [ "aiofiles", + # Remove this and instead depend on panel[fastapi] + # after https://github.com/holoviz/panel/pull/7495 is released + "bokeh_fastapi==0.1.1", "emoji", "eval_type_backport; python_version<'3.10'", "fastapi", "httpx", "packaging", - # FIXME: pin them to released versions - "bokeh-fastapi", - "panel", + "panel==1.5.4", "pydantic>=2", "pydantic-core", "pydantic-settings>=2", diff --git a/ragna/core/_components.py b/ragna/core/_components.py index ecec015b..a4d3e519 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -1,11 +1,11 @@ from __future__ import annotations import abc -import datetime import enum import functools import inspect import uuid +from datetime import datetime, timezone from typing import ( AsyncIterable, AsyncIterator, @@ -185,7 +185,7 @@ def __init__( role: MessageRole = MessageRole.SYSTEM, sources: Optional[list[Source]] = None, id: Optional[uuid.UUID] = None, - timestamp: Optional[datetime.datetime] = None, + timestamp: Optional[datetime] = None, ) -> None: if isinstance(content, str): self._content: str = content @@ -200,7 +200,7 @@ def __init__( self.id = id if timestamp is None: - timestamp = datetime.datetime.utcnow() + timestamp = datetime.now(timezone.utc) self.timestamp = timestamp async def __aiter__(self) -> AsyncIterator[str]: diff --git a/ragna/deploy/_core.py b/ragna/deploy/_core.py index 44c672a8..65cca3cd 100644 --- a/ragna/deploy/_core.py +++ b/ragna/deploy/_core.py @@ -2,12 +2,15 @@ import threading import time import webbrowser +from pathlib import Path from typing import AsyncContextManager, AsyncIterator, Callable, Optional, cast import httpx +import panel.io.fastapi from fastapi import FastAPI, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response +from fastapi.staticfiles import StaticFiles import ragna from ragna.core import RagnaException @@ -79,8 +82,14 @@ def server_available() -> bool: app.include_router(make_api_router(engine), prefix="/api") if ui: - panel_app = make_ui_app(engine) - panel_app.serve_with_fastapi(app, endpoint="/ui") + ui_app = make_ui_app(engine) + panel.io.fastapi.add_applications({"/ui": ui_app.index_page}, app=app) + for dir in ["css", "imgs"]: + app.mount( + f"/{dir}", + StaticFiles(directory=str(Path(__file__).parent / "_ui" / dir)), + name=dir, + ) @app.get("/", include_in_schema=False) async def base_redirect() -> Response: diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 631ba5f0..6df48460 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -40,7 +40,7 @@ def _get_component_json_schema( json_schema = component._protocol_model().model_json_schema() # FIXME: there is likely a better way to exclude certain fields builtin in # pydantic - for special_param in SpecialChatParams.model_fields: + for special_param in SpecialChatParams.__pydantic_fields__: if ( "properties" in json_schema and special_param in json_schema["properties"] diff --git a/ragna/deploy/_ui/api_wrapper.py b/ragna/deploy/_ui/api_wrapper.py index 6ee2f8a3..9dcc6b16 100644 --- a/ragna/deploy/_ui/api_wrapper.py +++ b/ragna/deploy/_ui/api_wrapper.py @@ -30,8 +30,8 @@ async def answer(self, chat_id, prompt): ): yield self.improve_message(message.model_dump(mode="json")) - async def get_components(self): - return self._engine.get_components().model_dump(mode="json") + def get_components(self): + return self._engine.get_components() async def start_and_prepare( self, name, documents, source_storage, assistant, params diff --git a/ragna/deploy/_ui/app.py b/ragna/deploy/_ui/app.py index 052ff36d..4163c378 100644 --- a/ragna/deploy/_ui/app.py +++ b/ragna/deploy/_ui/app.py @@ -1,10 +1,5 @@ -from pathlib import Path -from typing import cast - import panel as pn import param -from fastapi import FastAPI -from fastapi.staticfiles import StaticFiles from ragna.deploy._engine import Engine @@ -83,61 +78,6 @@ def index_page(self): def health_page(self): return pn.pane.HTML("

Ok

") - def add_panel_app(self, server, panel_app_fn, endpoint): - # FIXME: this code will ultimately be distributed as part of panel - from functools import partial - - import panel as pn - from bokeh.application import Application - from bokeh.application.handlers.function import FunctionHandler - from bokeh_fastapi import BokehFastAPI - from bokeh_fastapi.handler import WSHandler - from fastapi.responses import FileResponse - from panel.io.document import extra_socket_handlers - from panel.io.resources import COMPONENT_PATH - from panel.io.server import ComponentResourceHandler - from panel.io.state import set_curdoc - - def dispatch_fastapi(conn, events=None, msg=None): - if msg is None: - msg = conn.protocol.create("PATCH-DOC", events) - return [conn._socket.send_message(msg)] - - extra_socket_handlers[WSHandler] = dispatch_fastapi - - def panel_app(doc): - doc.on_event("document_ready", partial(pn.state._schedule_on_load, doc)) - - with set_curdoc(doc): - panel_app = panel_app_fn() - panel_app.server_doc(doc) - - handler = FunctionHandler(panel_app) - application = Application(handler) - - BokehFastAPI({endpoint: application}, server=server) - - @server.get( - f"/{COMPONENT_PATH.rstrip('/')}" + "/{path:path}", include_in_schema=False - ) - def get_component_resource(path: str): - # ComponentResourceHandler.parse_url_path only ever accesses - # self._resource_attrs, which fortunately is a class attribute. Thus, we can - # get away with using the method without actually instantiating the class - self_ = cast(ComponentResourceHandler, ComponentResourceHandler) - resolved_path = ComponentResourceHandler.parse_url_path(self_, path) - return FileResponse(resolved_path) - - def serve_with_fastapi(self, app: FastAPI, endpoint: str): - self.add_panel_app(app, self.index_page, endpoint) - - for dir in ["css", "imgs"]: - app.mount( - f"/{dir}", - StaticFiles(directory=str(Path(__file__).parent / dir)), - name=dir, - ) - def app(engine: Engine) -> App: return App(engine) diff --git a/ragna/deploy/_ui/central_view.py b/ragna/deploy/_ui/central_view.py index 644e3a06..81ac9072 100644 --- a/ragna/deploy/_ui/central_view.py +++ b/ragna/deploy/_ui/central_view.py @@ -368,11 +368,7 @@ def header(self): ) chat_documents_pills = [] - if ( - self.current_chat is not None - and "metadata" in self.current_chat - and "documents" in self.current_chat - ): + if self.current_chat is not None: doc_names = [d["name"] for d in self.current_chat["documents"]] # FIXME: Instead of setting a hard limit of 20 documents here, this should diff --git a/ragna/deploy/_ui/css/chat_interface/card.css b/ragna/deploy/_ui/css/chat_interface/chatinterface.css similarity index 100% rename from ragna/deploy/_ui/css/chat_interface/card.css rename to ragna/deploy/_ui/css/chat_interface/chatinterface.css diff --git a/ragna/deploy/_ui/css/modal_configuration/fileinput.css b/ragna/deploy/_ui/css/modal_configuration/fileinput.css new file mode 100644 index 00000000..6807a6c1 --- /dev/null +++ b/ragna/deploy/_ui/css/modal_configuration/fileinput.css @@ -0,0 +1,5 @@ +:host(.file-input) .bk-input { + height: 80px; + border: var(--accent-color) dashed 2px; + border-radius: 10px; +} diff --git a/ragna/deploy/_ui/modal_configuration.py b/ragna/deploy/_ui/modal_configuration.py index a8bb44e7..53ecb756 100644 --- a/ragna/deploy/_ui/modal_configuration.py +++ b/ragna/deploy/_ui/modal_configuration.py @@ -87,8 +87,11 @@ def __init__(self, api_wrapper, **params): self.chat_name_input = pn.widgets.TextInput.from_param( self.param.chat_name, ) - # FIXME: accept - self.document_uploader = pn.widgets.FileInput(multiple=True) + self.document_uploader = pn.widgets.FileInput( + multiple=True, + css_classes=["file-input"], + accept=",".join(self.api_wrapper.get_components().documents), + ) # Most widgets (including those that use from_param) should be placed after the super init call self.cancel_button = pn.widgets.Button( @@ -182,19 +185,19 @@ async def model_section(self): # prevents re-rendering the section if self.config is None: # Retrieve the components from the API and build a config object - components = await self.api_wrapper.get_components() + components = self.api_wrapper.get_components() # TODO : use the components to set up the default values for the various params config = ChatConfig() - config.allowed_documents = components["documents"] + config.allowed_documents = components.documents - assistants = [component["title"] for component in components["assistants"]] + assistants = [assistant["title"] for assistant in components.assistants] config.param.assistant_name.objects = assistants config.assistant_name = assistants[0] source_storages = [ - component["title"] for component in components["source_storages"] + source_storage["title"] for source_storage in components.source_storages ] config.param.source_storage_name.objects = source_storages config.source_storage_name = source_storages[0] @@ -202,7 +205,6 @@ async def model_section(self): # Now that the config object is set, we can assign it to the param. # This will trigger the update of the advanced_config_ui section self.config = config - self.document_uploader.allowed_documents = config.allowed_documents return pn.Row( pn.Column( diff --git a/ragna/deploy/_ui/styles.py b/ragna/deploy/_ui/styles.py index 793b4ca6..44bb1c4e 100644 --- a/ragna/deploy/_ui/styles.py +++ b/ragna/deploy/_ui/styles.py @@ -37,10 +37,10 @@ "central_view": [pn.Column, pn.Row, pn.pane.HTML], "chat_interface": [ pn.widgets.TextInput, - pn.layout.Card, pn.pane.Markdown, pn.widgets.button.Button, pn.Column, + pn.chat.ChatInterface, ], "right_sidebar": [pn.widgets.Button, pn.Column, pn.pane.Markdown], "left_sidebar": [pn.widgets.Button, pn.pane.HTML, pn.Column], @@ -50,6 +50,7 @@ pn.layout.Card, pn.Row, pn.widgets.Button, + pn.widgets.FileInput, ], } diff --git a/requirements-docker.lock b/requirements-docker.lock index 0010d3c0..14264c8e 100644 --- a/requirements-docker.lock +++ b/requirements-docker.lock @@ -4,85 +4,82 @@ # # pip-compile --extra=all --output-file=requirements-docker.lock --strip-extras pyproject.toml # -aiofiles==23.2.1 +aiofiles==24.1.0 # via Ragna (pyproject.toml) -annotated-types==0.6.0 +annotated-types==0.7.0 # via pydantic -anyio==4.2.0 +anyio==4.7.0 # via # httpx # starlette # watchfiles -asgiref==3.7.2 +asgiref==3.8.1 # via opentelemetry-instrumentation-asgi -attrs==23.2.0 - # via lancedb backoff==2.2.1 - # via - # opentelemetry-exporter-otlp-proto-common - # opentelemetry-exporter-otlp-proto-grpc - # posthog -bcrypt==4.1.2 + # via posthog +bcrypt==4.2.1 # via chromadb -bleach==6.1.0 +bleach==6.2.0 # via panel -bokeh==3.4.1 - # via panel -build==1.0.3 - # via chromadb -cachetools==5.3.2 +bokeh==3.6.2 # via - # google-auth - # lancedb -certifi==2023.11.17 + # bokeh-fastapi + # panel +bokeh-fastapi==0.1.1 + # via Ragna (pyproject.toml) +build==1.2.2.post1 + # via chromadb +cachetools==5.5.0 + # via google-auth +certifi==2024.8.30 # via # httpcore # httpx # kubernetes - # pulsar-client # requests -charset-normalizer==3.3.2 +charset-normalizer==3.4.0 # via requests -chroma-hnswlib==0.7.3 +chroma-hnswlib==0.7.6 # via chromadb -chromadb==0.4.22 +chromadb==0.5.23 # via Ragna (pyproject.toml) click==8.1.7 # via - # lancedb # typer # uvicorn coloredlogs==15.0.1 # via onnxruntime -contourpy==1.2.0 +contourpy==1.3.1 # via bokeh -decorator==5.1.1 - # via retry -deprecated==1.2.14 +deprecated==1.2.15 # via # opentelemetry-api # opentelemetry-exporter-otlp-proto-grpc + # opentelemetry-semantic-conventions deprecation==2.1.0 # via lancedb -emoji==2.9.0 +durationpy==0.9 + # via kubernetes +emoji==2.14.0 # via Ragna (pyproject.toml) -fastapi==0.109.0 +fastapi==0.115.6 # via # Ragna (pyproject.toml) + # bokeh-fastapi # chromadb -filelock==3.13.1 +filelock==3.16.1 # via huggingface-hub -flatbuffers==23.5.26 +flatbuffers==24.3.25 # via onnxruntime -fsspec==2023.12.2 +fsspec==2024.10.0 # via huggingface-hub -google-auth==2.26.2 +google-auth==2.36.0 # via kubernetes -googleapis-common-protos==1.62.0 +googleapis-common-protos==1.66.0 # via opentelemetry-exporter-otlp-proto-grpc -greenlet==3.0.3 +greenlet==3.1.1 # via sqlalchemy -grpcio==1.60.0 +grpcio==1.68.1 # via # chromadb # opentelemetry-exporter-otlp-proto-grpc @@ -90,61 +87,63 @@ h11==0.14.0 # via # httpcore # uvicorn -httpcore==1.0.2 +httpcore==1.0.7 # via httpx -httptools==0.6.1 +httptools==0.6.4 # via uvicorn -httpx==0.26.0 - # via Ragna (pyproject.toml) +httpx==0.28.1 + # via + # Ragna (pyproject.toml) + # chromadb httpx-sse==0.4.0 # via Ragna (pyproject.toml) -huggingface-hub==0.20.2 +huggingface-hub==0.26.5 # via tokenizers humanfriendly==10.0 # via coloredlogs -idna==3.6 +idna==3.10 # via # anyio # httpx # requests -ijson==3.2.3 +ijson==3.3.0 # via Ragna (pyproject.toml) -importlib-metadata==6.11.0 +importlib-metadata==8.5.0 # via opentelemetry-api -importlib-resources==6.1.1 +importlib-resources==6.4.5 # via chromadb -jinja2==3.1.3 +jinja2==3.1.4 # via bokeh -kubernetes==29.0.0 +kubernetes==31.0.0 # via chromadb -lancedb==0.4.4 +lancedb==0.17.0 # via Ragna (pyproject.toml) -linkify-it-py==2.0.2 +linkify-it-py==2.0.3 # via panel -lxml==5.1.0 +lxml==5.3.0 # via # python-docx # python-pptx -markdown==3.5.2 +markdown==3.7 # via panel markdown-it-py==3.0.0 # via # mdit-py-plugins # panel # rich -markupsafe==2.1.3 +markupsafe==3.0.2 # via jinja2 -mdit-py-plugins==0.4.0 +mdit-py-plugins==0.4.2 # via panel mdurl==0.1.2 # via markdown-it-py -mmh3==4.1.0 +mmh3==5.0.1 # via chromadb monotonic==1.6 # via posthog mpmath==1.3.0 # via sympy -numpy==1.26.3 +numpy==2.2.0 # via # bokeh # chroma-hnswlib @@ -152,15 +151,14 @@ numpy==1.26.3 # contourpy # onnxruntime # pandas - # pyarrow # pylance oauthlib==3.2.2 # via # kubernetes # requests-oauthlib -onnxruntime==1.19.0 +onnxruntime==1.20.1 # via chromadb -opentelemetry-api==1.22.0 +opentelemetry-api==1.28.2 # via # chromadb # opentelemetry-exporter-otlp-proto-grpc @@ -168,204 +166,200 @@ opentelemetry-api==1.22.0 # opentelemetry-instrumentation-asgi # opentelemetry-instrumentation-fastapi # opentelemetry-sdk -opentelemetry-exporter-otlp-proto-common==1.22.0 + # opentelemetry-semantic-conventions +opentelemetry-exporter-otlp-proto-common==1.28.2 # via opentelemetry-exporter-otlp-proto-grpc -opentelemetry-exporter-otlp-proto-grpc==1.22.0 +opentelemetry-exporter-otlp-proto-grpc==1.28.2 # via chromadb -opentelemetry-instrumentation==0.43b0 +opentelemetry-instrumentation==0.49b2 # via # opentelemetry-instrumentation-asgi # opentelemetry-instrumentation-fastapi -opentelemetry-instrumentation-asgi==0.43b0 +opentelemetry-instrumentation-asgi==0.49b2 # via opentelemetry-instrumentation-fastapi -opentelemetry-instrumentation-fastapi==0.43b0 +opentelemetry-instrumentation-fastapi==0.49b2 # via chromadb -opentelemetry-proto==1.22.0 +opentelemetry-proto==1.28.2 # via # opentelemetry-exporter-otlp-proto-common # opentelemetry-exporter-otlp-proto-grpc -opentelemetry-sdk==1.22.0 +opentelemetry-sdk==1.28.2 # via # chromadb # opentelemetry-exporter-otlp-proto-grpc -opentelemetry-semantic-conventions==0.43b0 +opentelemetry-semantic-conventions==0.49b2 # via + # opentelemetry-instrumentation # opentelemetry-instrumentation-asgi # opentelemetry-instrumentation-fastapi # opentelemetry-sdk -opentelemetry-util-http==0.43b0 +opentelemetry-util-http==0.49b2 # via # opentelemetry-instrumentation-asgi # opentelemetry-instrumentation-fastapi -overrides==7.4.0 +orjson==3.10.12 + # via chromadb +overrides==7.7.0 # via # chromadb # lancedb -packaging==23.2 +packaging==24.2 # via # Ragna (pyproject.toml) # bokeh # build # deprecation # huggingface-hub + # lancedb # onnxruntime -pandas==2.1.4 + # opentelemetry-instrumentation + # panel +pandas==2.2.3 # via # bokeh # panel -panel==1.4.4 +panel==1.5.4 # via Ragna (pyproject.toml) param==2.1.1 # via # panel # pyviz-comms -pillow==10.2.0 +pillow==11.0.0 # via # bokeh # python-pptx -posthog==3.3.1 +posthog==3.7.4 # via chromadb prompt-toolkit==3.0.36 # via questionary -protobuf==4.25.2 +protobuf==5.29.1 # via # googleapis-common-protos # onnxruntime # opentelemetry-proto -pulsar-client==3.4.0 - # via chromadb -py==1.11.0 - # via retry -pyarrow==14.0.2 +pyarrow==18.1.0 # via # Ragna (pyproject.toml) # pylance -pyasn1==0.5.1 +pyasn1==0.6.1 # via # pyasn1-modules # rsa -pyasn1-modules==0.3.0 +pyasn1-modules==0.4.1 # via google-auth -pydantic==2.5.3 +pydantic==2.10.3 # via # Ragna (pyproject.toml) # chromadb # fastapi # lancedb # pydantic-settings -pydantic-core==2.14.6 +pydantic-core==2.27.1 # via # Ragna (pyproject.toml) # pydantic -pydantic-settings==2.1.0 +pydantic-settings==2.6.1 # via Ragna (pyproject.toml) -pygments==2.17.2 +pygments==2.18.0 # via rich -pyjwt==2.8.0 +pyjwt==2.10.1 # via Ragna (pyproject.toml) -pylance==0.9.6 +pylance==0.20.0 # via lancedb -pymupdf==1.23.15 +pymupdf==1.25.0 # via Ragna (pyproject.toml) -pymupdfb==1.23.9 - # via pymupdf pypika==0.48.9 # via chromadb -pyproject-hooks==1.0.0 +pyproject-hooks==1.2.0 # via build -python-dateutil==2.8.2 +python-dateutil==2.9.0.post0 # via # kubernetes # pandas # posthog -python-docx==1.1.0 +python-docx==1.1.2 # via Ragna (pyproject.toml) -python-dotenv==1.0.0 +python-dotenv==1.0.1 # via # pydantic-settings # uvicorn -python-multipart==0.0.6 +python-multipart==0.0.19 # via Ragna (pyproject.toml) -python-pptx==0.6.23 +python-pptx==1.0.2 # via Ragna (pyproject.toml) -pytz==2023.3.post1 +pytz==2024.2 # via pandas -pyviz-comms==3.0.1 +pyviz-comms==3.0.3 # via panel -pyyaml==6.0.1 +pyyaml==6.0.2 # via # bokeh # chromadb # huggingface-hub # kubernetes - # lancedb # uvicorn questionary==2.0.1 # via Ragna (pyproject.toml) -ratelimiter==1.2.0.post0 - # via lancedb -regex==2023.12.25 +regex==2024.11.6 # via tiktoken -requests==2.31.0 +requests==2.32.3 # via - # chromadb # huggingface-hub # kubernetes - # lancedb # panel # posthog # requests-oauthlib # tiktoken -requests-oauthlib==1.3.1 +requests-oauthlib==2.0.0 # via kubernetes -retry==0.9.2 - # via lancedb -rich==13.7.0 - # via Ragna (pyproject.toml) +rich==13.9.4 + # via + # Ragna (pyproject.toml) + # chromadb + # typer rsa==4.9 # via google-auth -semver==3.0.2 - # via lancedb -six==1.16.0 +shellingham==1.5.4 + # via typer +six==1.17.0 # via - # bleach # kubernetes # posthog # python-dateutil -sniffio==1.3.0 - # via - # anyio - # httpx -sqlalchemy==2.0.25 +sniffio==1.3.1 + # via anyio +sqlalchemy==2.0.36 # via Ragna (pyproject.toml) -starlette==0.35.1 +starlette==0.41.3 # via # Ragna (pyproject.toml) + # bokeh-fastapi # fastapi -sympy==1.12 +sympy==1.13.3 # via onnxruntime -tenacity==8.2.3 +tenacity==9.0.0 # via chromadb -tiktoken==0.5.2 +tiktoken==0.8.0 # via Ragna (pyproject.toml) -tokenizers==0.15.0 +tokenizers==0.20.3 # via chromadb -tomlkit==0.12.3 +tomlkit==0.13.2 # via Ragna (pyproject.toml) -tornado==6.4 +tornado==6.4.2 # via bokeh -tqdm==4.66.1 +tqdm==4.67.1 # via # chromadb # huggingface-hub # lancedb # panel -typer==0.9.0 +typer==0.15.1 # via # Ragna (pyproject.toml) # chromadb -typing-extensions==4.9.0 +typing-extensions==4.12.2 # via + # anyio # chromadb # fastapi # huggingface-hub @@ -374,44 +368,40 @@ typing-extensions==4.9.0 # pydantic # pydantic-core # python-docx + # python-pptx # sqlalchemy # typer -tzdata==2023.4 +tzdata==2024.2 # via pandas -uc-micro-py==1.0.2 +uc-micro-py==1.0.3 # via linkify-it-py -urllib3==2.1.0 +urllib3==2.2.3 # via # kubernetes # requests -uvicorn==0.26.0 +uvicorn==0.32.1 # via # Ragna (pyproject.toml) # chromadb -uvloop==0.19.0 +uvloop==0.21.0 # via uvicorn -watchfiles==0.21.0 +watchfiles==1.0.3 # via uvicorn wcwidth==0.2.13 # via prompt-toolkit webencodings==0.5.1 # via bleach -websocket-client==1.7.0 +websocket-client==1.8.0 # via kubernetes -websockets==12.0 +websockets==14.1 # via uvicorn -wrapt==1.16.0 +wrapt==1.17.0 # via # deprecated # opentelemetry-instrumentation -xlsxwriter==3.1.9 +xlsxwriter==3.2.0 # via python-pptx -xyzservices==2023.10.1 - # via - # bokeh - # panel -zipp==3.17.0 +xyzservices==2024.9.0 + # via bokeh +zipp==3.21.0 # via importlib-metadata - -# The following packages are considered to be unsafe in a requirements file: -# setuptools From f0f09dc9b5751567a190f9a312abe1e798c0030e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 11 Dec 2024 12:36:26 +0100 Subject: [PATCH 08/29] install bokeh_fastapi through panel (#513) --- pyproject.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 67ffe501..ad841dcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,15 +22,12 @@ classifiers = [ requires-python = ">=3.10" dependencies = [ "aiofiles", - # Remove this and instead depend on panel[fastapi] - # after https://github.com/holoviz/panel/pull/7495 is released - "bokeh_fastapi==0.1.1", "emoji", "eval_type_backport; python_version<'3.10'", "fastapi", "httpx", "packaging", - "panel==1.5.4", + "panel[fastapi]==1.5.4", "pydantic>=2", "pydantic-core", "pydantic-settings>=2", From ac13a3cb316fd75090167c528562a1be6e25970d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 11 Dec 2024 15:54:54 +0100 Subject: [PATCH 09/29] add session based auth workflow (#464) --- docs/examples/gallery_streaming.py | 22 +- docs/references/config.md | 38 +- docs/references/{rest-api.md => deploy.md} | 2 +- docs/references/release-notes.md | 6 +- docs/tutorials/gallery_custom_components.py | 36 +- docs/tutorials/gallery_rest_api.py | 103 +++-- environment-dev.yml | 3 +- mkdocs.yml | 2 +- pyproject.toml | 7 +- ragna/_docs.py | 181 ++++----- ragna/_utils.py | 108 ++++- ragna/core/__init__.py | 1 - ragna/core/_rag.py | 54 +-- ragna/core/_utils.py | 16 +- ragna/deploy/__init__.py | 13 +- ragna/deploy/_api.py | 31 +- ragna/deploy/_auth.py | 416 ++++++++++++++++++++ ragna/deploy/_authentication.py | 132 ------- ragna/deploy/_config.py | 19 +- ragna/deploy/_core.py | 23 ++ ragna/deploy/_database.py | 122 +++++- ragna/deploy/_engine.py | 48 ++- ragna/deploy/_key_value_store.py | 118 ++++++ ragna/deploy/_orm.py | 22 +- ragna/deploy/_schemas.py | 52 ++- ragna/deploy/_templates/__init__.py | 16 + ragna/deploy/_templates/base.html | 49 +++ ragna/deploy/_templates/basic_auth.css | 6 + ragna/deploy/_templates/basic_auth.html | 33 ++ ragna/deploy/_templates/oauth.html | 7 + ragna/deploy/_ui/api_wrapper.py | 4 +- ragna/deploy/_ui/left_sidebar.py | 2 + ragna/source_storages/_vector_database.py | 2 +- scripts/add_chats.py | 27 +- scripts/docs/gen_files.py | 16 +- tests/assistants/test_api.py | 28 +- tests/conftest.py | 2 +- tests/deploy/api/test_components.py | 19 +- tests/deploy/api/test_e2e.py | 9 +- tests/deploy/api/utils.py | 35 ++ tests/deploy/utils.py | 32 +- tests/utils.py | 13 - 42 files changed, 1317 insertions(+), 558 deletions(-) rename docs/references/{rest-api.md => deploy.md} (91%) create mode 100644 ragna/deploy/_auth.py delete mode 100644 ragna/deploy/_authentication.py create mode 100644 ragna/deploy/_key_value_store.py create mode 100644 ragna/deploy/_templates/__init__.py create mode 100644 ragna/deploy/_templates/base.html create mode 100644 ragna/deploy/_templates/basic_auth.css create mode 100644 ragna/deploy/_templates/basic_auth.html create mode 100644 ragna/deploy/_templates/oauth.html create mode 100644 tests/deploy/api/utils.py diff --git a/docs/examples/gallery_streaming.py b/docs/examples/gallery_streaming.py index 09a39abd..846897e6 100644 --- a/docs/examples/gallery_streaming.py +++ b/docs/examples/gallery_streaming.py @@ -107,29 +107,30 @@ def answer(self, messages): config = Config(assistants=[DemoStreamingAssistant]) -rest_api = ragna_docs.RestApi() +ragna_deploy = ragna_docs.RagnaDeploy(config) -client, document = rest_api.start(config, authenticate=True, upload_document=True) +client, document = ragna_deploy.get_http_client( + authenticate=True, upload_document=True +) # %% # Start and prepare the chat chat = ( client.post( - "/chats", + "/api/chats", json={ "name": "Tutorial REST API", - "documents": [document], + "document_ids": [document["id"]], "source_storage": source_storages.RagnaDemoSourceStorage.display_name(), "assistant": DemoStreamingAssistant.display_name(), - "params": {}, }, ) .raise_for_status() .json() ) -client.post(f"/chats/{chat['id']}/prepare").raise_for_status() +client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() # %% # Streaming the response is performed with [JSONL](https://jsonlines.org/). Each line @@ -140,7 +141,7 @@ def answer(self, messages): with client.stream( "POST", - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "What is Ragna?", "stream": True}, ) as response: chunks = [json.loads(data) for data in response.iter_lines()] @@ -163,7 +164,8 @@ def answer(self, messages): print("".join(chunk["content"] for chunk in chunks)) # %% -# Before we close the example, let's stop the REST API and have a look at what would -# have printed in the terminal if we had started it with the `ragna api` command. +# Before we close the example, let's terminate the REST API and have a look at what +# would have printed in the terminal if we had started it with the `ragna deploy` +# command. -rest_api.stop() +ragna_deploy.terminate() diff --git a/docs/references/config.md b/docs/references/config.md index c6ed18b9..ba83263a 100644 --- a/docs/references/config.md +++ b/docs/references/config.md @@ -69,9 +69,9 @@ is equivalent to `RAGNA_API_ORIGINS='["http://localhost:31477"]'`. Local root directory Ragna uses for storing files. See [ragna.local_root][]. -### `authentication` +### `auth` -[ragna.deploy.Authentication][] class to use for authenticating users. +[ragna.deploy.Auth][] class to use for authenticating users. ### `document` @@ -85,48 +85,26 @@ Local root directory Ragna uses for storing files. See [ragna.local_root][]. [ragna.core.Assistant][]s to be available for the user to use. -### `api` - -#### `hostname` +### `hostname` Hostname the REST API will be bound to. -#### `port` +### `port` Port the REST API will be bound to. -#### `root_path` +### `root_path` A path prefix handled by a proxy that is not seen by the REST API, but is seen by external clients. -#### `url` - -URL of the REST API to be accessed by the web UI. Make sure to include the -[`root_path`](#root_path) if set. - -#### `origins` +### `origins` [CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) origins that are allowed -to connect to the REST API. The URL of the web UI is required for it to function. +to connect to the REST API. -#### `database_url` +### `database_url` URL of a SQL database that will be used to store the Ragna state. See [SQLAlchemy documentation](https://docs.sqlalchemy.org/en/20/core/engines.html#database-urls) on how to format the URL. - -### `ui` - -#### `hostname` - -Hostname the web UI will be bound to. - -#### `port` - -Port the web UI will be bound to. - -#### `origins` - -[CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) origins that are allowed -to connect to the web UI. diff --git a/docs/references/rest-api.md b/docs/references/deploy.md similarity index 91% rename from docs/references/rest-api.md rename to docs/references/deploy.md index bfd77544..6f39d899 100644 --- a/docs/references/rest-api.md +++ b/docs/references/deploy.md @@ -1,7 +1,7 @@ # REST API reference diff --git a/docs/references/release-notes.md b/docs/references/release-notes.md index 17457078..40ab7ff6 100644 --- a/docs/references/release-notes.md +++ b/docs/references/release-notes.md @@ -137,9 +137,9 @@ -- The classes [ragna.deploy.Authentication][], [ragna.deploy.RagnaDemoAuthentication][], - and [ragna.deploy.Config][] moved from the [ragna.core][] module to a new - [ragna.deploy][] module. +- The classes `ragna.deploy.Authentication`, `ragna.deploy.RagnaDemoAuthentication`, and + [ragna.deploy.Config][] moved from the [ragna.core][] module to a new [ragna.deploy][] + module. - [ragna.core.Component][], which is the superclass for [ragna.core.Assistant][] and [ragna.core.SourceStorage][], no longer takes a [ragna.deploy.Config][] to instantiate. For example diff --git a/docs/tutorials/gallery_custom_components.py b/docs/tutorials/gallery_custom_components.py index 8a411f81..6c556043 100644 --- a/docs/tutorials/gallery_custom_components.py +++ b/docs/tutorials/gallery_custom_components.py @@ -186,9 +186,11 @@ def answer(self, messages: list[Message]) -> Iterator[str]: assistants=[TutorialAssistant], ) -rest_api = ragna_docs.RestApi() +ragna_deploy = ragna_docs.RagnaDeploy(config) -client, document = rest_api.start(config, authenticate=True, upload_document=True) +client, document = ragna_deploy.get_http_client( + authenticate=True, upload_document=True +) # %% # To select our custom components, we pass their display names to the chat creation. @@ -201,10 +203,10 @@ def answer(self, messages: list[Message]) -> Iterator[str]: import json response = client.post( - "/chats", + "/api/chats", json={ "name": "Tutorial REST API", - "documents": [document], + "document_ids": [document["id"]], "source_storage": TutorialSourceStorage.display_name(), "assistant": TutorialAssistant.display_name(), "params": {}, @@ -212,10 +214,10 @@ def answer(self, messages: list[Message]) -> Iterator[str]: ).raise_for_status() chat = response.json() -client.post(f"/chats/{chat['id']}/prepare").raise_for_status() +client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() response = client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "What is Ragna?"}, ).raise_for_status() answer = response.json() @@ -225,7 +227,7 @@ def answer(self, messages: list[Message]) -> Iterator[str]: # Let's stop the REST API and have a look at what would have printed in the terminal if # we had started it with the `ragna api` command. -rest_api.stop() +ragna_deploy.terminate() # %% # ### Web UI @@ -263,9 +265,7 @@ def answer( my_optional_parameter: str = "foo", ) -> Iterator[str]: print(f"Running {type(self).__name__}().answer()") - yield ( - f"I was given {my_required_parameter=} and {my_optional_parameter=}." - ) + yield f"I was given {my_required_parameter=} and {my_optional_parameter=}." # %% @@ -319,19 +319,21 @@ def answer( assistants=[ElaborateTutorialAssistant], ) -rest_api = ragna_docs.RestApi() +ragna_deploy = ragna_docs.RagnaDeploy(config) -client, document = rest_api.start(config, authenticate=True, upload_document=True) +client, document = ragna_deploy.get_http_client( + authenticate=True, upload_document=True +) # %% # To pass custom parameters, define them in the `params` mapping when creating a new # chat. response = client.post( - "/chats", + "/api/chats", json={ "name": "Tutorial REST API", - "documents": [document], + "document_ids": [document["id"]], "source_storage": TutorialSourceStorage.display_name(), "assistant": ElaborateTutorialAssistant.display_name(), "params": { @@ -344,10 +346,10 @@ def answer( # %% -client.post(f"/chats/{chat['id']}/prepare").raise_for_status() +client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() response = client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "What is Ragna?"}, ).raise_for_status() answer = response.json() @@ -357,7 +359,7 @@ def answer( # Let's stop the REST API and have a look at what would have printed in the terminal if # we had started it with the `ragna api` command. -rest_api.stop() +ragna_deploy.terminate() # %% # ### Web UI diff --git a/docs/tutorials/gallery_rest_api.py b/docs/tutorials/gallery_rest_api.py index befcbfb3..ede8833d 100644 --- a/docs/tutorials/gallery_rest_api.py +++ b/docs/tutorials/gallery_rest_api.py @@ -3,9 +3,9 @@ Ragna was designed to help you quickly build custom RAG powered web applications. For this you can leverage the built-in -[REST API](../../references/rest-api.md). +[REST API](../../references/deploy.md). -This tutorial walks you through basic steps of using Ragnas REST API. +This tutorial walks you through basic steps of using Ragna's REST API. """ # %% @@ -39,42 +39,39 @@ config = Config() -rest_api = ragna_docs.RestApi() -_ = rest_api.start(config) +ragna_deploy = ragna_docs.RagnaDeploy(config=config) # %% # Let's make sure the REST API is started correctly and can be reached. import httpx -client = httpx.Client(base_url=config.api.url) -client.get("/").raise_for_status() +client = httpx.Client(base_url=f"http://{config.hostname}:{config.port}") +client.get("/health").raise_for_status() # %% # ## Step 2: Authentication # -# In order to use Ragnas REST API, we need to authenticate first. To forge an API token -# we send a request to the `/token` endpoint. This is processed by the -# [`Authentication`][ragna.deploy.Authentication], which can be overridden through the -# config. For this tutorial, we use the default -# [ragna.deploy.RagnaDemoAuthentication][], which requires a matching username and -# password. +# In order to use Ragna's REST API, we need to authenticate first. This is handled by +# the [ragna.deploy.Auth][] class, which can be overridden through the config. By +# default, [ragna.deploy.NoAuth][] is used. By hitting the `/login` endpoint, we get a +# session cookie, which is later used to authorize our requests. -username = password = "Ragna" - -response = client.post( - "/token", - data={"username": username, "password": password}, -).raise_for_status() -token = response.json() +client.get("/login", follow_redirects=True) +dict(client.cookies) # %% -# We set the API token on our HTTP client so we don't have to manually supply it for -# each request below. - -client.headers["Authorization"] = f"Bearer {token}" - +# !!! note +# +# In a regular deployment, you'll have login through your browser and create an API +# key in your profile page. The API key is used as +# [bearer token](https://swagger.io/docs/specification/authentication/bearer-authentication/) +# and can be set with +# +# ```python +# httpx.Client(..., headers={"Authorization": f"Bearer {RAGNA_API_KEY}"}) +# ``` # %% # ## Step 3: Uploading documents @@ -84,7 +81,7 @@ import json -response = client.get("/components").raise_for_status() +response = client.get("/api/components").raise_for_status() print(json.dumps(response.json(), indent=2)) # %% @@ -102,38 +99,28 @@ # %% # The upload process in Ragna consists of two parts: # -# 1. Announce the file to be uploaded. Under the hood this pre-registers the document -# in Ragnas database and returns information about how the upload is to be performed. -# This is handled by the [ragna.core.Document][] class. By default, -# [ragna.core.LocalDocument][] is used, which uploads the files to the local file -# system. -# 2. Perform the actual upload with the information from step 1. +# 1. Announce the file to be uploaded. Under the hood this registers the document +# in Ragna's database and returns the document ID, which is needed for the upload. response = client.post( - "/document", json={"name": document_path.name} + "/api/documents", json=[{"name": document_path.name}] ).raise_for_status() -document_upload = response.json() -print(json.dumps(response.json(), indent=2)) +documents = response.json() +print(json.dumps(documents, indent=2)) # %% -# The returned JSON contains two parts: the document object that we are later going to -# use to create a chat as well as the upload parameters. -# !!! note +# 2. Perform the actual upload with the information from step 1. through a +# [multipart request](https://swagger.io/docs/specification/describing-request-body/multipart-requests/) +# with the following parameters: # -# The `"token"` in the response is *not* the Ragna REST API token, but rather a -# separate one to perform the document upload. -# -# We perform the actual upload with the latter now. - -document = document_upload["document"] +# - The field is `documents` for all entries +# - The field name is the ID of the document returned by step 1. +# - The field value is the binary content of the document. -parameters = document_upload["parameters"] -client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": open(document_path, "rb")}, -).raise_for_status() +client.put( + "/api/documents", + files=[("documents", (documents[0]["id"], open(document_path, "rb")))], +) # %% # ## Step 4: Select a source storage and assistant @@ -155,13 +142,12 @@ # be used, we can create a new chat. response = client.post( - "/chats", + "/api/chats", json={ "name": "Tutorial REST API", - "documents": [document], + "document_ids": [document["id"] for document in documents], "source_storage": source_storage, "assistant": assistant, - "params": {}, }, ).raise_for_status() chat = response.json() @@ -171,13 +157,13 @@ # As can be seen by the `"prepared"` field in the `chat` JSON object we still need to # prepare it. -client.post(f"/chats/{chat['id']}/prepare").raise_for_status() +client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() # %% # Finally, we can get answers to our questions. response = client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "What is Ragna?"}, ).raise_for_status() answer = response.json() @@ -188,7 +174,8 @@ print(answer["content"]) # %% -# Before we close the tutorial, let's stop the REST API and have a look at what would -# have printed in the terminal if we had started it with the `ragna api` command. +# Before we close the tutorial, let's terminate the REST API and have a look at what +# would have printed in the terminal if we had started it with the `ragna deploy` +# command. -rest_api.stop() +ragna_deploy.terminate() diff --git a/environment-dev.yml b/environment-dev.yml index ae7ffe0c..1684b030 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -1,4 +1,4 @@ -name: ragna-deploy-dev +name: ragna-dev channels: - conda-forge dependencies: @@ -15,6 +15,7 @@ dependencies: - pytest-asyncio - pytest-playwright - mypy ==1.10.0 + - types-redis - pre-commit - types-aiofiles - sqlalchemy-stubs diff --git a/mkdocs.yml b/mkdocs.yml index a31d6ffe..787bae60 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -91,7 +91,7 @@ nav: - community/contribute.md - References: - references/python-api.md - - references/rest-api.md + - references/deploy.md - references/cli.md - references/config.md - references/faq.md diff --git a/pyproject.toml b/pyproject.toml index ad841dcf..83594126 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -172,9 +172,12 @@ disable_error_code = [ ] [[tool.mypy.overrides]] -# It is a fundamental feature of the components to request more parameters than the base -# class. Thus, we just silence mypy here. +# 1. We automatically handle user-defined sync and async methods +# 2. It is a fundamental feature of the RAG components to request more parameters than +# the base class. +# Thus, we just silence mypy where it would complain about the points above. module = [ + "ragna.deploy._auth", "ragna.source_storages.*", "ragna.assistants.*" ] diff --git a/ragna/_docs.py b/ragna/_docs.py index 0d6191d4..d06fd215 100644 --- a/ragna/_docs.py +++ b/ragna/_docs.py @@ -11,11 +11,12 @@ import httpx -from ragna._utils import timeout_after from ragna.core import RagnaException from ragna.deploy import Config -__all__ = ["SAMPLE_CONTENT", "RestApi"] +from ._utils import BackgroundSubprocess + +__all__ = ["SAMPLE_CONTENT", "RagnaDeploy"] SAMPLE_CONTENT = """\ Ragna is an open source project built by Quansight. It is designed to allow @@ -29,51 +30,25 @@ """ -class RestApi: - def __init__(self) -> None: - self._process: Optional[subprocess.Popen] = None - # In case the documentation errors before we call RestApi.stop, we still need to - # stop the server to avoid zombie processes - atexit.register(self.stop, quiet=True) - - def start( - self, - config: Config, - *, - authenticate: bool = False, - upload_document: bool = False, - ) -> tuple[httpx.Client, Optional[dict]]: - if upload_document and not authenticate: - raise RagnaException( - "Cannot upload a document without authenticating first. " - "Set authenticate=True when using upload_document=True." - ) - python_path, config_path = self._prepare_config(config) - - client = httpx.Client(base_url=config.api.url) - - self._process = self._start_api(config_path, python_path, client) +class RagnaDeploy: + def __init__(self, config: Config) -> None: + self.config = config + python_path, config_path = self._prepare_config() + self._process = self._deploy(config, config_path, python_path) + # In case the documentation errors before we call RagnaDeploy.terminate, + # we still need to stop the server to avoid zombie processes + atexit.register(self.terminate, quiet=True) - if authenticate: - self._authenticate(client) - - if upload_document: - document = self._upload_document(client) - else: - document = None - - return client, document - - def _prepare_config(self, config: Config) -> tuple[str, str]: + def _prepare_config(self) -> tuple[str, str]: deploy_directory = Path(tempfile.mkdtemp()) - python_path = ( - f"{deploy_directory}{os.pathsep}{os.environ.get('PYTHONPATH', '')}" + python_path = os.pathsep.join( + [str(deploy_directory), os.environ.get("PYTHONPATH", "")] ) config_path = str(deploy_directory / "ragna.toml") - config.local_root = deploy_directory - config.api.database_url = f"sqlite:///{deploy_directory / 'ragna.db'}" + self.config.local_root = deploy_directory + self.config.database_url = f"sqlite:///{deploy_directory / 'ragna.db'}" sys.modules["__main__"].__file__ = inspect.getouterframes( inspect.currentframe() @@ -88,98 +63,92 @@ def _prepare_config(self, config: Config) -> tuple[str, str]: file.write("from ragna import *\n") file.write("from ragna.core import *\n") - for component in itertools.chain(config.source_storages, config.assistants): + for component in itertools.chain( + self.config.source_storages, self.config.assistants + ): if component.__module__ == "__main__": custom_components.add(component) file.write(f"{textwrap.dedent(inspect.getsource(component))}\n\n") component.__module__ = custom_module - config.to_file(config_path) + self.config.to_file(config_path) for component in custom_components: component.__module__ = "__main__" return python_path, config_path - def _start_api( - self, config_path: str, python_path: str, client: httpx.Client - ) -> subprocess.Popen: + def _deploy( + self, config: Config, config_path: str, python_path: str + ) -> BackgroundSubprocess: env = os.environ.copy() env["PYTHONPATH"] = python_path - process = subprocess.Popen( - [sys.executable, "-m", "ragna", "api", "--config", config_path], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - ) - - def check_api_available() -> bool: + def startup_fn() -> bool: try: - return client.get("/").is_success + return httpx.get(f"{config._url}/health").is_success except httpx.ConnectError: return False - failure_message = "Failed to the start the Ragna REST API." - - @timeout_after(60, message=failure_message) - def wait_for_api() -> None: - print("Starting Ragna REST API") - while not check_api_available(): - try: - stdout, stderr = process.communicate(timeout=1) - except subprocess.TimeoutExpired: - print(".", end="") - continue - else: - parts = [failure_message] - if stdout: - parts.append(f"\n\nSTDOUT:\n\n{stdout.decode()}") - if stderr: - parts.append(f"\n\nSTDERR:\n\n{stderr.decode()}") - - raise RuntimeError("".join(parts)) - - print() - - wait_for_api() - return process - - def _authenticate(self, client: httpx.Client) -> None: - username = password = "Ragna" + if startup_fn(): + raise RagnaException("ragna server is already running") + + return BackgroundSubprocess( + sys.executable, + "-m", + "ragna", + "deploy", + "--api", + "--no-ui", + "--config", + config_path, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + startup_fn=startup_fn, + startup_timeout=60, + ) - response = client.post( - "/token", - data={"username": username, "password": password}, - ).raise_for_status() - token = response.json() + def get_http_client( + self, + *, + authenticate: bool = False, + upload_document: bool = False, + ) -> tuple[httpx.Client, Optional[dict[str, Any]]]: + if upload_document and not authenticate: + raise RagnaException( + "Cannot upload a document without authenticating first. " + "Set authenticate=True when using upload_document=True." + ) - client.headers["Authorization"] = f"Bearer {token}" + client = httpx.Client(base_url=self.config._url) - def _upload_document(self, client: httpx.Client) -> dict[str, Any]: - name, content = "ragna.txt", SAMPLE_CONTENT + if authenticate: + client.get("/login", follow_redirects=True) - response = client.post("/document", json={"name": name}).raise_for_status() - document_upload = response.json() + if upload_document: + name, content = "ragna.txt", SAMPLE_CONTENT - document = cast(dict[str, Any], document_upload["document"]) + response = client.post( + "/api/documents", json=[{"name": name}] + ).raise_for_status() + document = cast(dict[str, Any], response.json()[0]) - parameters = document_upload["parameters"] - client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": content}, - ).raise_for_status() + client.put( + "/api/documents", + files=[("documents", (document["id"], content.encode()))], + ) + else: + document = None - return document + return client, document - def stop(self, *, quiet: bool = False) -> None: + def terminate(self, quiet: bool = False) -> None: if self._process is None: return - self._process.terminate() - stdout, _ = self._process.communicate() + output = self._process.terminate() - if not quiet: - print(stdout.decode()) + if output and not quiet: + stdout, _ = output + print(stdout) diff --git a/ragna/_utils.py b/ragna/_utils.py index 6ef5eb5c..32bb24c5 100644 --- a/ragna/_utils.py +++ b/ragna/_utils.py @@ -1,10 +1,31 @@ +from __future__ import annotations + +import contextlib import functools +import getpass import inspect import os +import shlex +import subprocess import sys import threading +import time from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Iterator, + Optional, + TypeVar, + Union, + cast, +) + +from starlette.concurrency import iterate_in_threadpool, run_in_threadpool + +T = TypeVar("T") _LOCAL_ROOT = ( Path(os.environ.get("RAGNA_LOCAL_ROOT", "~/.cache/ragna")).expanduser().resolve() @@ -110,3 +131,88 @@ def is_debugging() -> bool: if any(part.startswith(name) for part in parts): return True return False + + +def as_awaitable( + fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any, **kwargs: Any +) -> Awaitable[T]: + if inspect.iscoroutinefunction(fn): + fn = cast(Callable[..., Awaitable[T]], fn) + awaitable = fn(*args, **kwargs) + else: + fn = cast(Callable[..., T], fn) + awaitable = run_in_threadpool(fn, *args, **kwargs) + + return awaitable + + +def as_async_iterator( + fn: Union[Callable[..., Iterator[T]], Callable[..., AsyncIterator[T]]], + *args: Any, + **kwargs: Any, +) -> AsyncIterator[T]: + if inspect.isasyncgenfunction(fn): + fn = cast(Callable[..., AsyncIterator[T]], fn) + async_iterator = fn(*args, **kwargs) + else: + fn = cast(Callable[..., Iterator[T]], fn) + async_iterator = iterate_in_threadpool(fn(*args, **kwargs)) + + return async_iterator + + +def default_user() -> str: + with contextlib.suppress(Exception): + return getpass.getuser() + with contextlib.suppress(Exception): + return os.getlogin() + return "Bodil" + + +class BackgroundSubprocess: + def __init__( + self, + *cmd: str, + stdout: Any = sys.stdout, + stderr: Any = sys.stdout, + text: bool = True, + startup_fn: Optional[Callable[[], bool]] = None, + startup_timeout: float = 10, + terminate_timeout: float = 10, + **subprocess_kwargs: Any, + ) -> None: + self._process = subprocess.Popen( + cmd, stdout=stdout, stderr=stderr, **subprocess_kwargs + ) + try: + if startup_fn: + + @timeout_after(startup_timeout, message=shlex.join(cmd)) + def wait() -> None: + while not startup_fn(): + time.sleep(0.2) + + wait() + except Exception: + self.terminate() + raise + + self._terminate_timeout = terminate_timeout + + def terminate(self) -> tuple[str, str]: + @timeout_after(self._terminate_timeout) + def terminate() -> tuple[str, str]: + self._process.terminate() + return self._process.communicate() + + try: + return terminate() # type: ignore[no-any-return] + except TimeoutError: + self._process.kill() + return self._process.communicate() + + def __enter__(self) -> BackgroundSubprocess: + return self + + def __exit__(self, *exc_info: Any) -> None: + self.terminate() diff --git a/ragna/core/__init__.py b/ragna/core/__init__.py index 44449775..1cdbc667 100644 --- a/ragna/core/__init__.py +++ b/ragna/core/__init__.py @@ -4,7 +4,6 @@ "Component", "Document", "DocumentHandler", - "DocumentUploadParameters", "DocxDocumentHandler", "PptxDocumentHandler", "EnvVarRequirement", diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index d963c15b..be5282aa 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -2,7 +2,6 @@ import contextlib import datetime -import inspect import itertools import uuid from collections import defaultdict @@ -24,11 +23,12 @@ import pydantic import pydantic_core from fastapi import status -from starlette.concurrency import iterate_in_threadpool, run_in_threadpool + +from ragna._utils import as_async_iterator, as_awaitable, default_user from ._components import Assistant, Component, Message, MessageRole, SourceStorage from ._document import Document, LocalDocument -from ._utils import RagnaException, default_user, merge_models +from ._utils import RagnaException, merge_models if TYPE_CHECKING: from ragna.deploy import Config @@ -145,7 +145,6 @@ def chat( Args: documents: Documents to use. If any item is not a [ragna.core.Document][], [ragna.core.LocalDocument.from_path][] is invoked on it. - FIXME source_storage: Source storage to use. assistant: Assistant to use. **params: Additional parameters passed to the source storage and assistant. @@ -153,8 +152,8 @@ def chat( return Chat( self, documents=documents, - source_storage=cast(SourceStorage, self._load_component(source_storage)), # type: ignore[arg-type] - assistant=cast(Assistant, self._load_component(assistant)), # type: ignore[arg-type] + source_storage=cast(SourceStorage, self._load_component(source_storage)), # type: ignore[arg-type] + assistant=cast(Assistant, self._load_component(assistant)), # type: ignore[arg-type] **params, ) @@ -241,11 +240,11 @@ async def prepare(self) -> Message: raise RagnaException( "Chat is already prepared", chat=self, - http_status_code=400, + http_status_code=status.HTTP_400_BAD_REQUEST, detail=RagnaException.EVENT, ) - await self._run(self.source_storage.store, self.documents) + await self._as_awaitable(self.source_storage.store, self.documents) self._prepared = True welcome = Message( @@ -269,17 +268,21 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message: raise RagnaException( "Chat is not prepared", chat=self, - http_status_code=400, + http_status_code=status.HTTP_400_BAD_REQUEST, detail=RagnaException.EVENT, ) - sources = await self._run(self.source_storage.retrieve, self.documents, prompt) + sources = await self._as_awaitable( + self.source_storage.retrieve, self.documents, prompt + ) question = Message(content=prompt, role=MessageRole.USER, sources=sources) self._messages.append(question) answer = Message( - content=self._run_gen(self.assistant.answer, self._messages.copy()), + content=self._as_async_iterator( + self.assistant.answer, self._messages.copy() + ), role=MessageRole.ASSISTANT, sources=sources, ) @@ -361,7 +364,7 @@ def format_error( formatted_error = f"- {param}" if annotation: annotation_ = cast( - type, model_cls.model_fields[param].annotation + type, model_cls.__pydantic_fields__[param].annotation ).__name__ formatted_error += f": {annotation_}" @@ -417,34 +420,17 @@ def format_error( raise RagnaException("\n".join(parts)) - async def _run( + def _as_awaitable( self, fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any - ) -> T: - kwargs = self._unpacked_params[fn] - if inspect.iscoroutinefunction(fn): - fn = cast(Callable[..., Awaitable[T]], fn) - coro = fn(*args, **kwargs) - else: - fn = cast(Callable[..., T], fn) - coro = run_in_threadpool(fn, *args, **kwargs) + ) -> Awaitable[T]: + return as_awaitable(fn, *args, **self._unpacked_params[fn]) - return await coro - - async def _run_gen( + def _as_async_iterator( self, fn: Union[Callable[..., Iterator[T]], Callable[..., AsyncIterator[T]]], *args: Any, ) -> AsyncIterator[T]: - kwargs = self._unpacked_params[fn] - if inspect.isasyncgenfunction(fn): - fn = cast(Callable[..., AsyncIterator[T]], fn) - async_gen = fn(*args, **kwargs) - else: - fn = cast(Callable[..., Iterator[T]], fn) - async_gen = iterate_in_threadpool(fn(*args, **kwargs)) - - async for item in async_gen: - yield item + return as_async_iterator(fn, *args, **self._unpacked_params[fn]) async def __aenter__(self) -> Chat: await self.prepare() diff --git a/ragna/core/_utils.py b/ragna/core/_utils.py index 972b0926..34ac2e7d 100644 --- a/ragna/core/_utils.py +++ b/ragna/core/_utils.py @@ -1,15 +1,13 @@ from __future__ import annotations import abc -import contextlib import enum import functools -import getpass import importlib import importlib.metadata import os from collections import defaultdict -from typing import Any, Collection, Optional, Type, Union, cast +from typing import Any, Callable, Collection, Optional, Type, Union, cast import packaging.requirements import pydantic @@ -121,14 +119,6 @@ def __repr__(self) -> str: return self._name -def default_user() -> str: - with contextlib.suppress(Exception): - return getpass.getuser() - with contextlib.suppress(Exception): - return os.getlogin() - return "Ragna" - - def merge_models( model_name: str, *models: Type[pydantic.BaseModel], @@ -136,14 +126,14 @@ def merge_models( ) -> Type[pydantic.BaseModel]: raw_field_definitions = defaultdict(list) for model_cls in models: - for name, field in model_cls.model_fields.items(): + for name, field in model_cls.__pydantic_fields__.items(): type_ = field.annotation default: Any if field.is_required(): default = ... elif field.default is pydantic_core.PydanticUndefined: - default = field.default_factory() # type: ignore[misc] + default = cast(Callable[[], Any], field.default_factory)() else: default = field.default diff --git a/ragna/deploy/__init__.py b/ragna/deploy/__init__.py index f3a86255..cdb2ba44 100644 --- a/ragna/deploy/__init__.py +++ b/ragna/deploy/__init__.py @@ -1,11 +1,18 @@ __all__ = [ - "Authentication", + "Auth", "Config", - "RagnaDemoAuthentication", + "DummyBasicAuth", + "GithubOAuth", + "InMemoryKeyValueStore", + "JupyterhubServerProxyAuth", + "KeyValueStore", + "NoAuth", + "RedisKeyValueStore", ] -from ._authentication import Authentication, RagnaDemoAuthentication +from ._auth import Auth, DummyBasicAuth, GithubOAuth, JupyterhubServerProxyAuth, NoAuth from ._config import Config +from ._key_value_store import InMemoryKeyValueStore, KeyValueStore, RedisKeyValueStore # isort: split diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py index 242730f6..788bfc38 100644 --- a/ragna/deploy/_api.py +++ b/ragna/deploy/_api.py @@ -2,34 +2,23 @@ from typing import Annotated, AsyncIterator import pydantic -from fastapi import ( - APIRouter, - Body, - Depends, - UploadFile, -) +from fastapi import APIRouter, Body, UploadFile from fastapi.responses import StreamingResponse -from ragna.core._utils import default_user - from . import _schemas as schemas +from ._auth import UserDependency from ._engine import Engine def make_router(engine: Engine) -> APIRouter: router = APIRouter(tags=["API"]) - def get_user() -> str: - return default_user() - - UserDependency = Annotated[str, Depends(get_user)] - @router.post("/documents") def register_documents( user: UserDependency, document_registrations: list[schemas.DocumentRegistration] ) -> list[schemas.Document]: return engine.register_documents( - user=user, document_registrations=document_registrations + user=user.name, document_registrations=document_registrations ) @router.put("/documents") @@ -44,7 +33,7 @@ async def content_stream() -> AsyncIterator[bytes]: return content_stream() await engine.store_documents( - user=user, + user=user.name, ids_and_streams=[ (uuid.UUID(document.filename), make_content_stream(document)) for document in documents @@ -60,19 +49,19 @@ async def create_chat( user: UserDependency, chat_creation: schemas.ChatCreation, ) -> schemas.Chat: - return engine.create_chat(user=user, chat_creation=chat_creation) + return engine.create_chat(user=user.name, chat_creation=chat_creation) @router.get("/chats") async def get_chats(user: UserDependency) -> list[schemas.Chat]: - return engine.get_chats(user=user) + return engine.get_chats(user=user.name) @router.get("/chats/{id}") async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat: - return engine.get_chat(user=user, id=id) + return engine.get_chat(user=user.name, id=id) @router.post("/chats/{id}/prepare") async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message: - return await engine.prepare_chat(user=user, id=id) + return await engine.prepare_chat(user=user.name, id=id) @router.post("/chats/{id}/answer") async def answer( @@ -81,7 +70,7 @@ async def answer( prompt: Annotated[str, Body(..., embed=True)], stream: Annotated[bool, Body(..., embed=True)] = False, ) -> schemas.Message: - message_stream = engine.answer_stream(user=user, chat_id=id, prompt=prompt) + message_stream = engine.answer_stream(user=user.name, chat_id=id, prompt=prompt) answer = await anext(message_stream) if not stream: @@ -106,6 +95,6 @@ async def to_jsonl( @router.delete("/chats/{id}") async def delete_chat(user: UserDependency, id: uuid.UUID) -> None: - engine.delete_chat(user=user, id=id) + engine.delete_chat(user=user.name, id=id) return router diff --git a/ragna/deploy/_auth.py b/ragna/deploy/_auth.py new file mode 100644 index 00000000..06492bf4 --- /dev/null +++ b/ragna/deploy/_auth.py @@ -0,0 +1,416 @@ +from __future__ import annotations + +import abc +import base64 +import contextlib +import json +import os +import re +import uuid +from typing import TYPE_CHECKING, Annotated, Awaitable, Callable, Optional, Union, cast + +import httpx +import panel as pn +import pydantic +from fastapi import Depends, FastAPI, Request, status +from fastapi.responses import HTMLResponse, RedirectResponse, Response +from fastapi.security.utils import get_authorization_scheme_param +from starlette.middleware.base import BaseHTTPMiddleware +from tornado.web import create_signed_value + +from ragna._utils import as_awaitable, default_user +from ragna.core import RagnaException + +from . import _schemas as schemas +from . import _templates as templates +from ._utils import redirect + +if TYPE_CHECKING: + from ._config import Config + from ._engine import Engine + from ._key_value_store import KeyValueStore + + +class Session(pydantic.BaseModel): + user: schemas.User + + +CallNext = Callable[[Request], Awaitable[Response]] + + +class SessionMiddleware(BaseHTTPMiddleware): + # panel uses cookies to transfer user information (see _cookie_dispatch() below) and + # signs them for security. However, since this happens after our authentication + # check, we can use an arbitrary, hardcoded value here. + _PANEL_COOKIE_SECRET = "ragna" + + def __init__( + self, app: FastAPI, *, config: Config, engine: Engine, api: bool, ui: bool + ) -> None: + super().__init__(app) + self._config = config + self._engine = engine + self._api = api + self._ui = ui + self._sessions: KeyValueStore[Session] = config.key_value_store() + + if ui: + pn.config.cookie_secret = self._PANEL_COOKIE_SECRET # type: ignore[misc] + + _COOKIE_NAME = "ragna" + + async def dispatch(self, request: Request, call_next: CallNext) -> Response: + if (authorization := request.headers.get("Authorization")) is not None: + return await self._api_token_dispatch( + request, call_next, authorization=authorization + ) + elif (cookie := request.cookies.get(self._COOKIE_NAME)) is not None: + return await self._cookie_dispatch(request, call_next, cookie=cookie) + elif request.url.path in {"/login", "/oauth-callback"}: + return await self._login_dispatch(request, call_next) + elif self._api and request.url.path.startswith("/api"): + return self._unauthorized("Missing authorization header") + elif self._ui and request.url.path.startswith("/ui"): + return redirect("/login") + else: + # Either an unknown route or something on the default router. In any case, + # this doesn't need a session and so we let it pass. + request.state.session = None + return await call_next(request) + + async def _api_token_dispatch( + self, request: Request, call_next: CallNext, authorization: str + ) -> Response: + scheme, api_key = get_authorization_scheme_param(authorization) + if scheme.lower() != "bearer": + return self._unauthorized("Bearer authentication scheme required") + + user, expired = self._engine.get_user_by_api_key(api_key) + if user is None or expired: + self._sessions.delete(api_key) + reason = "Invalid" if user is None else "Expired" + return self._unauthorized(f"{reason} API key") + + session = self._sessions.get(api_key) + if session is None: + # First time the API key is used + session = Session(user=user) + # We are using the API key value instead of its ID as session key for two + # reasons: + # 1. Similar to its ID, the value is unique and thus can be safely used as + # key. + # 2. If an API key was deleted, we lose its ID, but still need to be able to + # remove its corresponding session. + self._sessions.set(api_key, session, expires_after=3600) + + request.state.session = session + return await call_next(request) + + async def _cookie_dispatch( + self, request: Request, call_next: CallNext, *, cookie: str + ) -> Response: + session = self._sessions.get(cookie) + response: Response + if session is None: + # Invalid cookie + response = redirect("/login") + self._delete_cookie(response) + return response + + request.state.session = session + if self._ui and request.method == "GET" and request.url.path == "/ui": + # panel.state.user and panel.state.user_info are based on the two cookies + # below that the panel auth flow sets. Since we don't want extra cookies + # just for panel, we just inject them into the scope here, which will be + # parsed by panel down the line. After this initial request, the values are + # tied to the active session and don't have to be set again. + extra_cookies: dict[str, Union[str, bytes]] = { + "user": session.user.name, + "id_token": base64.b64encode(json.dumps(session.user.data).encode()), + } + extra_values = [ + ( + f"{key}=".encode() + + create_signed_value( + self._PANEL_COOKIE_SECRET, key, value, version=1 + ) + ) + for key, value in extra_cookies.items() + ] + + cookie_key = b"cookie" + idx, value = next( + (idx, value) + for idx, (key, value) in enumerate(request.scope["headers"]) + if key == cookie_key + ) + # We are not setting request.cookies or request.headers here, because any + # changes to them are not reflected back to the scope, which is the only + # safe way to transfer data between the middleware and an endpoint. + request.scope["headers"][idx] = ( + cookie_key, + b";".join([value, *extra_values]), + ) + + response = await call_next(request) + + if request.url.path == "/logout": + self._sessions.delete(cookie) + self._delete_cookie(response) + else: + self._sessions.refresh(cookie, expires_after=self._config.session_lifetime) + self._add_cookie(response, cookie) + + return response + + async def _login_dispatch(self, request: Request, call_next: CallNext) -> Response: + request.state.session = None + response = await call_next(request) + session = request.state.session + + if session is not None: + cookie = str(uuid.uuid4()) + self._sessions.set( + cookie, session, expires_after=self._config.session_lifetime + ) + self._add_cookie(response, cookie=cookie) + + return response + + def _unauthorized(self, message: str) -> Response: + return Response( + content=message, + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + ) + + def _add_cookie(self, response: Response, cookie: str) -> None: + response.set_cookie( + key=self._COOKIE_NAME, + value=cookie, + max_age=self._config.session_lifetime, + httponly=True, + samesite="lax", + ) + + def _delete_cookie(self, response: Response) -> None: + response.delete_cookie( + key=self._COOKIE_NAME, + httponly=True, + samesite="lax", + ) + + +async def _get_session(request: Request) -> Session: + session = cast(Optional[Session], request.state.session) + if session is None: + raise RagnaException( + "Not authenticated", + http_detail=RagnaException.EVENT, + http_status_code=status.HTTP_401_UNAUTHORIZED, + ) + return session + + +SessionDependency = Annotated[Session, Depends(_get_session)] + + +async def _get_user(session: SessionDependency) -> schemas.User: + return session.user + + +UserDependency = Annotated[schemas.User, Depends(_get_user)] + + +class Auth(abc.ABC): + """ + ADDME + """ + + @classmethod + def _add_to_app( + cls, app: FastAPI, *, config: Config, engine: Engine, api: bool, ui: bool + ) -> None: + self = cls() + + @app.get("/login", include_in_schema=False) + async def login_page(request: Request) -> Response: + return await as_awaitable(self.login_page, request) + + async def _login(request: Request) -> Response: + result = await as_awaitable(self.login, request) + if not isinstance(result, schemas.User): + return result + + engine.maybe_add_user(result) + request.state.session = Session(user=result) + return redirect("/") + + @app.post("/login", include_in_schema=False) + async def login(request: Request) -> Response: + return await _login(request) + + @app.get("/oauth-callback", include_in_schema=False) + async def oauth_callback(request: Request) -> Response: + return await _login(request) + + @app.get("/logout", include_in_schema=False) + async def logout() -> RedirectResponse: + return redirect("/") + + app.add_middleware( + SessionMiddleware, + config=config, + engine=engine, + api=api, + ui=ui, + ) + + @abc.abstractmethod + def login_page(self, request: Request) -> Response: ... + + @abc.abstractmethod + def login(self, request: Request) -> Union[schemas.User, Response]: ... + + +class _AutomaticLoginAuthBase(Auth): + def login_page(self, request: Request) -> Response: + # To invoke the Auth.login() method, the client either needs to + # - POST /login or + # - GET /oauth-callback + # Since we cannot instruct a browser to post when sending redirect response, we + # use the OAuth callback endpoint here, although this might have nothing to do + # with OAuth. + return redirect("/oauth-callback") + + +class NoAuth(_AutomaticLoginAuthBase): + """ + ADDME + """ + + def login(self, request: Request) -> schemas.User: + return schemas.User( + name=request.headers.get("X-Forwarded-User", default_user()) + ) + + +class DummyBasicAuth(Auth): + """Dummy OAuth2 password authentication without requirements. + + !!! danger + + As the name implies, this authentication is just testing or demo purposes and + should not be used in production. + """ + + def __init__(self) -> None: + self._password = os.environ.get("RAGNA_DUMMY_BASIC_AUTH_PASSWORD") + + def login_page( + self, + request: Request, + *, + username: Optional[str] = None, + fail_reason: Optional[str] = None, + ) -> HTMLResponse: + return HTMLResponse( + templates.render( + "basic_auth.html", username=username, fail_reason=fail_reason + ) + ) + + async def login(self, request: Request) -> Union[schemas.User, Response]: + async with request.form() as form: + username = cast(str, form.get("username")) + password = cast(str, form.get("password")) + + if username is None or password is None: + # This can only happen if the endpoint is not hit through the login page. + # Thus, instead of returning the failed login page like below, we just + # return an error. + raise RagnaException( + "Field 'username' or 'password' is missing from the form data.", + http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + http_detail=RagnaException.MESSAGE, + ) + + if not username: + return self.login_page(request, fail_reason="Username cannot be empty") + elif (self._password is not None and password != self._password) or ( + self._password is None and password != username + ): + return self.login_page( + request, username=username, fail_reason="Password incorrect" + ) + + return schemas.User(name=username) + + +class GithubOAuth(Auth): + def __init__(self) -> None: + # FIXME: requirements + self._client_id = os.environ["RAGNA_GITHUB_OAUTH_CLIENT_ID"] + self._client_secret = os.environ["RAGNA_GITHUB_OAUTH_CLIENT_SECRET"] + + def login_page(self, request: Request) -> HTMLResponse: + return HTMLResponse( + templates.render( + "oauth.html", + service="GitHub", + url=f"https://github.com/login/oauth/authorize?client_id={self._client_id}", + ) + ) + + async def login(self, request: Request) -> Union[schemas.User, Response]: + async with httpx.AsyncClient(headers={"Accept": "application/json"}) as client: + response = await client.post( + "https://github.com/login/oauth/access_token", + json={ + "client_id": self._client_id, + "client_secret": self._client_secret, + "code": request.query_params["code"], + }, + ) + access_token = response.json()["access_token"] + client.headers["Authorization"] = f"Bearer {access_token}" + + user_data = (await client.get("https://api.github.com/user")).json() + + organizations_data = ( + await client.get(user_data["organizations_url"]) + ).json() + organizations = { + organization_data["login"] for organization_data in organizations_data + } + if not (organizations & {"Quansight", "Quansight-Labs"}): + # FIXME: send the login page again with a failure message + return HTMLResponse("Unauthorized!") + + return schemas.User(name=user_data["login"]) + + +class JupyterhubServerProxyAuth(_AutomaticLoginAuthBase): + _JUPYTERHUB_ENV_VAR_PATTERN = re.compile(r"JUPYTERHUB_(?P.+)") + + def __init__(self) -> None: + data = {} + for env_var, value in os.environ.items(): + match = self._JUPYTERHUB_ENV_VAR_PATTERN.match(env_var) + if match is None: + continue + + key = match["key"].lower() + with contextlib.suppress(json.JSONDecodeError): + value = json.loads(value) + + data[key] = value + + name = data.pop("user") + if name is None: + raise RagnaException + + self._user = schemas.User(name=name, data=data) + + def login(self, request: Request) -> schemas.User: + return self._user diff --git a/ragna/deploy/_authentication.py b/ragna/deploy/_authentication.py deleted file mode 100644 index b8a4cbb8..00000000 --- a/ragna/deploy/_authentication.py +++ /dev/null @@ -1,132 +0,0 @@ -import abc -import os -import secrets -import time -from typing import cast - -import jwt -import rich -from fastapi import HTTPException, Request, status -from fastapi.security.utils import get_authorization_scheme_param - - -class Authentication(abc.ABC): - """Abstract base class for authentication used by the REST API.""" - - @abc.abstractmethod - async def create_token(self, request: Request) -> str: - """Authenticate user and create an authorization token. - - Args: - request: Request send to the `/token` endpoint of the REST API. - - Returns: - Authorization token. - """ - pass - - @abc.abstractmethod - async def get_user(self, request: Request) -> str: - """ - Args: - request: Request send to any endpoint of the REST API that requires - authorization. - - Returns: - Authorized user. - """ - pass - - -class RagnaDemoAuthentication(Authentication): - """Demo OAuth2 password authentication without requirements. - - !!! danger - - As the name implies, this authentication is just for demo purposes and should - not be used in production. - """ - - def __init__(self) -> None: - msg = f"INFO:\t{type(self).__name__}: You can log in with any username" - self._password = os.environ.get("RAGNA_DEMO_AUTHENTICATION_PASSWORD") - if self._password is None: - msg = f"{msg} and a matching password." - else: - msg = f"{msg} and the password {self._password}" - rich.print(msg) - - _JWT_SECRET = os.environ.get( - "RAGNA_DEMO_AUTHENTICATION_SECRET", secrets.token_urlsafe(32)[:32] - ) - _JWT_ALGORITHM = "HS256" - _JWT_TTL = int(os.environ.get("RAGNA_DEMO_AUTHENTICATION_TTL", 60 * 60 * 24 * 7)) - - async def create_token(self, request: Request) -> str: - """Authenticate user and create an authorization token. - - User name is arbitrary. Authentication is possible in two ways: - - 1. If the `RAGNA_DEMO_AUTHENTICATION_PASSWORD` environment variable is set, the - password is checked against that. - 2. Otherwise, the password has to match the user name. - - Args: - request: Request send to the `/token` endpoint of the REST API. Must include - the `"username"` and `"password"` as form data. - - Returns: - Authorization [JWT](https://jwt.io/) that expires after one week. - """ - async with request.form() as form: - username = form.get("username") - password = form.get("password") - - if username is None or password is None: - raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY) - - if (self._password is not None and password != self._password) or ( - self._password is None and password != username - ): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - - return jwt.encode( - payload={"user": username, "exp": time.time() + self._JWT_TTL}, - key=self._JWT_SECRET, - algorithm=self._JWT_ALGORITHM, - ) - - async def get_user(self, request: Request) -> str: - """Get user from an authorization token. - - Token has to be supplied in the - [Bearer authentication scheme](https://swagger.io/docs/specification/authentication/bearer-authentication/), - i.e. including a `Authorization: Bearer {token}` header. - - Args: - request: Request send to any endpoint of the REST API that requires - authorization. - - Returns: - Authorized user. - """ - - unauthorized = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not authenticated", - headers={"WWW-Authenticate": "Bearer"}, - ) - - authorization = request.headers.get("Authorization") - scheme, token = get_authorization_scheme_param(authorization) - if not authorization or scheme.lower() != "bearer": - raise unauthorized - - try: - payload = jwt.decode( - token, key=self._JWT_SECRET, algorithms=[self._JWT_ALGORITHM] - ) - except (jwt.InvalidSignatureError, jwt.ExpiredSignatureError): - raise unauthorized - - return cast(str, payload["user"]) diff --git a/ragna/deploy/_config.py b/ragna/deploy/_config.py index e960f831..3c6f46c3 100644 --- a/ragna/deploy/_config.py +++ b/ragna/deploy/_config.py @@ -2,7 +2,15 @@ import itertools from pathlib import Path -from typing import Annotated, Any, Callable, Generic, Type, TypeVar, Union +from typing import ( + Annotated, + Any, + Callable, + Generic, + Type, + TypeVar, + Union, +) import tomlkit import tomlkit.container @@ -18,7 +26,8 @@ from ragna._utils import make_directory from ragna.core import Assistant, Document, RagnaException, SourceStorage -from ._authentication import Authentication +from ._auth import Auth +from ._key_value_store import KeyValueStore T = TypeVar("T") @@ -79,8 +88,9 @@ def settings_customise_sources( default_factory=ragna.local_root ) - authentication: ImportString[type[Authentication]] = ( - "ragna.deploy.RagnaDemoAuthentication" # type: ignore[assignment] + auth: ImportString[type[Auth]] = "ragna.deploy.NoAuth" # type: ignore[assignment] + key_value_store: ImportString[type[KeyValueStore]] = ( + "ragna.deploy.InMemoryKeyValueStore" # type: ignore[assignment] ) document: ImportString[type[Document]] = "ragna.core.LocalDocument" # type: ignore[assignment] @@ -97,6 +107,7 @@ def settings_customise_sources( origins: list[str] = AfterConfigValidateDefault.make( lambda config: [f"http://{config.hostname}:{config.port}"] ) + session_lifetime: int = 60 * 60 * 24 database_url: str = AfterConfigValidateDefault.make( lambda config: f"sqlite:///{config.local_root}/ragna.db", diff --git a/ragna/deploy/_core.py b/ragna/deploy/_core.py index 65cca3cd..64de4a27 100644 --- a/ragna/deploy/_core.py +++ b/ragna/deploy/_core.py @@ -1,6 +1,7 @@ import contextlib import threading import time +import uuid import webbrowser from pathlib import Path from typing import AsyncContextManager, AsyncIterator, Callable, Optional, cast @@ -15,7 +16,9 @@ import ragna from ragna.core import RagnaException +from . import _schemas as schemas from ._api import make_router as make_api_router +from ._auth import UserDependency from ._config import Config from ._engine import Engine from ._ui import app as make_ui_app @@ -78,6 +81,8 @@ def server_available() -> bool: ignore_unavailable_components=ignore_unavailable_components, ) + config.auth._add_to_app(app, config=config, engine=engine, api=api, ui=ui) + if api: app.include_router(make_api_router(engine), prefix="/api") @@ -103,6 +108,24 @@ async def health() -> Response: async def version() -> str: return ragna.__version__ + @app.get("/user") + async def user(user: UserDependency) -> schemas.User: + return user + + @app.get("/api-keys") + def list_api_keys(user: UserDependency) -> list[schemas.ApiKey]: + return engine.list_api_keys(user=user.name) + + @app.post("/api-keys") + def create_api_key( + user: UserDependency, api_key_creation: schemas.ApiKeyCreation + ) -> schemas.ApiKey: + return engine.create_api_key(user=user.name, api_key_creation=api_key_creation) + + @app.delete("/api-keys/{id}") + def delete_api_key(user: UserDependency, id: uuid.UUID) -> None: + return engine.delete_api_key(user=user.name, id=id) + @app.exception_handler(RagnaException) async def ragna_exception_handler( request: Request, exc: RagnaException diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py index 529fa3b6..54f7b9f2 100644 --- a/ragna/deploy/_database.py +++ b/ragna/deploy/_database.py @@ -1,7 +1,7 @@ from __future__ import annotations import uuid -from typing import Any, Collection, Optional +from typing import Any, Collection, Optional, cast from urllib.parse import urlsplit from sqlalchemy import create_engine, select @@ -13,6 +13,14 @@ from . import _schemas as schemas +class UnknownUser(Exception): + def __init__( + self, name: Optional[str] = None, api_key: Optional[str] = None + ) -> None: + self.name = name + self.api_key = api_key + + class Database: def __init__(self, url: str) -> None: components = urlsplit(url) @@ -28,20 +36,88 @@ def __init__(self, url: str) -> None: self._to_orm = SchemaToOrmConverter() self._to_schema = OrmToSchemaConverter() - def _get_user(self, session: Session, *, username: str) -> orm.User: - user: Optional[orm.User] = session.execute( - select(orm.User).where(orm.User.name == username) - ).scalar_one_or_none() + def _get_orm_user_by_name(self, session: Session, *, name: str) -> orm.User: + user = cast( + Optional[orm.User], + session.execute( + select(orm.User).where(orm.User.name == name) + ).scalar_one_or_none(), + ) if user is None: - # Add a new user if the current username is not registered yet. Since this - # is behind the authentication layer, we don't need any extra security here. - user = orm.User(id=uuid.uuid4(), name=username) - session.add(user) - session.commit() + raise UnknownUser(name) return user + def maybe_add_user(self, session: Session, *, user: schemas.User) -> None: + try: + self._get_orm_user_by_name(session, name=user.name) + except UnknownUser: + orm_user = orm.User(id=uuid.uuid4(), name=user.name) + session.add(orm_user) + session.commit() + + def add_api_key( + self, session: Session, *, user: str, api_key: schemas.ApiKey + ) -> None: + user_id = self._get_orm_user_by_name(session, name=user).id + orm_api_key = orm.ApiKey( + id=uuid.uuid4(), + user_id=user_id, + name=api_key.name, + value=api_key.value, + expires_at=api_key.expires_at, + ) + session.add(orm_api_key) + session.commit() + + def get_api_keys(self, session: Session, *, user: str) -> list[schemas.ApiKey]: + return [ + self._to_schema.api_key(api_key) + for api_key in session.execute( + select(orm.ApiKey).where( + orm.ApiKey.user_id + == self._get_orm_user_by_name(session, name=user).id + ) + ) + .scalars() + .all() + ] + + def delete_api_key(self, session: Session, *, user: str, id: uuid.UUID) -> None: + orm_api_key = session.execute( + select(orm.ApiKey).where( + (orm.ApiKey.id == id) + & ( + orm.ApiKey.user_id + == self._get_orm_user_by_name(session, name=user).id + ) + ) + ).scalar_one_or_none() + + if orm_api_key is None: + raise RagnaException + + session.delete(orm_api_key) # type: ignore[no-untyped-call] + session.commit() + + def get_user_by_api_key( + self, session: Session, api_key_value: str + ) -> Optional[tuple[schemas.User, schemas.ApiKey]]: + orm_api_key = session.execute( + select(orm.ApiKey) # type: ignore[attr-defined] + .options(joinedload(orm.ApiKey.user)) + .where(orm.ApiKey.value == api_key_value) + ).scalar_one_or_none() + + if orm_api_key is None: + return None + + return ( + self._to_schema.user(orm_api_key.user), + self._to_schema.api_key(orm_api_key), + ) + def add_documents( self, session: Session, @@ -49,7 +125,7 @@ def add_documents( user: str, documents: list[schemas.Document], ) -> None: - user_id = self._get_user(session, username=user).id + user_id = self._get_orm_user_by_name(session, name=user).id session.add_all( [self._to_orm.document(document, user_id=user_id) for document in documents] ) @@ -82,7 +158,7 @@ def get_documents( def add_chat(self, session: Session, *, user: str, chat: schemas.Chat) -> None: orm_chat = self._to_orm.chat( - chat, user_id=self._get_user(session, username=user).id + chat, user_id=self._get_orm_user_by_name(session, name=user).id ) # We need to merge and not add here, because the documents are already in the DB session.merge(orm_chat) @@ -102,7 +178,8 @@ def get_chats(self, session: Session, *, user: str) -> list[schemas.Chat]: self._to_schema.chat(chat) for chat in session.execute( self._select_chat(eager=True).where( - orm.Chat.user_id == self._get_user(session, username=user).id + orm.Chat.user_id + == self._get_orm_user_by_name(session, name=user).id ) ) .scalars() @@ -117,7 +194,10 @@ def _get_orm_chat( session.execute( self._select_chat(eager=eager).where( (orm.Chat.id == id) - & (orm.Chat.user_id == self._get_user(session, username=user).id) + & ( + orm.Chat.user_id + == self._get_orm_user_by_name(session, name=user).id + ) ) ) .unique() @@ -134,7 +214,7 @@ def get_chat(self, session: Session, *, user: str, id: uuid.UUID) -> schemas.Cha def update_chat(self, session: Session, user: str, chat: schemas.Chat) -> None: orm_chat = self._to_orm.chat( - chat, user_id=self._get_user(session, username=user).id + chat, user_id=self._get_orm_user_by_name(session, name=user).id ) session.merge(orm_chat) session.commit() @@ -199,6 +279,18 @@ def chat( class OrmToSchemaConverter: + def user(self, user: orm.User) -> schemas.User: + return schemas.User(name=user.name) + + def api_key(self, api_key: orm.ApiKey) -> schemas.ApiKey: + return schemas.ApiKey( + id=api_key.id, + name=api_key.name, + expires_at=api_key.expires_at, + obfuscated=True, + value=api_key.value, + ) + def document(self, document: orm.Document) -> schemas.Document: return schemas.Document( id=document.id, name=document.name, metadata=document.metadata_ diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 6df48460..7d668c78 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -1,16 +1,17 @@ +import secrets import uuid from typing import Any, AsyncIterator, Optional, Type, cast from fastapi import status as http_status_code import ragna -from ragna import Rag, core +from ragna import core from ragna._utils import make_directory -from ragna.core import RagnaException +from ragna.core import Rag, RagnaException from ragna.core._rag import SpecialChatParams -from ragna.deploy import Config from . import _schemas as schemas +from ._config import Config from ._database import Database @@ -33,6 +34,47 @@ def __init__(self, *, config: Config, ignore_unavailable_components: bool) -> No self._to_core = SchemaToCoreConverter(config=self._config, rag=self._rag) self._to_schema = CoreToSchemaConverter() + def maybe_add_user(self, user: schemas.User) -> None: + with self._database.get_session() as session: + return self._database.maybe_add_user(session, user=user) + + def get_user_by_api_key( + self, api_key_value: str + ) -> tuple[Optional[schemas.User], bool]: + with self._database.get_session() as session: + data = self._database.get_user_by_api_key( + session, api_key_value=api_key_value + ) + + if data is None: + return None, False + + user, api_key = data + return user, api_key.expired + + def create_api_key( + self, user: str, api_key_creation: schemas.ApiKeyCreation + ) -> schemas.ApiKey: + api_key = schemas.ApiKey( + name=api_key_creation.name, + expires_at=api_key_creation.expires_at, + obfuscated=False, + value=secrets.token_urlsafe(32)[:32], + ) + + with self._database.get_session() as session: + self._database.add_api_key(session, user=user, api_key=api_key) + + return api_key + + def list_api_keys(self, user: str) -> list[schemas.ApiKey]: + with self._database.get_session() as session: + return self._database.get_api_keys(session, user=user) + + def delete_api_key(self, user: str, id: uuid.UUID) -> None: + with self._database.get_session() as session: + self._database.delete_api_key(session, user=user, id=id) + def _get_component_json_schema( self, component: Type[core.Component], diff --git a/ragna/deploy/_key_value_store.py b/ragna/deploy/_key_value_store.py new file mode 100644 index 00000000..8218aacc --- /dev/null +++ b/ragna/deploy/_key_value_store.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import abc +import os +import time +from typing import Any, Callable, Generic, Optional, TypeVar, Union, cast + +import pydantic + +from ragna.core import PackageRequirement, Requirement +from ragna.core._utils import RequirementsMixin + +M = TypeVar("M", bound=pydantic.BaseModel) + + +class SerializableModel(pydantic.BaseModel, Generic[M]): + cls: pydantic.ImportString[type[M]] + obj: dict[str, Any] + + @classmethod + def from_model(cls, model: M) -> SerializableModel[M]: + return SerializableModel(cls=type(model), obj=model.model_dump(mode="json")) + + def to_model(self) -> M: + return self.cls.model_validate(self.obj) + + +class KeyValueStore(abc.ABC, RequirementsMixin, Generic[M]): + def serialize(self, model: M) -> str: + return SerializableModel.from_model(model).model_dump_json() + + def deserialize(self, data: Union[str, bytes]) -> M: + return SerializableModel.model_validate_json(data).to_model() + + @abc.abstractmethod + def set( + self, key: str, model: M, *, expires_after: Optional[int] = None + ) -> None: ... + + @abc.abstractmethod + def get(self, key: str) -> Optional[M]: ... + + @abc.abstractmethod + def delete(self, key: str) -> None: ... + + @abc.abstractmethod + def refresh(self, key: str, *, expires_after: Optional[int] = None) -> None: ... + + +class InMemoryKeyValueStore(KeyValueStore[M]): + def __init__(self) -> None: + self._store: dict[str, tuple[M, Optional[float]]] = {} + self._timer: Callable[[], float] = time.monotonic + + def set(self, key: str, model: M, *, expires_after: Optional[int] = None) -> None: + if expires_after is not None: + expires_at = self._timer() + expires_after + else: + expires_at = None + self._store[key] = (model, expires_at) + + def get(self, key: str) -> Optional[M]: + value = self._store.get(key) + if value is None: + return None + + model, expires_at = value + if expires_at is not None and self._timer() >= expires_at: + self.delete(key) + return None + + return model + + def delete(self, key: str) -> None: + if key not in self._store: + return + + del self._store[key] + + def refresh(self, key: str, *, expires_after: Optional[int] = None) -> None: + value = self._store.get(key) + if value is None: + return + + model, _ = value + self.set(key, model, expires_after=expires_after) + + +class RedisKeyValueStore(KeyValueStore[M]): + @classmethod + def requirements(cls) -> list[Requirement]: + return [PackageRequirement("redis")] + + def __init__(self) -> None: + import redis + + self._r = redis.Redis( + host=os.environ.get("RAGNA_REDIS_HOST", "localhost"), + port=int(os.environ.get("RAGNA_REDIS_PORT", 6379)), + ) + + def set(self, key: str, model: M, *, expires_after: Optional[int] = None) -> None: + self._r.set(key, self.serialize(model), ex=expires_after) + + def get(self, key: str) -> Optional[M]: + value = cast(bytes, self._r.get(key)) + if value is None: + return None + return self.deserialize(value) + + def delete(self, key: str) -> None: + self._r.delete(key) + + def refresh(self, key: str, *, expires_after: Optional[int] = None) -> None: + if expires_after is None: + self._r.persist(key) + else: + self._r.expire(key, expires_after) diff --git a/ragna/deploy/_orm.py b/ragna/deploy/_orm.py index 74f3560f..3017e660 100644 --- a/ragna/deploy/_orm.py +++ b/ragna/deploy/_orm.py @@ -31,7 +31,7 @@ def process_result_value( return json.loads(value) -class UtcDateTime(types.TypeDecorator): +class UtcAwareDateTime(types.TypeDecorator): """UTC timezone aware datetime type. This is needed because sqlalchemy.types.DateTime(timezone=True) does not @@ -63,14 +63,25 @@ class Base(DeclarativeBase): pass -# FIXME: Do we actually need this table? If we are sure that a user is unique and has to -# be authenticated from the API layer, it seems having an extra mapping here is not -# needed? class User(Base): __tablename__ = "users" id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined] + name = Column(types.String, nullable=False, unique=True) + api_keys = relationship("ApiKey", back_populates="user") + + +class ApiKey(Base): + __tablename__ = "api_keys" + + id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined] + + user_id = Column(ForeignKey("users.id")) + user = relationship("User", back_populates="api_keys") + name = Column(types.String, nullable=False) + value = Column(types.String, nullable=False, unique=True) + expires_at = Column(UtcAwareDateTime, nullable=False) document_chat_association_table = Table( @@ -163,4 +174,5 @@ class Message(Base): secondary=source_message_association_table, back_populates="messages", ) - timestamp = Column(UtcDateTime, nullable=False) + + timestamp = Column(UtcAwareDateTime, nullable=False) diff --git a/ragna/deploy/_schemas.py b/ragna/deploy/_schemas.py index 3c080f43..0b73aa23 100644 --- a/ragna/deploy/_schemas.py +++ b/ragna/deploy/_schemas.py @@ -4,11 +4,61 @@ from datetime import datetime, timezone from typing import Annotated, Any -from pydantic import AfterValidator, BaseModel, Field +from pydantic import ( + AfterValidator, + BaseModel, + Field, + ValidationInfo, + computed_field, + field_validator, +) import ragna.core +class User(BaseModel): + name: str + data: dict[str, Any] = Field(default_factory=dict) + + +class ApiKeyCreation(BaseModel): + name: str + expires_at: datetime + + +class ApiKey(BaseModel): + id: uuid.UUID = Field(default_factory=uuid.uuid4) + name: str + expires_at: datetime + obfuscated: bool = True + value: str + + @field_validator("expires_at") + @classmethod + def _set_utc_timezone(cls, v: datetime) -> datetime: + if v.tzinfo is None: + return v.replace(tzinfo=timezone.utc) + else: + return v.astimezone(timezone.utc) + + @computed_field # type: ignore[misc] + @property + def expired(self) -> bool: + return datetime.now(timezone.utc) >= self.expires_at + + @field_validator("value") + @classmethod + def _maybe_obfuscate(cls, v: str, info: ValidationInfo) -> str: + if not info.data["obfuscated"]: + return v + + i = min(len(v) // 6, 3) + if i > 0: + return f"{v[:i]}***{v[-i:]}" + else: + return "***" + + def _set_utc_timezone(v: datetime) -> datetime: if v.tzinfo is None: return v.replace(tzinfo=timezone.utc) diff --git a/ragna/deploy/_templates/__init__.py b/ragna/deploy/_templates/__init__.py new file mode 100644 index 00000000..7c977794 --- /dev/null +++ b/ragna/deploy/_templates/__init__.py @@ -0,0 +1,16 @@ +import contextlib +from pathlib import Path +from typing import Any + +from jinja2 import Environment, FileSystemLoader, TemplateNotFound + +ENVIRONMENT = Environment(loader=FileSystemLoader(Path(__file__).parent)) + + +def render(template: str, **context: Any) -> str: + with contextlib.suppress(TemplateNotFound): + css_template = ENVIRONMENT.get_template(str(Path(template).with_suffix(".css"))) + context["__template_css__"] = css_template.render(**context) + + template = ENVIRONMENT.get_template(template) + return template.render(**context) diff --git a/ragna/deploy/_templates/base.html b/ragna/deploy/_templates/base.html new file mode 100644 index 00000000..e4f61d54 --- /dev/null +++ b/ragna/deploy/_templates/base.html @@ -0,0 +1,49 @@ + + + + + Ragna + + + + + + + + +
+ {% block content %}{% endblock %} +
+ + diff --git a/ragna/deploy/_templates/basic_auth.css b/ragna/deploy/_templates/basic_auth.css new file mode 100644 index 00000000..80a7b918 --- /dev/null +++ b/ragna/deploy/_templates/basic_auth.css @@ -0,0 +1,6 @@ +.basic-auth { + height: 100%; + display: flex; + flex-direction: column; + justify-content: space-between; +} diff --git a/ragna/deploy/_templates/basic_auth.html b/ragna/deploy/_templates/basic_auth.html new file mode 100644 index 00000000..5269334f --- /dev/null +++ b/ragna/deploy/_templates/basic_auth.html @@ -0,0 +1,33 @@ +{% extends "base.html" %} {% block content %} +
+

Log in

+ {% if fail_reason %} +
{{ fail_reason }}
+ {% endif %} +
+
+ Username + +
+
+ Password + +
+
+
+ +
+
+{% endblock %} diff --git a/ragna/deploy/_templates/oauth.html b/ragna/deploy/_templates/oauth.html new file mode 100644 index 00000000..7a21f0cd --- /dev/null +++ b/ragna/deploy/_templates/oauth.html @@ -0,0 +1,7 @@ +{% extends "base.html" %} {% block content %} + +{% endblock %} diff --git a/ragna/deploy/_ui/api_wrapper.py b/ragna/deploy/_ui/api_wrapper.py index 9dcc6b16..b03f2232 100644 --- a/ragna/deploy/_ui/api_wrapper.py +++ b/ragna/deploy/_ui/api_wrapper.py @@ -2,9 +2,9 @@ from datetime import datetime import emoji +import panel as pn import param -from ragna.core._utils import default_user from ragna.deploy import _schemas as schemas from ragna.deploy._engine import Engine @@ -12,7 +12,7 @@ class ApiWrapper(param.Parameterized): def __init__(self, engine: Engine): super().__init__() - self._user = default_user() + self._user = pn.state.user self._engine = engine async def get_chats(self): diff --git a/ragna/deploy/_ui/left_sidebar.py b/ragna/deploy/_ui/left_sidebar.py index 267a3b77..3d45849f 100644 --- a/ragna/deploy/_ui/left_sidebar.py +++ b/ragna/deploy/_ui/left_sidebar.py @@ -14,6 +14,7 @@ class LeftSidebar(pn.viewable.Viewer): def __init__(self, api_wrapper, **params): super().__init__(**params) + self.api_wrapper = api_wrapper self.on_click_chat = None self.on_click_new_chat = None @@ -104,6 +105,7 @@ def __panel__(self): + self.chat_buttons + [ pn.layout.VSpacer(), + pn.pane.HTML(f"user: {self.api_wrapper._user}"), pn.pane.HTML(f"version: {ragna_version}"), # self.footer() ] diff --git a/ragna/source_storages/_vector_database.py b/ragna/source_storages/_vector_database.py index 81ec2df5..3ed5ca35 100644 --- a/ragna/source_storages/_vector_database.py +++ b/ragna/source_storages/_vector_database.py @@ -89,7 +89,7 @@ def _chunk_pages( ): tokens, page_numbers = zip(*window) yield Chunk( - text=self._tokenizer.decode(tokens), # type: ignore[arg-type] + text=self._tokenizer.decode(tokens), page_numbers=list(filter(lambda n: n is not None, page_numbers)) or None, num_tokens=len(tokens), diff --git a/scripts/add_chats.py b/scripts/add_chats.py index 5f550289..14d8827a 100644 --- a/scripts/add_chats.py +++ b/scripts/add_chats.py @@ -8,25 +8,14 @@ def main(): client = httpx.Client(base_url="http://127.0.0.1:31476") client.get("/health").raise_for_status() - # ## authentication - # - # username = default_user() - # token = ( - # client.post( - # "/token", - # data={ - # "username": username, - # "password": os.environ.get( - # "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username - # ), - # }, - # ) - # .raise_for_status() - # .json() - # ) - # client.headers["Authorization"] = f"Bearer {token}" - - print() + ## authentication + + # This only works if Ragna was deployed with ragna.core.NoAuth + # If that is not the case, login in whatever way is required, grab the API token and + # use the following instead + # client.headers["Authorization"] = f"Bearer {api_token}" + + client.get("/login", follow_redirects=True).raise_for_status() ## documents diff --git a/scripts/docs/gen_files.py b/scripts/docs/gen_files.py index a350f035..42bd6107 100644 --- a/scripts/docs/gen_files.py +++ b/scripts/docs/gen_files.py @@ -7,14 +7,14 @@ import mkdocs_gen_files import typer.rich_utils +from ragna._cli import app as cli_app # noqa: E402 from ragna.deploy import Config -from ragna.deploy._api import app as api_app -from ragna.deploy._cli import app as cli_app +from ragna.deploy._core import make_app as make_deploy_app def main(): cli_reference() - api_reference() + deploy_reference() config_reference() @@ -43,8 +43,14 @@ def get_doc(command): file.write(get_doc(command.name or command.callback.__name__)) -def api_reference(): - app = api_app(config=Config(), ignore_unavailable_components=False) +def deploy_reference(): + app = make_deploy_app( + config=Config(), + api=True, + ui=True, + ignore_unavailable_components=False, + open_browser=False, + ) openapi_json = fastapi.openapi.utils.get_openapi( title=app.title, version=app.version, diff --git a/tests/assistants/test_api.py b/tests/assistants/test_api.py index f7c9c594..3f10089b 100644 --- a/tests/assistants/test_api.py +++ b/tests/assistants/test_api.py @@ -2,17 +2,16 @@ import itertools import json import os -import time from pathlib import Path import httpx import pytest from ragna import assistants -from ragna._utils import timeout_after +from ragna._utils import BackgroundSubprocess from ragna.assistants._http_api import HttpApiAssistant, HttpStreamingProtocol from ragna.core import Message, RagnaException -from tests.utils import background_subprocess, get_available_port, skip_on_windows +from tests.utils import get_available_port, skip_on_windows HTTP_API_ASSISTANTS = [ assistant @@ -43,26 +42,19 @@ def streaming_server(): port = get_available_port() base_url = f"http://localhost:{port}" - with background_subprocess( + def check_fn(): + try: + return httpx.get(f"{base_url}/health").is_success + except httpx.ConnectError: + return False + + with BackgroundSubprocess( "uvicorn", f"--app-dir={Path(__file__).parent}", f"--port={port}", "streaming_server:app", + startup_fn=check_fn, ): - - def up(): - try: - return httpx.get(f"{base_url}/health").is_success - except httpx.ConnectError: - return False - - @timeout_after(10, message="Failed to start streaming server") - def wait(): - while not up(): - time.sleep(0.2) - - wait() - yield base_url diff --git a/tests/conftest.py b/tests/conftest.py index 6f380be2..4f44218e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import ragna -@pytest.fixture +@pytest.fixture(autouse=True) def tmp_local_root(tmp_path): old = ragna.local_root() try: diff --git a/tests/deploy/api/test_components.py b/tests/deploy/api/test_components.py index 8e23cdeb..9ccdb2ba 100644 --- a/tests/deploy/api/test_components.py +++ b/tests/deploy/api/test_components.py @@ -1,11 +1,10 @@ import pytest from fastapi import status -from fastapi.testclient import TestClient from ragna import assistants from ragna.core import RagnaException from ragna.deploy import Config -from tests.deploy.utils import authenticate_with_api, make_api_app +from tests.deploy.utils import make_api_app, make_api_client @pytest.mark.parametrize("ignore_unavailable_components", [True, False]) @@ -19,14 +18,10 @@ def test_ignore_unavailable_components(ignore_unavailable_components): config = Config(assistants=[available_assistant, unavailable_assistant]) if ignore_unavailable_components: - with TestClient( - make_api_app( - config=config, - ignore_unavailable_components=ignore_unavailable_components, - ) + with make_api_client( + config=config, + ignore_unavailable_components=ignore_unavailable_components, ) as client: - authenticate_with_api(client) - components = client.get("/api/components").raise_for_status().json() assert [assistant["title"] for assistant in components["assistants"]] == [ available_assistant.display_name() @@ -61,11 +56,9 @@ def test_unknown_component(tmp_local_root): with open(document_path, "w") as file: file.write("!\n") - with TestClient( - make_api_app(config=Config(), ignore_unavailable_components=False) + with make_api_client( + config=Config(), ignore_unavailable_components=False ) as client: - authenticate_with_api(client) - document = ( client.post("/api/documents", json=[{"name": document_path.name}]) .raise_for_status() diff --git a/tests/deploy/api/test_e2e.py b/tests/deploy/api/test_e2e.py index e023c0ee..fa342d91 100644 --- a/tests/deploy/api/test_e2e.py +++ b/tests/deploy/api/test_e2e.py @@ -1,10 +1,9 @@ import json import pytest -from fastapi.testclient import TestClient from ragna.deploy import Config -from tests.deploy.utils import TestAssistant, authenticate_with_api, make_api_app +from tests.deploy.utils import TestAssistant, make_api_client from tests.utils import skip_on_windows @@ -20,11 +19,7 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): with open(document_path, "w") as file: file.write("!\n") - with TestClient( - make_api_app(config=config, ignore_unavailable_components=False) - ) as client: - authenticate_with_api(client) - + with make_api_client(config=config, ignore_unavailable_components=False) as client: assert client.get("/api/chats").raise_for_status().json() == [] documents = ( diff --git a/tests/deploy/api/utils.py b/tests/deploy/api/utils.py new file mode 100644 index 00000000..6dd3fb78 --- /dev/null +++ b/tests/deploy/api/utils.py @@ -0,0 +1,35 @@ +import os + +from fastapi.testclient import TestClient + +from ragna._utils import default_user +from ragna.deploy._core import make_app + + +def make_api_app(*, config, ignore_unavailable_components): + return make_app( + config, + api=True, + ui=False, + ignore_unavailable_components=ignore_unavailable_components, + open_browser=False, + ) + + +def authenticate(client: TestClient) -> None: + return + username = default_user() + token = ( + client.post( + "/token", + data={ + "username": username, + "password": os.environ.get( + "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username + ), + }, + ) + .raise_for_status() + .json() + ) + client.headers["Authorization"] = f"Bearer {token}" diff --git a/tests/deploy/utils.py b/tests/deploy/utils.py index f8d1277a..e30fa758 100644 --- a/tests/deploy/utils.py +++ b/tests/deploy/utils.py @@ -1,10 +1,10 @@ -import os +import contextlib import time from fastapi.testclient import TestClient from ragna.assistants import RagnaDemoAssistant -from ragna.core._utils import default_user +from ragna.deploy._auth import SessionMiddleware from ragna.deploy._core import make_app @@ -38,19 +38,17 @@ def make_api_app(*, config, ignore_unavailable_components): def authenticate_with_api(client: TestClient) -> None: - return - username = default_user() - token = ( - client.post( - "/token", - data={ - "username": username, - "password": os.environ.get( - "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username - ), - }, + client.get("/login", follow_redirects=True).raise_for_status() + assert SessionMiddleware._COOKIE_NAME in client.cookies + + +@contextlib.contextmanager +def make_api_client(*, config, ignore_unavailable_components): + with TestClient( + make_api_app( + config=config, + ignore_unavailable_components=ignore_unavailable_components, ) - .raise_for_status() - .json() - ) - client.headers["Authorization"] = f"Bearer {token}" + ) as client: + authenticate_with_api(client) + yield client diff --git a/tests/utils.py b/tests/utils.py index 3fc31732..87081f65 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,8 +1,5 @@ -import contextlib import platform import socket -import subprocess -import sys import pytest @@ -11,16 +8,6 @@ ) -@contextlib.contextmanager -def background_subprocess(*args, stdout=sys.stdout, stderr=sys.stdout, **kwargs): - process = subprocess.Popen(args, stdout=stdout, stderr=stderr, **kwargs) - try: - yield process - finally: - process.kill() - process.communicate() - - def get_available_port(): with socket.socket() as s: s.bind(("", 0)) From d444255a53090252f014c404c65db3fe42011b32 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 10:42:01 +0100 Subject: [PATCH 10/29] fix env name in CI --- .github/actions/setup-env/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index b39f4b18..a527f590 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -17,7 +17,7 @@ runs: uses: conda-incubator/setup-miniconda@v3 with: miniforge-version: latest - activate-environment: ragna-deploy-dev + activate-environment: ragna-dev - name: Display conda info shell: bash -el {0} From 0d1e09b07591929be5e1c0accb4e04f5132a2c80 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 11:33:21 +0100 Subject: [PATCH 11/29] fix docs --- docs/assets/images/web-ui-login.png | 3 --- docs/examples/gallery_streaming.py | 2 +- docs/references/config.md | 8 ++++++++ docs/tutorials/gallery_custom_components.py | 9 +++++---- docs/tutorials/gallery_rest_api.py | 12 ++++++------ docs/tutorials/gallery_web_ui.py | 12 ++---------- ragna/_docs.py | 8 ++++---- ragna/_utils.py | 12 ++++++------ 8 files changed, 32 insertions(+), 34 deletions(-) delete mode 100644 docs/assets/images/web-ui-login.png diff --git a/docs/assets/images/web-ui-login.png b/docs/assets/images/web-ui-login.png deleted file mode 100644 index 0fc87981..00000000 --- a/docs/assets/images/web-ui-login.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:35f2191ef348b1fb470ed11e6722bc8b10ce5d268e3f245af13f2b987746c16b -size 9246 diff --git a/docs/examples/gallery_streaming.py b/docs/examples/gallery_streaming.py index 846897e6..27af3e24 100644 --- a/docs/examples/gallery_streaming.py +++ b/docs/examples/gallery_streaming.py @@ -110,7 +110,7 @@ def answer(self, messages): ragna_deploy = ragna_docs.RagnaDeploy(config) client, document = ragna_deploy.get_http_client( - authenticate=True, upload_document=True + authenticate=True, upload_sample_document=True ) # %% diff --git a/docs/references/config.md b/docs/references/config.md index ba83263a..53fd9ef3 100644 --- a/docs/references/config.md +++ b/docs/references/config.md @@ -73,6 +73,10 @@ Local root directory Ragna uses for storing files. See [ragna.local_root][]. [ragna.deploy.Auth][] class to use for authenticating users. +### `key_value_store` + +[ragna.deploy.KeyValueStore][] class to use for temporary storage. + ### `document` [ragna.core.Document][] class to use to upload and read documents. @@ -103,6 +107,10 @@ external clients. [CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) origins that are allowed to connect to the REST API. +### `session_lifetime` + +Number of seconds of inactivity after a user has to login again. + ### `database_url` URL of a SQL database that will be used to store the Ragna state. See diff --git a/docs/tutorials/gallery_custom_components.py b/docs/tutorials/gallery_custom_components.py index 6c556043..22b43604 100644 --- a/docs/tutorials/gallery_custom_components.py +++ b/docs/tutorials/gallery_custom_components.py @@ -189,7 +189,7 @@ def answer(self, messages: list[Message]) -> Iterator[str]: ragna_deploy = ragna_docs.RagnaDeploy(config) client, document = ragna_deploy.get_http_client( - authenticate=True, upload_document=True + authenticate=True, upload_sample_document=True ) # %% @@ -205,7 +205,7 @@ def answer(self, messages: list[Message]) -> Iterator[str]: response = client.post( "/api/chats", json={ - "name": "Tutorial REST API", + "name": "Tutorial Custom Components", "document_ids": [document["id"]], "source_storage": TutorialSourceStorage.display_name(), "assistant": TutorialAssistant.display_name(), @@ -322,7 +322,7 @@ def answer( ragna_deploy = ragna_docs.RagnaDeploy(config) client, document = ragna_deploy.get_http_client( - authenticate=True, upload_document=True + authenticate=True, upload_sample_document=True ) # %% @@ -332,7 +332,7 @@ def answer( response = client.post( "/api/chats", json={ - "name": "Tutorial REST API", + "name": "Tutorial Elaborate Custom Components", "document_ids": [document["id"]], "source_storage": TutorialSourceStorage.display_name(), "assistant": ElaborateTutorialAssistant.display_name(), @@ -343,6 +343,7 @@ def answer( }, ).raise_for_status() chat = response.json() +print(json.dumps(chat, indent=2)) # %% diff --git a/docs/tutorials/gallery_rest_api.py b/docs/tutorials/gallery_rest_api.py index ede8833d..77186562 100644 --- a/docs/tutorials/gallery_rest_api.py +++ b/docs/tutorials/gallery_rest_api.py @@ -14,7 +14,7 @@ # Ragnas REST API is normally started from a terminal with # # ```bash -# $ ragna api +# $ ragna deploy # ``` # # For this tutorial we use our helper that does the equivalent just from Python. @@ -99,8 +99,8 @@ # %% # The upload process in Ragna consists of two parts: # -# 1. Announce the file to be uploaded. Under the hood this registers the document -# in Ragna's database and returns the document ID, which is needed for the upload. +# 1. Register the document in Ragna's database. This returns the document ID, which is +# needed for the upload. response = client.post( "/api/documents", json=[{"name": document_path.name}] @@ -109,7 +109,7 @@ print(json.dumps(documents, indent=2)) # %% -# 2. Perform the actual upload with the information from step 1. through a +# 2. Perform the upload through a # [multipart request](https://swagger.io/docs/specification/describing-request-body/multipart-requests/) # with the following parameters: # @@ -154,8 +154,8 @@ print(json.dumps(chat, indent=2)) # %% -# As can be seen by the `"prepared"` field in the `chat` JSON object we still need to -# prepare it. +# As can be seen by the `"prepared": false` value in the `chat` JSON object we still +# need to prepare it. client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() diff --git a/docs/tutorials/gallery_web_ui.py b/docs/tutorials/gallery_web_ui.py index 51ff57a5..f51ce7e7 100644 --- a/docs/tutorials/gallery_web_ui.py +++ b/docs/tutorials/gallery_web_ui.py @@ -13,15 +13,7 @@ # You can launch the web application from the CLI: # # ```bash -# ragna ui +# ragna deploy # ``` # -# It will open in a browser window automatically and you should see the login screen: -# -# ![ragna UI login screen with a username and password form. Both the username and -# password are filled with the string "ragna" and the password is visible and not -# censored](../../assets/images/web-ui-login.png) -# -# By default [ragna.deploy.RagnaDemoAuthentication][] is used. Thus, you can use any -# username and a matching password to authenticate. For example, you can leave both fields -# blank or use `ragna` / `ragna` as in the image above. +# The UI will open in a browser window automatically. diff --git a/ragna/_docs.py b/ragna/_docs.py index d06fd215..a9a3977c 100644 --- a/ragna/_docs.py +++ b/ragna/_docs.py @@ -113,12 +113,12 @@ def get_http_client( self, *, authenticate: bool = False, - upload_document: bool = False, + upload_sample_document: bool = False, ) -> tuple[httpx.Client, Optional[dict[str, Any]]]: - if upload_document and not authenticate: + if upload_sample_document and not authenticate: raise RagnaException( "Cannot upload a document without authenticating first. " - "Set authenticate=True when using upload_document=True." + "Set authenticate=True when using upload_sample_document=True." ) client = httpx.Client(base_url=self.config._url) @@ -126,7 +126,7 @@ def get_http_client( if authenticate: client.get("/login", follow_redirects=True) - if upload_document: + if upload_sample_document: name, content = "ragna.txt", SAMPLE_CONTENT response = client.post( diff --git a/ragna/_utils.py b/ragna/_utils.py index 32bb24c5..952b4d20 100644 --- a/ragna/_utils.py +++ b/ragna/_utils.py @@ -175,14 +175,16 @@ def __init__( *cmd: str, stdout: Any = sys.stdout, stderr: Any = sys.stdout, - text: bool = True, startup_fn: Optional[Callable[[], bool]] = None, startup_timeout: float = 10, terminate_timeout: float = 10, + text: bool = True, **subprocess_kwargs: Any, ) -> None: + self._terminate_timeout = terminate_timeout + self._process = subprocess.Popen( - cmd, stdout=stdout, stderr=stderr, **subprocess_kwargs + cmd, stdout=stdout, stderr=stderr, text=text, **subprocess_kwargs ) try: if startup_fn: @@ -197,11 +199,9 @@ def wait() -> None: self.terminate() raise - self._terminate_timeout = terminate_timeout - - def terminate(self) -> tuple[str, str]: + def terminate(self) -> tuple[str | bytes, str | bytes]: @timeout_after(self._terminate_timeout) - def terminate() -> tuple[str, str]: + def terminate() -> tuple[str | bytes, str | bytes]: self._process.terminate() return self._process.communicate() From dc7595d18dde0a0a78fabfb3e6609a86245c6ba4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 11:48:56 +0100 Subject: [PATCH 12/29] debug windows CI failure --- .github/actions/setup-env/action.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index a527f590..09fbd262 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -54,7 +54,10 @@ runs: - name: Install playwright shell: bash -el {0} - run: playwright install + run: | + ls -la "${CONDA_PREFIX}/bin" + python -m playwright --help + playwright install - name: Install ragna shell: bash -el {0} From 1cbe83a860b85431723cbf25f6d54797a8ffb9af Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 11:58:06 +0100 Subject: [PATCH 13/29] Revert "debug windows CI failure" This reverts commit dc7595d18dde0a0a78fabfb3e6609a86245c6ba4. --- .github/actions/setup-env/action.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index 09fbd262..a527f590 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -54,10 +54,7 @@ runs: - name: Install playwright shell: bash -el {0} - run: | - ls -la "${CONDA_PREFIX}/bin" - python -m playwright --help - playwright install + run: playwright install - name: Install ragna shell: bash -el {0} From 57658742a920bbf554527d1be723f2d4889ffa77 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 12:00:28 +0100 Subject: [PATCH 14/29] disable UI tests completely --- .github/actions/setup-env/action.yml | 6 +++--- pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index a527f590..d68ce95e 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -52,9 +52,9 @@ runs: mamba env update --file environment-dev.yml git checkout -- environment-dev.yml - - name: Install playwright - shell: bash -el {0} - run: playwright install + # - name: Install playwright + # shell: bash -el {0} + # run: playwright install - name: Install ragna shell: bash -el {0} diff --git a/pyproject.toml b/pyproject.toml index 83594126..d7832f0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ ignore = ["E501"] [tool.pytest.ini_options] minversion = "6.0" -addopts = "-ra --tb=short --asyncio-mode=auto" +addopts = "-ra --tb=short --asyncio-mode=auto --ignore tests/deploy/ui" asyncio_default_fixture_loop_scope = "function" testpaths = [ "tests", From 388e676e1a26c72e091f564cfc457374e925b72d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 12:16:49 +0100 Subject: [PATCH 15/29] trigger CI From b540c9e10d2c184ad8be1823c0d8385bf33c75ed Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 12:22:55 +0100 Subject: [PATCH 16/29] disable cache restore keys --- .github/actions/setup-env/action.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index d68ce95e..fc349832 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -40,8 +40,6 @@ runs: env-${{ runner.os }}-${{ runner.arch }}-${{ inputs.python-version }}|${{steps.cache-key.outputs.date }}-${{ hashFiles('environment-dev.yml', 'pyproject.toml') }} - restore-keys: | - env-${{ runner.os }}-${{ runner.arch }}-${{ inputs.python-version }} - name: Update conda environment if necessary if: steps.cache.outputs.cache-hit != 'true' From 26da321b3c7801999b11c234c3370900500712a6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 12:27:40 +0100 Subject: [PATCH 17/29] no progress bars in CI --- .github/actions/setup-env/action.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index fc349832..f55d584d 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -47,7 +47,7 @@ runs: run: | sed -i'' -e 's/python *= *[0-9.]\+/python =${{ inputs.python-version }}/g' environment-dev.yml cat environment-dev.yml - mamba env update --file environment-dev.yml + mamba env update --quiet --file environment-dev.yml git checkout -- environment-dev.yml # - name: Install playwright @@ -63,7 +63,7 @@ runs: else PROJECT_PATH='.' fi - pip install --verbose --editable "${PROJECT_PATH}" + pip install --verbose --progress-bar=off --editable "${PROJECT_PATH}" - name: Display development environment shell: bash -el {0} From 3c363f095e7543286745ebcfd710a5e9155ddb68 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 12:32:06 +0100 Subject: [PATCH 18/29] pipefail --- .github/actions/setup-env/action.yml | 6 +++--- .github/workflows/docker.yml | 2 +- .github/workflows/lint.yml | 6 +++--- .github/workflows/test.yml | 6 +++--- .github/workflows/update-docker-requirements.yml | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index f55d584d..f7b18aaf 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -20,7 +20,7 @@ runs: activate-environment: ragna-dev - name: Display conda info - shell: bash -el {0} + shell: bash -elo pipefail {0} run: conda info - name: Set cache date @@ -55,7 +55,7 @@ runs: # run: playwright install - name: Install ragna - shell: bash -el {0} + shell: bash -elo pipefail {0} run: | if [[ ${{ inputs.optional-dependencies }} == true ]] then @@ -66,5 +66,5 @@ runs: pip install --verbose --progress-bar=off --editable "${PROJECT_PATH}" - name: Display development environment - shell: bash -el {0} + shell: bash -elo pipefail {0} run: conda list diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 5e26e6f9..7bd4b8f0 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -28,7 +28,7 @@ jobs: runs-on: ubuntu-latest defaults: run: - shell: bash -el {0} + shell: bash -elo pipefail {0} steps: - name: Checkout repository diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 44264ad2..b38db3bb 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest defaults: run: - shell: bash -el {0} + shell: bash -elo pipefail {0} steps: - name: Checkout repository @@ -33,7 +33,7 @@ jobs: runs-on: ubuntu-latest defaults: run: - shell: bash -el {0} + shell: bash -elo pipefail {0} steps: - name: Checkout repository @@ -63,7 +63,7 @@ jobs: runs-on: ubuntu-latest defaults: run: - shell: bash -el {0} + shell: bash -elo pipefail {0} steps: - name: Checkout repository diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1fd5080f..0a0b33f6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest defaults: run: - shell: bash -el {0} + shell: bash -elo pipefail {0} steps: - name: Checkout repository @@ -53,7 +53,7 @@ jobs: defaults: run: - shell: bash -el {0} + shell: bash -elo pipefail {0} steps: - name: Checkout repository @@ -113,7 +113,7 @@ jobs: # # defaults: # run: -# shell: bash -el {0} +# shell: bash -elo pipefail {0} # # steps: # - name: Checkout repository diff --git a/.github/workflows/update-docker-requirements.yml b/.github/workflows/update-docker-requirements.yml index a1a2c860..be1989ec 100644 --- a/.github/workflows/update-docker-requirements.yml +++ b/.github/workflows/update-docker-requirements.yml @@ -19,7 +19,7 @@ jobs: runs-on: ubuntu-latest defaults: run: - shell: bash -el {0} + shell: bash -elo pipefail {0} steps: - name: Checkout repository From 769c645105a83dedb7b37926f0ea553a0c343772 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 12:40:23 +0100 Subject: [PATCH 19/29] improve cache key step --- .github/actions/setup-env/action.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index f7b18aaf..f0cb413c 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -13,7 +13,7 @@ runs: using: composite steps: - - name: Setup mambaforge and development environment + - name: Setup miniforge and empty development environment uses: conda-incubator/setup-miniconda@v3 with: miniforge-version: latest @@ -28,8 +28,7 @@ runs: shell: bash run: | DATE=$(date +'%Y%m%d') - echo $DATE - echo "DATE=$DATE" >> $GITHUB_OUTPUT + echo "date=${DATE}" | tee --append "${GITHUB_OUTPUT}" - name: Restore conda environment id: cache From 5b2217f8f1f54714e6b5ac141c37d957ba6c8a95 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 12:42:01 +0100 Subject: [PATCH 20/29] debug seg fault --- .github/workflows/test.yml | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0a0b33f6..4635a3ad 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -70,15 +70,18 @@ jobs: id: tests run: | pytest \ - --ignore tests/deploy/ui \ - --junit-xml=test-results.xml \ - --durations=25 + --co -q + +# --ignore tests/deploy/ui \ +# --junit-xml=test-results.xml \ +# --durations=25 +# +# - name: Surface failing tests +# if: steps.tests.outcome != 'success' +# uses: pmeier/pytest-results-action@v0.3.0 +# with: +# path: test-results.xml - - name: Surface failing tests - if: steps.tests.outcome != 'success' - uses: pmeier/pytest-results-action@v0.3.0 - with: - path: test-results.xml # pytest-ui: # strategy: # matrix: From bcfe59bdce3f113905a42aaaa54c581ddd76737e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 12:46:49 +0100 Subject: [PATCH 21/29] fix tee option for macos --- .github/actions/setup-env/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index f0cb413c..b545a504 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -28,7 +28,7 @@ runs: shell: bash run: | DATE=$(date +'%Y%m%d') - echo "date=${DATE}" | tee --append "${GITHUB_OUTPUT}" + echo "date=${DATE}" | tee -a "${GITHUB_OUTPUT}" - name: Restore conda environment id: cache From 19ff2476c70aa97394d93a627fe3ef2fb26193f1 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 12:48:07 +0100 Subject: [PATCH 22/29] debug --- .github/workflows/test.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4635a3ad..d05c9941 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -66,6 +66,15 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Debug 1 + run: which pytest + + - name: Debug 2 + run: pytest --version + + - name: Debug 3 + run: pytest --help + - name: Run unit tests id: tests run: | From 58ead9f8200b93c1ba93b58a638016e2c3e79979 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 12:50:55 +0100 Subject: [PATCH 23/29] try conda as installer instead of mamba --- .github/actions/setup-env/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index b545a504..36fafc78 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -46,7 +46,7 @@ runs: run: | sed -i'' -e 's/python *= *[0-9.]\+/python =${{ inputs.python-version }}/g' environment-dev.yml cat environment-dev.yml - mamba env update --quiet --file environment-dev.yml + conda env update --quiet --file environment-dev.yml git checkout -- environment-dev.yml # - name: Install playwright From 49aabc71d3ccab6e773bcd58eb548ccf52334ba1 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 13:50:36 +0100 Subject: [PATCH 24/29] Revert "try conda as installer instead of mamba" This reverts commit 58ead9f8200b93c1ba93b58a638016e2c3e79979. --- .github/actions/setup-env/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index 36fafc78..b545a504 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -46,7 +46,7 @@ runs: run: | sed -i'' -e 's/python *= *[0-9.]\+/python =${{ inputs.python-version }}/g' environment-dev.yml cat environment-dev.yml - conda env update --quiet --file environment-dev.yml + mamba env update --quiet --file environment-dev.yml git checkout -- environment-dev.yml # - name: Install playwright From 509afb793bf5906fb276d99b5e4c316d160ac24f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 13:50:47 +0100 Subject: [PATCH 25/29] more debug --- .github/workflows/test.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d05c9941..061a97a2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -67,14 +67,21 @@ jobs: python-version: ${{ matrix.python-version }} - name: Debug 1 + if: always() run: which pytest - name: Debug 2 + if: always() run: pytest --version - name: Debug 3 + if: always() run: pytest --help + - name: Debug 4 + if: always() + run: ragna --help + - name: Run unit tests id: tests run: | From 1cd64be8e3fc67325bd86237956eb744fad7ece6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 15:16:48 +0100 Subject: [PATCH 26/29] pin pymupdf --- environment-dev.yml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/environment-dev.yml b/environment-dev.yml index 1684b030..177682d3 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -2,7 +2,7 @@ name: ragna-dev channels: - conda-forge dependencies: - - python =3.10 + - python =3.11 - pip - git-lfs - pip: diff --git a/pyproject.toml b/pyproject.toml index d7832f0e..1a4a9429 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ all = [ "ijson", "lancedb>=0.2", "pyarrow", - "pymupdf>=1.23.6", + "pymupdf>=1.23.6,<=1.24.10", "python-docx", "python-pptx", "tiktoken", From 122312f70b88ef0d6c5727c43febf00e4b1dafda Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 15:23:24 +0100 Subject: [PATCH 27/29] lint --- pyproject.toml | 2 +- ragna/core/_document.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1a4a9429..3d4641e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ all = [ "ijson", "lancedb>=0.2", "pyarrow", - "pymupdf>=1.23.6,<=1.24.10", + "pymupdf<=1.24.10,>=1.23.6", "python-docx", "python-pptx", "tiktoken", diff --git a/ragna/core/_document.py b/ragna/core/_document.py index 624115f5..4868b270 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -223,7 +223,7 @@ class PdfDocumentHandler(DocumentHandler): @classmethod def requirements(cls) -> list[Requirement]: - return [PackageRequirement("pymupdf>=1.23.6")] + return [PackageRequirement("pymupdf<=1.24.10,>=1.23.6")] @classmethod def supported_suffixes(cls) -> list[str]: From e309c030ec362a9ce5bf4e30d8ce91594145d0a2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 15:28:56 +0100 Subject: [PATCH 28/29] revert debug --- .github/workflows/test.yml | 35 ++++++++--------------------------- 1 file changed, 8 insertions(+), 27 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 061a97a2..d197cb7c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -66,38 +66,19 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Debug 1 - if: always() - run: which pytest - - - name: Debug 2 - if: always() - run: pytest --version - - - name: Debug 3 - if: always() - run: pytest --help - - - name: Debug 4 - if: always() - run: ragna --help - - name: Run unit tests id: tests run: | pytest \ - --co -q - -# --ignore tests/deploy/ui \ -# --junit-xml=test-results.xml \ -# --durations=25 -# -# - name: Surface failing tests -# if: steps.tests.outcome != 'success' -# uses: pmeier/pytest-results-action@v0.3.0 -# with: -# path: test-results.xml + --ignore tests/deploy/ui \ + --junit-xml=test-results.xml \ + --durations=25 + - name: Surface failing tests + if: steps.tests.outcome != 'success' + uses: pmeier/pytest-results-action@v0.3.0 + with: + path: test-results.xml # pytest-ui: # strategy: # matrix: From 37a38122cc5bff5ed29487af594ff3691fcf807a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 Dec 2024 15:48:50 +0100 Subject: [PATCH 29/29] cleanup --- environment-dev.yml | 2 +- pyproject.toml | 1 - ragna/deploy/_auth.py | 10 - ragna/deploy/_ui/central_view.py | 13 +- ragna/deploy/_ui/components/__init__.py | 0 ragna/deploy/_ui/components/file_uploader.py | 266 ------------------- ragna/deploy/_ui/resources/upload.js | 44 --- tests/deploy/api/utils.py | 35 --- tests/deploy/{api => }/conftest.py | 0 9 files changed, 6 insertions(+), 365 deletions(-) delete mode 100644 ragna/deploy/_ui/components/__init__.py delete mode 100644 ragna/deploy/_ui/components/file_uploader.py delete mode 100644 ragna/deploy/_ui/resources/upload.js delete mode 100644 tests/deploy/api/utils.py rename tests/deploy/{api => }/conftest.py (100%) diff --git a/environment-dev.yml b/environment-dev.yml index 177682d3..1684b030 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -2,7 +2,7 @@ name: ragna-dev channels: - conda-forge dependencies: - - python =3.11 + - python =3.10 - pip - git-lfs - pip: diff --git a/pyproject.toml b/pyproject.toml index 3d4641e0..2aa8980e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,6 @@ requires-python = ">=3.10" dependencies = [ "aiofiles", "emoji", - "eval_type_backport; python_version<'3.10'", "fastapi", "httpx", "packaging", diff --git a/ragna/deploy/_auth.py b/ragna/deploy/_auth.py index 06492bf4..e7a39b4d 100644 --- a/ragna/deploy/_auth.py +++ b/ragna/deploy/_auth.py @@ -377,16 +377,6 @@ async def login(self, request: Request) -> Union[schemas.User, Response]: user_data = (await client.get("https://api.github.com/user")).json() - organizations_data = ( - await client.get(user_data["organizations_url"]) - ).json() - organizations = { - organization_data["login"] for organization_data in organizations_data - } - if not (organizations & {"Quansight", "Quansight-Labs"}): - # FIXME: send the login page again with a failure message - return HTMLResponse("Unauthorized!") - return schemas.User(name=user_data["login"]) diff --git a/ragna/deploy/_ui/central_view.py b/ragna/deploy/_ui/central_view.py index 81ac9072..1a908e99 100644 --- a/ragna/deploy/_ui/central_view.py +++ b/ragna/deploy/_ui/central_view.py @@ -299,15 +299,12 @@ async def chat_callback( message.clipboard_button.value = message.content_pane.object message.assistant_toolbar.visible = True - except Exception as exc: - import traceback - + except Exception: yield RagnaChatMessage( - # ( - # "Sorry, something went wrong. " - # "If this problem persists, please contact your administrator." - # ), - "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)), + ( + "Sorry, something went wrong. " + "If this problem persists, please contact your administrator." + ), role="system", user=self.get_user_from_role("system"), ) diff --git a/ragna/deploy/_ui/components/__init__.py b/ragna/deploy/_ui/components/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ragna/deploy/_ui/components/file_uploader.py b/ragna/deploy/_ui/components/file_uploader.py deleted file mode 100644 index 1568bc61..00000000 --- a/ragna/deploy/_ui/components/file_uploader.py +++ /dev/null @@ -1,266 +0,0 @@ -import json -import uuid - -import param -from panel.reactive import ReactiveHTML -from panel.widgets import Widget - - -class FileUploader(ReactiveHTML, Widget): # type: ignore[misc] - allowed_documents = param.List(default=[]) - allowed_documents_str = param.String(default="") - - file_list = param.List(default=[]) - - custom_js = param.String(default="") - uploaded_documents_json = param.String(default="") - - title = param.String(default="") - - def __init__(self, allowed_documents, informations_endpoint, **params): - super().__init__(**params) - - self.informations_endpoint = informations_endpoint - - self.after_upload_callback = None - - def can_proceed_to_upload(self): - return len(self.file_list) > 0 - - @param.depends("allowed_documents", watch=True) - def update_allowed_documents_str(self): - self.allowed_documents_str = ", ".join(self.allowed_documents) - - @param.depends("uploaded_documents_json", watch=True) - async def did_finish_upload(self): - if self.after_upload_callback is not None: - await self.after_upload_callback(json.loads(self.uploaded_documents_json)) - - def perform_upload(self, event=None, after_upload_callback=None): - self.after_upload_callback = after_upload_callback - - self.loading = True - - final_callback_js = """ - var final_callback = function(uploaded_documents) { - self.get_uploaded_documents_json().innerText = JSON.stringify(uploaded_documents); - }; - """ - - # This is a hack to force the re-execution of the javascript. - # If the whole javascript is the same, and doesn't change, - # the panel widget is not re-renderer, and the upload function is not called. - random_id = f" var random_id = '{str(uuid.uuid4())}';" - - self.custom_js = ( - final_callback_js - + random_id - + f"""upload( self.get_upload_files(), '{self.informations_endpoint}', final_callback) """ - ) - - _child_config = { - "custom_js": "template", - "uploaded_documents_json": "template", - "allowed_documents_str": "template", - } - - _template = """ - - -
-
- - Click to upload or drag and drop.
-
- Allowed files: ${allowed_documents_str} -
- -
-
- -
-
-
-
- ${custom_js} -
-
- -
-
- """ - - _scripts = { - "after_layout": """ - self.update_layout(); - """, - "update_layout": """ - - - if (data.file_list.length > 0) { - fileUploadDropArea.classList.add("uploaded"); - } else { - fileUploadDropArea.classList.remove("uploaded"); - } - - fileListContainer.innerHTML = ""; - - data.file_list.forEach(function(f) { - var pill = document.createElement("div"); - pill.classList.add("chat_document_pill"); - var fname = document.createTextNode(f.name); - - pill.appendChild(fname); - fileListContainer.appendChild(pill); - - }); - - """, - "file_input_on_change": """ - - - var new_file_list = Array.from(fileUpload.files).map(function(f) { - new_f = { - "lastModified":f.lastModified , - "name":f.name , - "size":f.size , - "type":f.type , - }; - - return new_f; - }); - - data.file_list = new_file_list; - self.update_layout(); - - - """, - "get_upload_files": """ - - return fileUpload.files; - """, - "get_uploaded_documents_json": """ - return uploaded_documents_json_watcher; - """, - "render": """ - - var MutationObserver = window.MutationObserver || window.WebKitMutationObserver || window.MozMutationObserver; - var observer = new MutationObserver(function(mutationsList, observer) { - mutationsList.forEach(function(mutation){ - if (mutation.type == 'characterData') { - eval(custom_js_watcher.innerText); - } - - }); - - }); - observer.observe(custom_js_watcher, {characterData: true, childList: true, attributes: true, subtree: true}); - - - var MutationObserver = window.MutationObserver || window.WebKitMutationObserver || window.MozMutationObserver; - var observer = new MutationObserver(function(mutationsList, observer) { - mutationsList.forEach(function(mutation){ - data.uploaded_documents_json = uploaded_documents_json_watcher.innerText - }); - - }); - observer.observe(uploaded_documents_json_watcher, {characterData: true, childList: true, attributes: true, subtree: true}); - - - fileUpload.addEventListener("dragenter", function(event){ - fileUploadDropArea.classList.add("draggedOver"); - }); - - fileUpload.addEventListener("dragleave", function(event){ - fileUploadDropArea.classList.remove("draggedOver") - }); - - fileUpload.addEventListener("drop", function(event){ - fileUploadDropArea.classList.remove("draggedOver") - event.preventDefault(); - fileUpload.files = event.dataTransfer.files; - self.file_input_on_change(); - }); - - """, - "remove": """ - - """, - } diff --git a/ragna/deploy/_ui/resources/upload.js b/ragna/deploy/_ui/resources/upload.js deleted file mode 100644 index 905da833..00000000 --- a/ragna/deploy/_ui/resources/upload.js +++ /dev/null @@ -1,44 +0,0 @@ -function upload(files, informationEndpoint, final_callback) { - uploadBatches(files, informationEndpoint).then(final_callback); -} - -async function uploadBatches(files, informationEndpoint) { - const batchSize = 500; - const queue = Array.from(files); - - let uploaded = []; - - while (queue.length) { - const batch = queue.splice(0, batchSize); - await Promise.all(batch.map((file) => uploadFile(file, informationEndpoint))).then( - (results) => { - uploaded.push(...results); - }, - ); - } - - return uploaded; -} - -async function uploadFile(file, informationEndpoint) { - const response = await fetch(informationEndpoint, { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ name: file.name }), - }); - const documentUpload = await response.json(); - - const parameters = documentUpload.parameters; - var body = new FormData(); - for (const [key, value] of Object.entries(parameters.data)) { - body.append(key, value); - } - body.append("file", file); - - await fetch(parameters.url, { - method: parameters.method, - body: body, - }); - - return documentUpload.document; -} diff --git a/tests/deploy/api/utils.py b/tests/deploy/api/utils.py deleted file mode 100644 index 6dd3fb78..00000000 --- a/tests/deploy/api/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -import os - -from fastapi.testclient import TestClient - -from ragna._utils import default_user -from ragna.deploy._core import make_app - - -def make_api_app(*, config, ignore_unavailable_components): - return make_app( - config, - api=True, - ui=False, - ignore_unavailable_components=ignore_unavailable_components, - open_browser=False, - ) - - -def authenticate(client: TestClient) -> None: - return - username = default_user() - token = ( - client.post( - "/token", - data={ - "username": username, - "password": os.environ.get( - "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username - ), - }, - ) - .raise_for_status() - .json() - ) - client.headers["Authorization"] = f"Bearer {token}" diff --git a/tests/deploy/api/conftest.py b/tests/deploy/conftest.py similarity index 100% rename from tests/deploy/api/conftest.py rename to tests/deploy/conftest.py