Skip to content

Commit

Permalink
introduce engine for API
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jun 25, 2024
1 parent 528b953 commit 8fcaf3b
Show file tree
Hide file tree
Showing 13 changed files with 781 additions and 639 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ ignore_missing_imports = true

[[tool.mypy.overrides]]
module = [
"ragna.deploy._api.orm",
"ragna.deploy._orm",
]
# Our ORM schema doesn't really work with mypy. There are some other ways to define it
# to play ball. We should do that in the future.
Expand Down
12 changes: 12 additions & 0 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import abc
import datetime
import enum
import functools
import inspect
import uuid
from typing import AsyncIterable, AsyncIterator, Iterator, Optional, Type, Union

import pydantic
Expand Down Expand Up @@ -157,6 +159,8 @@ def __init__(
*,
role: MessageRole = MessageRole.SYSTEM,
sources: Optional[list[Source]] = None,
id: Optional[uuid.UUID] = None,
timestamp: Optional[datetime.datetime] = None,
) -> None:
if isinstance(content, str):
self._content: str = content
Expand All @@ -166,6 +170,14 @@ def __init__(
self.role = role
self.sources = sources or []

if id is None:
id = uuid.uuid4()
self.id = id

if timestamp is None:
timestamp = datetime.datetime.utcnow()
self.timestamp = timestamp

async def __aiter__(self) -> AsyncIterator[str]:
if hasattr(self, "_content"):
yield self._content
Expand Down
87 changes: 70 additions & 17 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import uuid
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Expand All @@ -12,21 +13,24 @@
Iterable,
Iterator,
Optional,
Type,
TypeVar,
Union,
cast,
)

import pydantic
from fastapi import status
from starlette.concurrency import iterate_in_threadpool, run_in_threadpool

from ._components import Assistant, Component, Message, MessageRole, SourceStorage
from ._document import Document, LocalDocument
from ._utils import RagnaException, default_user, merge_models

if TYPE_CHECKING:
from ragna.deploy import Config

T = TypeVar("T")
C = TypeVar("C", bound=Component)
C = TypeVar("C", bound=Component, covariant=True)


class Rag(Generic[C]):
Expand All @@ -41,20 +45,69 @@ class Rag(Generic[C]):
```
"""

def __init__(self) -> None:
self._components: dict[Type[C], C] = {}
def __init__(
self,
*,
config: Optional[Config] = None,
ignore_unavailable_components: bool = False,
) -> None:
self._components: dict[type[C], C] = {}
self._display_name_map: dict[str, type[C]] = {}

if config is not None:
self._preload_components(
config=config,
ignore_unavailable_components=ignore_unavailable_components,
)

def _preload_components(
self, *, config: Config, ignore_unavailable_components: bool
) -> None:
for components in [config.source_storages, config.assistants]:
components = cast(list[type[Component]], components)
at_least_one = False
for component in components:
loaded_component = self._load_component(
component, # type: ignore[arg-type]
ignore_unavailable=ignore_unavailable_components,
)
if loaded_component is None:
print(
f"Ignoring {component.display_name()}, because it is not available."
)
else:
at_least_one = True

if not at_least_one:
raise RagnaException(
"No component available",
components=[component.display_name() for component in components],
)

def _load_component(
self, component: Union[Type[C], C], *, ignore_unavailable: bool = False
self, component: Union[C, type[C], str], *, ignore_unavailable: bool = False
) -> Optional[C]:
cls: Type[C]
cls: type[C]
instance: Optional[C]

if isinstance(component, Component):
instance = cast(C, component)
cls = type(instance)
elif isinstance(component, type) and issubclass(component, Component):
cls = component
instance = None
elif isinstance(component, str):
try:
cls = self._display_name_map[component]
except KeyError:
raise RagnaException(
"Unknown component",
display_name=component,
help="Did you forget to create the Rag() instance with a config?",
http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
http_detail=f"Unknown component '{component}'",
) from None

instance = None
else:
raise RagnaException
Expand All @@ -71,31 +124,33 @@ def _load_component(
instance = cls()

self._components[cls] = instance
self._display_name_map[cls.display_name()] = cls

return self._components[cls]

def chat(
self,
*,
documents: Iterable[Any],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
source_storage: Union[SourceStorage, type[SourceStorage], str],
assistant: Union[Assistant, type[Assistant], str],
**params: Any,
) -> Chat:
"""Create a new [ragna.core.Chat][].
Args:
documents: Documents to use. If any item is not a [ragna.core.Document][],
[ragna.core.LocalDocument.from_path][] is invoked on it.
FIXME
source_storage: Source storage to use.
assistant: Assistant to use.
**params: Additional parameters passed to the source storage and assistant.
"""
return Chat(
self,
documents=documents,
source_storage=source_storage,
assistant=assistant,
source_storage=cast(SourceStorage, self._load_component(source_storage)), # type: ignore[arg-type]
assistant=cast(Assistant, self._load_component(assistant)), # type: ignore[arg-type]
**params,
)

Expand Down Expand Up @@ -146,17 +201,15 @@ def __init__(
rag: Rag,
*,
documents: Iterable[Any],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
source_storage: SourceStorage,
assistant: Assistant,
**params: Any,
) -> None:
self._rag = rag

self.documents = self._parse_documents(documents)
self.source_storage = cast(
SourceStorage, self._rag._load_component(source_storage)
)
self.assistant = cast(Assistant, self._rag._load_component(assistant))
self.source_storage = source_storage
self.assistant = assistant

special_params = SpecialChatParams().model_dump()
special_params.update(params)
Expand Down Expand Up @@ -306,6 +359,6 @@ async def __aenter__(self) -> Chat:
return self

async def __aexit__(
self, exc_type: Type[Exception], exc: Exception, traceback: str
self, exc_type: type[Exception], exc: Exception, traceback: str
) -> None:
pass
163 changes: 163 additions & 0 deletions ragna/deploy/_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import uuid
from typing import Annotated, AsyncIterator, cast

import aiofiles
import pydantic
from fastapi import (
APIRouter,
Body,
Depends,
Form,
HTTPException,
UploadFile,
)
from fastapi.responses import StreamingResponse

import ragna
import ragna.core
from ragna._compat import anext
from ragna.core._utils import default_user
from ragna.deploy import Config

from . import _schemas as schemas
from ._engine import Engine


def make_router(config: Config, engine: Engine) -> APIRouter:
router = APIRouter(tags=["API"])

def get_user() -> str:
return default_user()

UserDependency = Annotated[str, Depends(get_user)]

# TODO: the document endpoints do not go through the engine, because they'll change
# quite drastically when the UI no longer depends on the API

_database = engine._database

@router.post("/document")
async def create_document_upload_info(
user: UserDependency,
name: Annotated[str, Body(..., embed=True)],
) -> schemas.DocumentUpload:
with _database.get_session() as session:
document = schemas.Document(name=name)
metadata, parameters = await config.document.get_upload_info(
config=config, user=user, id=document.id, name=document.name
)
document.metadata = metadata
_database.add_document(
session, user=user, document=document, metadata=metadata
)
return schemas.DocumentUpload(parameters=parameters, document=document)

# TODO: Add UI support and documentation for this endpoint (#406)
@router.post("/documents")
async def create_documents_upload_info(
user: UserDependency,
names: Annotated[list[str], Body(..., embed=True)],
) -> list[schemas.DocumentUpload]:
with _database.get_session() as session:
document_metadata_collection = []
document_upload_collection = []
for name in names:
document = schemas.Document(name=name)
metadata, parameters = await config.document.get_upload_info(
config=config, user=user, id=document.id, name=document.name
)
document.metadata = metadata
document_metadata_collection.append((document, metadata))
document_upload_collection.append(
schemas.DocumentUpload(parameters=parameters, document=document)
)

_database.add_documents(
session,
user=user,
document_metadata_collection=document_metadata_collection,
)
return document_upload_collection

# TODO: Add new endpoint for batch uploading documents (#407)
@router.put("/document")
async def upload_document(
token: Annotated[str, Form()], file: UploadFile
) -> schemas.Document:
if not issubclass(config.document, ragna.core.LocalDocument):
raise HTTPException(
status_code=400,
detail="Ragna configuration does not support local upload",
)
with _database.get_session() as session:
user, id = ragna.core.LocalDocument.decode_upload_token(token)
document = _database.get_document(session, user=user, id=id)

core_document = cast(
ragna.core.LocalDocument, engine._to_core.document(document)
)
core_document.path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(core_document.path, "wb") as document_file:
while content := await file.read(1024):
await document_file.write(content)

return document

@router.get("/components")
def get_components(_: UserDependency) -> schemas.Components:
return engine.get_components()

@router.post("/chats")
async def create_chat(
user: UserDependency,
chat_metadata: schemas.ChatMetadata,
) -> schemas.Chat:
return engine.create_chat(user=user, chat_metadata=chat_metadata)

@router.get("/chats")
async def get_chats(user: UserDependency) -> list[schemas.Chat]:
return engine.get_chats(user=user)

@router.get("/chats/{id}")
async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat:
return engine.get_chat(user=user, id=id)

@router.post("/chats/{id}/prepare")
async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message:
return await engine.prepare_chat(user=user, id=id)

@router.post("/chats/{id}/answer")
async def answer(
user: UserDependency,
id: uuid.UUID,
prompt: Annotated[str, Body(..., embed=True)],
stream: Annotated[bool, Body(..., embed=True)] = False,
) -> schemas.Message:
message_stream = engine.answer_stream(user=user, chat_id=id, prompt=prompt)
answer = await anext(message_stream)

if not stream:
content_chunks = [chunk.content async for chunk in message_stream]
answer.content += "".join(content_chunks)
return answer

async def message_chunks() -> AsyncIterator[schemas.Message]:
yield answer
async for chunk in message_stream:
yield chunk

async def to_jsonl(
models: AsyncIterator[pydantic.BaseModel],
) -> AsyncIterator[str]:
async for model in models:
yield f"{model.model_dump_json()}\n"

return StreamingResponse( # type: ignore[return-value]
to_jsonl(message_chunks())
)

@router.delete("/chats/{id}")
async def delete_chat(user: UserDependency, id: uuid.UUID) -> None:
engine.delete_chat(user=user, id=id)

return router
1 change: 0 additions & 1 deletion ragna/deploy/_api/__init__.py

This file was deleted.

Loading

0 comments on commit 8fcaf3b

Please sign in to comment.