diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index 7deb55c2..60684969 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -58,6 +58,10 @@ runs: shell: bash -el {0} run: mamba install --yes --channel conda-forge redis-server + - name: Install playwright + shell: bash -el {0} + run: playwright install + - name: Install ragna shell: bash -el {0} run: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b74f3907..5ff4147d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -68,10 +68,76 @@ jobs: - name: Run unit tests id: tests - run: pytest --junit-xml=test-results.xml --durations=25 + run: | + pytest \ + --ignore tests/deploy/ui \ + --junit-xml=test-results.xml \ + --durations=25 - name: Surface failing tests if: steps.tests.outcome != 'success' uses: pmeier/pytest-results-action@v0.3.0 with: path: test-results.xml + + pytest-ui: + strategy: + matrix: + os: + - ubuntu-latest + - windows-latest + - macos-latest + browser: + - chromium + - firefox + python-version: + - "3.9" + - "3.10" + - "3.11" + exclude: + - python-version: "3.10" + os: windows-latest + - python-version: "3.11" + os: windows-latest + - python-version: "3.10" + os: macos-latest + - python-version: "3.11" + os: macos-latest + include: + - browser: webkit + os: macos-latest + python-version: "3.9" + + fail-fast: false + + runs-on: ${{ matrix.os }} + + defaults: + run: + shell: bash -el {0} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup environment + uses: ./.github/actions/setup-env + with: + python-version: ${{ matrix.python-version }} + + - name: Run unit tests + id: tests + run: | + pytest tests/deploy/ui \ + --browser ${{ matrix.browser }} \ + --video=retain-on-failure + + - name: Upload playwright video + if: failure() + uses: actions/upload-artifact@v4 + with: + name: + playwright-${{ matrix.os }}-${{ matrix.python-version}}-${{ github.run_id }} + path: test-results diff --git a/docs/examples/gallery_streaming.py b/docs/examples/gallery_streaming.py index 84d92d08..9ccdecb4 100644 --- a/docs/examples/gallery_streaming.py +++ b/docs/examples/gallery_streaming.py @@ -31,6 +31,14 @@ # - [ragna.assistants.Gpt4][] # - [llamafile](https://github.com/Mozilla-Ocho/llamafile) # - [ragna.assistants.LlamafileAssistant][] +# - [Ollama](https://ollama.com/) +# - [ragna.assistants.OllamaGemma2B][] +# - [ragna.assistants.OllamaLlama2][] +# - [ragna.assistants.OllamaLlava][] +# - [ragna.assistants.OllamaMistral][] +# - [ragna.assistants.OllamaMixtral][] +# - [ragna.assistants.OllamaOrcaMini][] +# - [ragna.assistants.OllamaPhi2][] from ragna import assistants diff --git a/docs/tutorials/gallery_python_api.py b/docs/tutorials/gallery_python_api.py index a7d0ef63..23667a71 100644 --- a/docs/tutorials/gallery_python_api.py +++ b/docs/tutorials/gallery_python_api.py @@ -87,6 +87,14 @@ # - [ragna.assistants.Jurassic2Ultra][] # - [llamafile](https://github.com/Mozilla-Ocho/llamafile) # - [ragna.assistants.LlamafileAssistant][] +# - [Ollama](https://ollama.com/) +# - [ragna.assistants.OllamaGemma2B][] +# - [ragna.assistants.OllamaLlama2][] +# - [ragna.assistants.OllamaLlava][] +# - [ragna.assistants.OllamaMistral][] +# - [ragna.assistants.OllamaMixtral][] +# - [ragna.assistants.OllamaOrcaMini][] +# - [ragna.assistants.OllamaPhi2][] # # !!! note # diff --git a/environment-dev.yml b/environment-dev.yml index 2a7b6a03..9ee47a1b 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -10,6 +10,7 @@ dependencies: - pytest >=6 - pytest-mock - pytest-asyncio + - pytest-playwright - mypy ==1.10.0 - pre-commit - types-aiofiles diff --git a/ragna/assistants/__init__.py b/ragna/assistants/__init__.py index d583e7a0..bcf5ead6 100644 --- a/ragna/assistants/__init__.py +++ b/ragna/assistants/__init__.py @@ -6,6 +6,13 @@ "CommandLight", "GeminiPro", "GeminiUltra", + "OllamaGemma2B", + "OllamaPhi2", + "OllamaLlama2", + "OllamaLlava", + "OllamaMistral", + "OllamaMixtral", + "OllamaOrcaMini", "Gpt35Turbo16k", "Gpt4", "Jurassic2Ultra", @@ -19,6 +26,15 @@ from ._demo import RagnaDemoAssistant from ._google import GeminiPro, GeminiUltra from ._llamafile import LlamafileAssistant +from ._ollama import ( + OllamaGemma2B, + OllamaLlama2, + OllamaLlava, + OllamaMistral, + OllamaMixtral, + OllamaOrcaMini, + OllamaPhi2, +) from ._openai import Gpt4, Gpt35Turbo16k # isort: split diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 1c61a213..3e0c56b5 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -7,6 +7,7 @@ class Ai21LabsAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "AI21_API_KEY" + _STREAMING_PROTOCOL = None _MODEL_TYPE: str @classmethod @@ -27,7 +28,8 @@ async def answer( # See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters # See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters # See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response - response = await self._client.post( + async for data in self._call_api( + "POST", f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat", headers={ "accept": "application/json", @@ -46,10 +48,8 @@ async def answer( ], "system": self._make_system_content(sources), }, - ) - await self._assert_api_call_is_success(response) - - yield cast(str, response.json()["outputs"][0]["text"]) + ): + yield cast(str, data["outputs"][0]["text"]) # The Jurassic2Mid assistant receives a 500 internal service error from the remote diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index 37f132b5..d74fc840 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -2,11 +2,12 @@ from ragna.core import PackageRequirement, RagnaException, Requirement, Source -from ._http_api import HttpApiAssistant +from ._http_api import HttpApiAssistant, HttpStreamingProtocol class AnthropicAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "ANTHROPIC_API_KEY" + _STREAMING_PROTOCOL = HttpStreamingProtocol.SSE _MODEL: str @classmethod @@ -40,7 +41,7 @@ async def answer( ) -> AsyncIterator[str]: # See https://docs.anthropic.com/claude/reference/messages_post # See https://docs.anthropic.com/claude/reference/streaming - async for data in self._stream_sse( + async for data in self._call_api( "POST", "https://api.anthropic.com/v1/messages", headers={ diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index b47737f8..4108d31b 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -2,11 +2,12 @@ from ragna.core import RagnaException, Source -from ._http_api import HttpApiAssistant +from ._http_api import HttpApiAssistant, HttpStreamingProtocol class CohereAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "COHERE_API_KEY" + _STREAMING_PROTOCOL = HttpStreamingProtocol.JSONL _MODEL: str @classmethod @@ -29,7 +30,7 @@ async def answer( # See https://docs.cohere.com/docs/cochat-beta # See https://docs.cohere.com/reference/chat # See https://docs.cohere.com/docs/retrieval-augmented-generation-rag - async for event in self._stream_jsonl( + async for event in self._call_api( "POST", "https://api.cohere.ai/v1/chat", headers={ diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index 8e1caf1e..70c82936 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -1,38 +1,15 @@ from typing import AsyncIterator -from ragna._compat import anext -from ragna.core import PackageRequirement, Requirement, Source +from ragna.core import Source -from ._http_api import HttpApiAssistant - - -# ijson does not support reading from an (async) iterator, but only from file-like -# objects, i.e. https://docs.python.org/3/tutorial/inputoutput.html#methods-of-file-objects. -# See https://github.com/ICRAR/ijson/issues/44 for details. -# ijson actually doesn't care about most of the file interface and only requires the -# read() method to be present. -class AsyncIteratorReader: - def __init__(self, ait: AsyncIterator[bytes]) -> None: - self._ait = ait - - async def read(self, n: int) -> bytes: - # n is usually used to indicate how many bytes to read, but since we want to - # return a chunk as soon as it is available, we ignore the value of n. The only - # exception is n == 0, which is used by ijson to probe the return type and - # set up decoding. - if n == 0: - return b"" - return await anext(self._ait, b"") # type: ignore[call-arg] +from ._http_api import HttpApiAssistant, HttpStreamingProtocol class GoogleAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "GOOGLE_API_KEY" + _STREAMING_PROTOCOL = HttpStreamingProtocol.JSON _MODEL: str - @classmethod - def _extra_requirements(cls) -> list[Requirement]: - return [PackageRequirement("ijson")] - @classmethod def display_name(cls) -> str: return f"Google/{cls._MODEL}" @@ -51,9 +28,7 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str: async def answer( self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: - import ijson - - async with self._client.stream( + async for chunk in self._call_api( "POST", f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent", params={"key": self._api_key}, @@ -64,7 +39,10 @@ async def answer( ], # https://ai.google.dev/docs/safety_setting_gemini "safetySettings": [ - {"category": f"HARM_CATEGORY_{category}", "threshold": "BLOCK_NONE"} + { + "category": f"HARM_CATEGORY_{category}", + "threshold": "BLOCK_NONE", + } for category in [ "HARASSMENT", "HATE_SPEECH", @@ -78,14 +56,9 @@ async def answer( "maxOutputTokens": max_new_tokens, }, }, - ) as response: - await self._assert_api_call_is_success(response) - - async for chunk in ijson.items( - AsyncIteratorReader(response.aiter_bytes(1024)), - "item.candidates.item.content.parts.item.text", - ): - yield chunk + parse_kwargs=dict(item="item.candidates.item.content.parts.item.text"), + ): + yield chunk class GeminiPro(GoogleAssistant): diff --git a/ragna/assistants/_http_api.py b/ragna/assistants/_http_api.py index 1151a62a..d6f48a26 100644 --- a/ragna/assistants/_http_api.py +++ b/ragna/assistants/_http_api.py @@ -1,65 +1,83 @@ import contextlib +import enum import json import os from typing import Any, AsyncIterator, Optional import httpx -from httpx import Response import ragna -from ragna.core import Assistant, EnvVarRequirement, RagnaException, Requirement +from ragna._compat import anext +from ragna.core import ( + Assistant, + EnvVarRequirement, + PackageRequirement, + RagnaException, + Requirement, +) -class HttpApiAssistant(Assistant): - _API_KEY_ENV_VAR: Optional[str] +class HttpStreamingProtocol(enum.Enum): + SSE = enum.auto() + JSONL = enum.auto() + JSON = enum.auto() - @classmethod - def requirements(cls) -> list[Requirement]: - requirements: list[Requirement] = ( - [EnvVarRequirement(cls._API_KEY_ENV_VAR)] - if cls._API_KEY_ENV_VAR is not None - else [] - ) - requirements.extend(cls._extra_requirements()) - return requirements +class HttpApiCaller: @classmethod - def _extra_requirements(cls) -> list[Requirement]: - return [] + def requirements(cls, protocol: HttpStreamingProtocol) -> list[Requirement]: + streaming_requirements: dict[HttpStreamingProtocol, list[Requirement]] = { + HttpStreamingProtocol.SSE: [PackageRequirement("httpx_sse")], + HttpStreamingProtocol.JSON: [PackageRequirement("ijson")], + } + return streaming_requirements.get(protocol, []) - def __init__(self) -> None: - self._client = httpx.AsyncClient( - headers={"User-Agent": f"{ragna.__version__}/{self}"}, - timeout=60, - ) - self._api_key: Optional[str] = ( - os.environ[self._API_KEY_ENV_VAR] - if self._API_KEY_ENV_VAR is not None - else None - ) - - async def _assert_api_call_is_success(self, response: Response) -> None: - if response.is_success: - return + def __init__( + self, + client: httpx.AsyncClient, + protocol: Optional[HttpStreamingProtocol] = None, + ) -> None: + self._client = client + self._protocol = protocol - content = await response.aread() - with contextlib.suppress(Exception): - content = json.loads(content) + def __call__( + self, + method: str, + url: str, + *, + parse_kwargs: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncIterator[Any]: + if self._protocol is None: + call_method = self._no_stream + else: + call_method = { + HttpStreamingProtocol.SSE: self._stream_sse, + HttpStreamingProtocol.JSONL: self._stream_jsonl, + HttpStreamingProtocol.JSON: self._stream_json, + }[self._protocol] + return call_method(method, url, parse_kwargs=parse_kwargs or {}, **kwargs) - raise RagnaException( - "API call failed", - request_method=response.request.method, - request_url=str(response.request.url), - response_status_code=response.status_code, - response_content=content, - ) + async def _no_stream( + self, + method: str, + url: str, + *, + parse_kwargs: dict[str, Any], + **kwargs: Any, + ) -> AsyncIterator[Any]: + response = await self._client.request(method, url, **kwargs) + await self._assert_api_call_is_success(response) + yield response.json() async def _stream_sse( self, method: str, url: str, + *, + parse_kwargs: dict[str, Any], **kwargs: Any, - ) -> AsyncIterator[dict[str, Any]]: + ) -> AsyncIterator[Any]: import httpx_sse async with httpx_sse.aconnect_sse( @@ -71,10 +89,103 @@ async def _stream_sse( yield json.loads(sse.data) async def _stream_jsonl( - self, method: str, url: str, **kwargs: Any - ) -> AsyncIterator[dict[str, Any]]: + self, + method: str, + url: str, + *, + parse_kwargs: dict[str, Any], + **kwargs: Any, + ) -> AsyncIterator[Any]: async with self._client.stream(method, url, **kwargs) as response: await self._assert_api_call_is_success(response) async for chunk in response.aiter_lines(): yield json.loads(chunk) + + # ijson does not support reading from an (async) iterator, but only from file-like + # objects, i.e. https://docs.python.org/3/tutorial/inputoutput.html#methods-of-file-objects. + # See https://github.com/ICRAR/ijson/issues/44 for details. + # ijson actually doesn't care about most of the file interface and only requires the + # read() method to be present. + class _AsyncIteratorReader: + def __init__(self, ait: AsyncIterator[bytes]) -> None: + self._ait = ait + + async def read(self, n: int) -> bytes: + # n is usually used to indicate how many bytes to read, but since we want to + # return a chunk as soon as it is available, we ignore the value of n. The + # only exception is n == 0, which is used by ijson to probe the return type + # and set up decoding. + if n == 0: + return b"" + return await anext(self._ait, b"") # type: ignore[call-arg] + + async def _stream_json( + self, + method: str, + url: str, + *, + parse_kwargs: dict[str, Any], + **kwargs: Any, + ) -> AsyncIterator[Any]: + import ijson + + item = parse_kwargs["item"] + chunk_size = parse_kwargs.get("chunk_size", 16) + + async with self._client.stream(method, url, **kwargs) as response: + await self._assert_api_call_is_success(response) + + async for chunk in ijson.items( + self._AsyncIteratorReader(response.aiter_bytes(chunk_size)), item + ): + yield chunk + + async def _assert_api_call_is_success(self, response: httpx.Response) -> None: + if response.is_success: + return + + content = await response.aread() + with contextlib.suppress(Exception): + content = json.loads(content) + + raise RagnaException( + "API call failed", + request_method=response.request.method, + request_url=str(response.request.url), + response_status_code=response.status_code, + response_content=content, + ) + + +class HttpApiAssistant(Assistant): + _API_KEY_ENV_VAR: Optional[str] + _STREAMING_PROTOCOL: Optional[HttpStreamingProtocol] + + @classmethod + def requirements(cls) -> list[Requirement]: + requirements: list[Requirement] = ( + [EnvVarRequirement(cls._API_KEY_ENV_VAR)] + if cls._API_KEY_ENV_VAR is not None + else [] + ) + if cls._STREAMING_PROTOCOL is not None: + requirements.extend(HttpApiCaller.requirements(cls._STREAMING_PROTOCOL)) + requirements.extend(cls._extra_requirements()) + return requirements + + @classmethod + def _extra_requirements(cls) -> list[Requirement]: + return [] + + def __init__(self) -> None: + self._client = httpx.AsyncClient( + headers={"User-Agent": f"{ragna.__version__}/{self}"}, + timeout=60, + ) + self._api_key: Optional[str] = ( + os.environ[self._API_KEY_ENV_VAR] + if self._API_KEY_ENV_VAR is not None + else None + ) + self._call_api = HttpApiCaller(self._client, self._STREAMING_PROTOCOL) diff --git a/ragna/assistants/_llamafile.py b/ragna/assistants/_llamafile.py index 3e78a625..5c7cc1da 100644 --- a/ragna/assistants/_llamafile.py +++ b/ragna/assistants/_llamafile.py @@ -1,9 +1,11 @@ import os +from functools import cached_property -from ._openai import OpenaiCompliantHttpApiAssistant +from ._http_api import HttpStreamingProtocol +from ._openai import OpenaiLikeHttpApiAssistant -class LlamafileAssistant(OpenaiCompliantHttpApiAssistant): +class LlamafileAssistant(OpenaiLikeHttpApiAssistant): """[llamafile](https://github.com/Mozilla-Ocho/llamafile) To use this assistant, start the llamafile server manually. By default, the server @@ -16,10 +18,14 @@ class LlamafileAssistant(OpenaiCompliantHttpApiAssistant): """ _API_KEY_ENV_VAR = None - _STREAMING_METHOD = "sse" + _STREAMING_PROTOCOL = HttpStreamingProtocol.SSE _MODEL = None - @property + @classmethod + def display_name(cls) -> str: + return "llamafile" + + @cached_property def _url(self) -> str: base_url = os.environ.get("RAGNA_LLAMAFILE_BASE_URL", "http://localhost:8080") return f"{base_url}/v1/chat/completions" diff --git a/ragna/assistants/_ollama.py b/ragna/assistants/_ollama.py new file mode 100644 index 00000000..3bb23c9f --- /dev/null +++ b/ragna/assistants/_ollama.py @@ -0,0 +1,83 @@ +import os +from functools import cached_property +from typing import AsyncIterator, cast + +from ragna.core import RagnaException, Source + +from ._http_api import HttpStreamingProtocol +from ._openai import OpenaiLikeHttpApiAssistant + + +class OllamaAssistant(OpenaiLikeHttpApiAssistant): + """[Ollama](https://ollama.com/) + + To use this assistant, start the Ollama server manually. By default, the server + is expected at `http://localhost:11434`. This can be changed with the + `RAGNA_OLLAMA_BASE_URL` environment variable. + """ + + _API_KEY_ENV_VAR = None + _STREAMING_PROTOCOL = HttpStreamingProtocol.JSONL + _MODEL: str + + @classmethod + def display_name(cls) -> str: + return f"Ollama/{cls._MODEL}" + + @cached_property + def _url(self) -> str: + base_url = os.environ.get("RAGNA_OLLAMA_BASE_URL", "http://localhost:11434") + return f"{base_url}/api/chat" + + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + ) -> AsyncIterator[str]: + async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens): + # Modeled after + # https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62 + if "error" in data: + raise RagnaException(data["error"]) + if not data["done"]: + yield cast(str, data["message"]["content"]) + + +class OllamaGemma2B(OllamaAssistant): + """[Gemma:2B](https://ollama.com/library/gemma)""" + + _MODEL = "gemma:2b" + + +class OllamaLlama2(OllamaAssistant): + """[Llama 2](https://ollama.com/library/llama2)""" + + _MODEL = "llama2" + + +class OllamaLlava(OllamaAssistant): + """[Llava](https://ollama.com/library/llava)""" + + _MODEL = "llava" + + +class OllamaMistral(OllamaAssistant): + """[Mistral](https://ollama.com/library/mistral)""" + + _MODEL = "mistral" + + +class OllamaMixtral(OllamaAssistant): + """[Mixtral](https://ollama.com/library/mixtral)""" + + _MODEL = "mixtral" + + +class OllamaOrcaMini(OllamaAssistant): + """[Orca Mini](https://ollama.com/library/orca-mini)""" + + _MODEL = "orca-mini" + + +class OllamaPhi2(OllamaAssistant): + """[Phi-2](https://ollama.com/library/phi)""" + + _MODEL = "phi" diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index 37957be2..0f51d6d9 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -1,23 +1,15 @@ import abc -from typing import Any, AsyncIterator, Literal, Optional, cast +from functools import cached_property +from typing import Any, AsyncIterator, Optional, cast -from ragna.core import PackageRequirement, RagnaException, Requirement, Source +from ragna.core import Source -from ._http_api import HttpApiAssistant +from ._http_api import HttpApiAssistant, HttpStreamingProtocol -class OpenaiCompliantHttpApiAssistant(HttpApiAssistant): - _STREAMING_METHOD: Literal["sse", "jsonl"] +class OpenaiLikeHttpApiAssistant(HttpApiAssistant): _MODEL: Optional[str] - @classmethod - def requirements(cls) -> list[Requirement]: - requirements = super().requirements() - requirements.extend( - {"sse": [PackageRequirement("httpx_sse")]}.get(cls._STREAMING_METHOD, []) - ) - return requirements - @property @abc.abstractmethod def _url(self) -> str: ... @@ -32,23 +24,8 @@ def _make_system_content(self, sources: list[Source]) -> str: return instruction + "\n\n".join(source.content for source in sources) def _stream( - self, - method: str, - url: str, - **kwargs: Any, + self, prompt: str, sources: list[Source], *, max_new_tokens: int ) -> AsyncIterator[dict[str, Any]]: - stream = { - "sse": self._stream_sse, - "jsonl": self._stream_jsonl, - }.get(self._STREAMING_METHOD) - if stream is None: - raise RagnaException - - return stream(method, url, **kwargs) - - async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 - ) -> AsyncIterator[str]: # See https://platform.openai.com/docs/api-reference/chat/create # and https://platform.openai.com/docs/api-reference/chat/streaming headers = { @@ -75,7 +52,12 @@ async def answer( if self._MODEL is not None: json_["model"] = self._MODEL - async for data in self._stream("POST", self._url, headers=headers, json=json_): + return self._call_api("POST", self._url, headers=headers, json=json_) + + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + ) -> AsyncIterator[str]: + async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens): choice = data["choices"][0] if choice["finish_reason"] is not None: break @@ -83,15 +65,15 @@ async def answer( yield cast(str, choice["delta"]["content"]) -class OpenaiAssistant(OpenaiCompliantHttpApiAssistant): +class OpenaiAssistant(OpenaiLikeHttpApiAssistant): _API_KEY_ENV_VAR = "OPENAI_API_KEY" - _STREAMING_METHOD = "sse" + _STREAMING_PROTOCOL = HttpStreamingProtocol.SSE @classmethod def display_name(cls) -> str: return f"OpenAI/{cls._MODEL}" - @property + @cached_property def _url(self) -> str: return "https://api.openai.com/v1/chat/completions" diff --git a/ragna/deploy/_api/core.py b/ragna/deploy/_api/core.py index 5346b048..2a95ae96 100644 --- a/ragna/deploy/_api/core.py +++ b/ragna/deploy/_api/core.py @@ -233,11 +233,7 @@ def schema_to_core_chat( chat_name=chat.metadata.name, **chat.metadata.params, ) - # FIXME: We need to reconstruct the previous messages here. Right now this is - # not needed, because the chat itself never accesses past messages. However, - # if we implement a chat history feature, i.e. passing past messages to - # the assistant, this becomes crucial. - core_chat._messages = [] + core_chat._messages = [message.to_core() for message in chat.messages] core_chat._prepared = chat.prepared return core_chat diff --git a/ragna/deploy/_api/schemas.py b/ragna/deploy/_api/schemas.py index 53957a74..37471c69 100644 --- a/ragna/deploy/_api/schemas.py +++ b/ragna/deploy/_api/schemas.py @@ -26,6 +26,16 @@ def from_core(cls, document: ragna.core.Document) -> Document: name=document.name, ) + def to_core(self) -> ragna.core.Document: + return ragna.core.LocalDocument( + id=self.id, + name=self.name, + # TEMP: setting an empty metadata dict for now. + # Will be resolved as part of the "managed ragna" work: + # https://github.com/Quansight/ragna/issues/256 + metadata={}, + ) + class DocumentUpload(BaseModel): parameters: ragna.core.DocumentUploadParameters @@ -50,6 +60,15 @@ def from_core(cls, source: ragna.core.Source) -> Source: num_tokens=source.num_tokens, ) + def to_core(self) -> ragna.core.Source: + return ragna.core.Source( + id=self.id, + document=self.document.to_core(), + location=self.location, + content=self.content, + num_tokens=self.num_tokens, + ) + class Message(BaseModel): id: uuid.UUID = Field(default_factory=uuid.uuid4) @@ -66,6 +85,13 @@ def from_core(cls, message: ragna.core.Message) -> Message: sources=[Source.from_core(source) for source in message.sources], ) + def to_core(self) -> ragna.core.Message: + return ragna.core.Message( + content=self.content, + role=self.role, + sources=[source.to_core() for source in self.sources], + ) + class ChatMetadata(BaseModel): name: str diff --git a/ragna/deploy/_ui/css/modal_welcome/button.css b/ragna/deploy/_ui/css/modal_welcome/button.css deleted file mode 100644 index 5c98c041..00000000 --- a/ragna/deploy/_ui/css/modal_welcome/button.css +++ /dev/null @@ -1,4 +0,0 @@ -:host(.modal_welcome_close_button) { - width: 35%; - margin-left: 60%; -} diff --git a/ragna/deploy/_ui/main_page.py b/ragna/deploy/_ui/main_page.py index c8610e7b..4ba5ba94 100644 --- a/ragna/deploy/_ui/main_page.py +++ b/ragna/deploy/_ui/main_page.py @@ -3,12 +3,9 @@ import panel as pn import param -from . import js -from . import styles as ui from .central_view import CentralView from .left_sidebar import LeftSidebar from .modal_configuration import ModalConfiguration -from .modal_welcome import ModalWelcome from .right_sidebar import RightSidebar @@ -71,14 +68,6 @@ def open_modal(self): self.template.modal.objects[0].objects = [self.modal] self.template.open_modal() - def open_welcome_modal(self, event): - self.modal = ModalWelcome( - close_button_callback=lambda: self.template.close_modal(), - ) - - self.template.modal.objects[0].objects = [self.modal] - self.template.open_modal() - async def open_new_chat(self, new_chat_id): # called after creating a new chat. self.current_chat_id = new_chat_id @@ -111,59 +100,9 @@ def update_subviews_current_chat_id(self, avoid_senders=[]): def __panel__(self): asyncio.ensure_future(self.refresh_data()) - objects = [self.left_sidebar, self.central_view, self.right_sidebar] - - if self.chats is not None and len(self.chats) == 0: - """I haven't found a better way to open the modal when the pages load, - than simulating a click on the "New chat" button. - - calling self.template.open_modal() doesn't work - - calling self.on_click_new_chat doesn't work either - - trying to schedule a call to on_click_new_chat with pn.state.schedule_task - could have worked but my tests were yielding an unstable result. - """ - - new_chat_button_name = "open welcome modal" - open_welcome_modal = pn.widgets.Button( - name=new_chat_button_name, - button_type="primary", - ) - open_welcome_modal.on_click(self.open_welcome_modal) - - hack_open_modal = pn.pane.HTML( - """ - - """.replace( - "{new_chat_btn_name}", new_chat_button_name - ).strip(), - # This is not really styling per say, it's just a way to hide from the page the HTML item of this hack. - # It's not worth moving this to a separate file. - stylesheets=[ - ui.css( - ":host", - {"position": "absolute", "z-index": "-999"}, - ) - ], - ) - - objects.append( - pn.Row( - open_welcome_modal, - pn.pane.HTML(js.SHADOWROOT_INDEXING), - hack_open_modal, - visible=False, - ) - ) - - main_page = pn.Row( - *objects, + return pn.Row( + self.left_sidebar, + self.central_view, + self.right_sidebar, css_classes=["main_page_main_row"], ) - - return main_page diff --git a/ragna/deploy/_ui/modal_welcome.py b/ragna/deploy/_ui/modal_welcome.py deleted file mode 100644 index 71b6ad7f..00000000 --- a/ragna/deploy/_ui/modal_welcome.py +++ /dev/null @@ -1,42 +0,0 @@ -import panel as pn -import param - -from . import js -from . import styles as ui - - -class ModalWelcome(pn.viewable.Viewer): - close_button_callback = param.Callable() - - def __init__(self, **params): - super().__init__(**params) - - def did_click_on_close_button(self, event): - if self.close_button_callback is not None: - self.close_button_callback() - - def __panel__(self): - close_button = pn.widgets.Button( - name="Okay, let's go", - button_type="primary", - css_classes=["modal_welcome_close_button"], - ) - close_button.on_click(self.did_click_on_close_button) - - return pn.Column( - pn.pane.HTML( - f"""""" - + """