Skip to content

Commit

Permalink
Merge branch 'main' into ragna-base-poc
Browse files Browse the repository at this point in the history
  • Loading branch information
arjxn-py authored Jun 14, 2024
2 parents 9b0f1c2 + 45d5f94 commit ba60010
Show file tree
Hide file tree
Showing 26 changed files with 626 additions and 297 deletions.
4 changes: 4 additions & 0 deletions .github/actions/setup-env/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ runs:
shell: bash -el {0}
run: mamba install --yes --channel conda-forge redis-server

- name: Install playwright
shell: bash -el {0}
run: playwright install

- name: Install ragna
shell: bash -el {0}
run: |
Expand Down
68 changes: 67 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,76 @@ jobs:

- name: Run unit tests
id: tests
run: pytest --junit-xml=test-results.xml --durations=25
run: |
pytest \
--ignore tests/deploy/ui \
--junit-xml=test-results.xml \
--durations=25
- name: Surface failing tests
if: steps.tests.outcome != 'success'
uses: pmeier/[email protected]
with:
path: test-results.xml

pytest-ui:
strategy:
matrix:
os:
- ubuntu-latest
- windows-latest
- macos-latest
browser:
- chromium
- firefox
python-version:
- "3.9"
- "3.10"
- "3.11"
exclude:
- python-version: "3.10"
os: windows-latest
- python-version: "3.11"
os: windows-latest
- python-version: "3.10"
os: macos-latest
- python-version: "3.11"
os: macos-latest
include:
- browser: webkit
os: macos-latest
python-version: "3.9"

fail-fast: false

runs-on: ${{ matrix.os }}

defaults:
run:
shell: bash -el {0}

steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Setup environment
uses: ./.github/actions/setup-env
with:
python-version: ${{ matrix.python-version }}

- name: Run unit tests
id: tests
run: |
pytest tests/deploy/ui \
--browser ${{ matrix.browser }} \
--video=retain-on-failure
- name: Upload playwright video
if: failure()
uses: actions/upload-artifact@v4
with:
name:
playwright-${{ matrix.os }}-${{ matrix.python-version}}-${{ github.run_id }}
path: test-results
8 changes: 8 additions & 0 deletions docs/examples/gallery_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@
# - [ragna.assistants.Gpt4][]
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
# - [ragna.assistants.LlamafileAssistant][]
# - [Ollama](https://ollama.com/)
# - [ragna.assistants.OllamaGemma2B][]
# - [ragna.assistants.OllamaLlama2][]
# - [ragna.assistants.OllamaLlava][]
# - [ragna.assistants.OllamaMistral][]
# - [ragna.assistants.OllamaMixtral][]
# - [ragna.assistants.OllamaOrcaMini][]
# - [ragna.assistants.OllamaPhi2][]

from ragna import assistants

Expand Down
8 changes: 8 additions & 0 deletions docs/tutorials/gallery_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@
# - [ragna.assistants.Jurassic2Ultra][]
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
# - [ragna.assistants.LlamafileAssistant][]
# - [Ollama](https://ollama.com/)
# - [ragna.assistants.OllamaGemma2B][]
# - [ragna.assistants.OllamaLlama2][]
# - [ragna.assistants.OllamaLlava][]
# - [ragna.assistants.OllamaMistral][]
# - [ragna.assistants.OllamaMixtral][]
# - [ragna.assistants.OllamaOrcaMini][]
# - [ragna.assistants.OllamaPhi2][]
#
# !!! note
#
Expand Down
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- pytest >=6
- pytest-mock
- pytest-asyncio
- pytest-playwright
- mypy ==1.10.0
- pre-commit
- types-aiofiles
Expand Down
16 changes: 16 additions & 0 deletions ragna/assistants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
"CommandLight",
"GeminiPro",
"GeminiUltra",
"OllamaGemma2B",
"OllamaPhi2",
"OllamaLlama2",
"OllamaLlava",
"OllamaMistral",
"OllamaMixtral",
"OllamaOrcaMini",
"Gpt35Turbo16k",
"Gpt4",
"Jurassic2Ultra",
Expand All @@ -19,6 +26,15 @@
from ._demo import RagnaDemoAssistant
from ._google import GeminiPro, GeminiUltra
from ._llamafile import LlamafileAssistant
from ._ollama import (
OllamaGemma2B,
OllamaLlama2,
OllamaLlava,
OllamaMistral,
OllamaMixtral,
OllamaOrcaMini,
OllamaPhi2,
)
from ._openai import Gpt4, Gpt35Turbo16k

# isort: split
Expand Down
10 changes: 5 additions & 5 deletions ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class Ai21LabsAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "AI21_API_KEY"
_STREAMING_PROTOCOL = None
_MODEL_TYPE: str

@classmethod
Expand All @@ -27,7 +28,8 @@ async def answer(
# See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters
# See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters
# See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response
response = await self._client.post(
async for data in self._call_api(
"POST",
f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat",
headers={
"accept": "application/json",
Expand All @@ -46,10 +48,8 @@ async def answer(
],
"system": self._make_system_content(sources),
},
)
await self._assert_api_call_is_success(response)

yield cast(str, response.json()["outputs"][0]["text"])
):
yield cast(str, data["outputs"][0]["text"])


# The Jurassic2Mid assistant receives a 500 internal service error from the remote
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from ragna.core import PackageRequirement, RagnaException, Requirement, Source

from ._http_api import HttpApiAssistant
from ._http_api import HttpApiAssistant, HttpStreamingProtocol


class AnthropicAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "ANTHROPIC_API_KEY"
_STREAMING_PROTOCOL = HttpStreamingProtocol.SSE
_MODEL: str

@classmethod
Expand Down Expand Up @@ -40,7 +41,7 @@ async def answer(
) -> AsyncIterator[str]:
# See https://docs.anthropic.com/claude/reference/messages_post
# See https://docs.anthropic.com/claude/reference/streaming
async for data in self._stream_sse(
async for data in self._call_api(
"POST",
"https://api.anthropic.com/v1/messages",
headers={
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from ragna.core import RagnaException, Source

from ._http_api import HttpApiAssistant
from ._http_api import HttpApiAssistant, HttpStreamingProtocol


class CohereAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "COHERE_API_KEY"
_STREAMING_PROTOCOL = HttpStreamingProtocol.JSONL
_MODEL: str

@classmethod
Expand All @@ -29,7 +30,7 @@ async def answer(
# See https://docs.cohere.com/docs/cochat-beta
# See https://docs.cohere.com/reference/chat
# See https://docs.cohere.com/docs/retrieval-augmented-generation-rag
async for event in self._stream_jsonl(
async for event in self._call_api(
"POST",
"https://api.cohere.ai/v1/chat",
headers={
Expand Down
49 changes: 11 additions & 38 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,15 @@
from typing import AsyncIterator

from ragna._compat import anext
from ragna.core import PackageRequirement, Requirement, Source
from ragna.core import Source

from ._http_api import HttpApiAssistant


# ijson does not support reading from an (async) iterator, but only from file-like
# objects, i.e. https://docs.python.org/3/tutorial/inputoutput.html#methods-of-file-objects.
# See https://github.com/ICRAR/ijson/issues/44 for details.
# ijson actually doesn't care about most of the file interface and only requires the
# read() method to be present.
class AsyncIteratorReader:
def __init__(self, ait: AsyncIterator[bytes]) -> None:
self._ait = ait

async def read(self, n: int) -> bytes:
# n is usually used to indicate how many bytes to read, but since we want to
# return a chunk as soon as it is available, we ignore the value of n. The only
# exception is n == 0, which is used by ijson to probe the return type and
# set up decoding.
if n == 0:
return b""
return await anext(self._ait, b"") # type: ignore[call-arg]
from ._http_api import HttpApiAssistant, HttpStreamingProtocol


class GoogleAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "GOOGLE_API_KEY"
_STREAMING_PROTOCOL = HttpStreamingProtocol.JSON
_MODEL: str

@classmethod
def _extra_requirements(cls) -> list[Requirement]:
return [PackageRequirement("ijson")]

@classmethod
def display_name(cls) -> str:
return f"Google/{cls._MODEL}"
Expand All @@ -51,9 +28,7 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:
async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
import ijson

async with self._client.stream(
async for chunk in self._call_api(
"POST",
f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent",
params={"key": self._api_key},
Expand All @@ -64,7 +39,10 @@ async def answer(
],
# https://ai.google.dev/docs/safety_setting_gemini
"safetySettings": [
{"category": f"HARM_CATEGORY_{category}", "threshold": "BLOCK_NONE"}
{
"category": f"HARM_CATEGORY_{category}",
"threshold": "BLOCK_NONE",
}
for category in [
"HARASSMENT",
"HATE_SPEECH",
Expand All @@ -78,14 +56,9 @@ async def answer(
"maxOutputTokens": max_new_tokens,
},
},
) as response:
await self._assert_api_call_is_success(response)

async for chunk in ijson.items(
AsyncIteratorReader(response.aiter_bytes(1024)),
"item.candidates.item.content.parts.item.text",
):
yield chunk
parse_kwargs=dict(item="item.candidates.item.content.parts.item.text"),
):
yield chunk


class GeminiPro(GoogleAssistant):
Expand Down
Loading

0 comments on commit ba60010

Please sign in to comment.