Skip to content

Commit

Permalink
fix pre-commit issues
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonroach committed Sep 5, 2024
1 parent e767c1b commit c9e1a7e
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 26 deletions.
8 changes: 4 additions & 4 deletions ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator, Union, cast

from ragna.core import Message, Source
from ragna.core import Message, MessageRole, Source

from ._http_api import HttpApiAssistant

Expand Down Expand Up @@ -33,7 +33,7 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> Union[str, list]:
messages = [Message(content=prompt, role=MessageRole.USER)]
else:
messages = prompt

messages = [
{"text": i["content"], "role": i["role"]}
for i in messages
Expand Down Expand Up @@ -76,7 +76,7 @@ async def generate(
"numResults": 1,
"temperature": 0.0,
"maxTokens": max_new_tokens,
"messages": _render_prompt(prompt),
"messages": self._render_prompt(prompt),
"system": system_prompt,
},
) as stream:
Expand All @@ -88,7 +88,7 @@ async def answer(
) -> AsyncIterator[str]:
prompt, sources = (message := messages[-1]).content, message.sources
system_prompt = self._make_system_content(sources)
yield generate(
yield self.generate(
prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens
)

Expand Down
17 changes: 12 additions & 5 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from typing import AsyncIterator, Union, cast

from ragna.core import Message, PackageRequirement, RagnaException, Requirement, Source
from ragna.core import (
Message,
MessageRole,
PackageRequirement,
RagnaException,
Requirement,
Source,
)

from ._http_api import HttpApiAssistant, HttpStreamingProtocol

Expand Down Expand Up @@ -44,9 +51,9 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> list[dict]:
ordered list of dicts with 'content' and 'role' keys
"""
if isinstance(prompt, str):
messages = [Message(content=prompt, role=MessageRole.USER)]
messages = [Message(content=prompt, role=MessageRole.USER)]
else:
messages = prompt
messages = prompt

messages = [
{"content": i["content"], "role": i["role"]}
Expand Down Expand Up @@ -88,8 +95,8 @@ async def generate(
},
json={
"model": self._MODEL,
"system": system,
"messages": _render_prompt(prompt),
"system": system_prompt,
"messages": self._render_prompt(prompt),
"max_tokens": max_new_tokens,
"temperature": 0.0,
"stream": True,
Expand Down
10 changes: 5 additions & 5 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator, Union, cast

from ragna.core import Message, RagnaException, Source
from ragna.core import Message, MessageRole, RagnaException, Source

from ._http_api import HttpApiAssistant, HttpStreamingProtocol

Expand Down Expand Up @@ -35,7 +35,7 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> str:
messages = [Message(content=prompt, role=MessageRole.USER)]
else:
messages = prompt

messages = [i["content"] for i in messages if i["role"] == "user"][-1]
return messages

Expand Down Expand Up @@ -74,12 +74,12 @@ async def generate(
},
json={
"preamble_override": system_prompt,
"message": _render_prompt(prompt),
"message": self._render_prompt(prompt),
"model": self._MODEL,
"stream": True,
"temperature": 0.0,
"max_tokens": max_new_tokens,
"documents": self._make_source_documents(sources),
"documents": source_documents,
},
) as stream:
async for event in stream:
Expand All @@ -100,7 +100,7 @@ async def answer(
prompt, sources = (message := messages[-1]).content, message.sources
system_prompt = self._make_preamble()
source_documents = self._make_source_documents(sources)
yield generate(
yield self.generate(
prompt=prompt,
system_prompt=system_prompt,
source_documents=source_documents,
Expand Down
6 changes: 3 additions & 3 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator, Union

from ragna.core import Message, Source
from ragna.core import Message, MessageRole, Source

from ._http_api import HttpApiAssistant, HttpStreamingProtocol

Expand Down Expand Up @@ -59,7 +59,7 @@ async def generate(
params={"key": self._api_key},
headers={"Content-Type": "application/json"},
json={
"contents": _render_prompt(prompt),
"contents": self._render_prompt(prompt),
# https://ai.google.dev/docs/safety_setting_gemini
"safetySettings": [
{
Expand Down Expand Up @@ -89,7 +89,7 @@ async def answer(
) -> AsyncIterator[str]:
prompt, sources = (message := messages[-1]).content, message.sources
expanded_prompt = self._instructize_prompt(prompt, sources)
yield generate(prompt=expanded_prompt, max_new_tokens=max_new_tokens)
yield self.generate(prompt=expanded_prompt, max_new_tokens=max_new_tokens)


class GeminiPro(GoogleAssistant):
Expand Down
38 changes: 29 additions & 9 deletions ragna/assistants/_openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import abc
from functools import cached_property
from typing import Any, AsyncContextManager, AsyncIterator, Optional, cast, Union
from typing import (
Any,
AsyncContextManager,
AsyncIterator,
MessageRole,
Optional,
Union,
cast,
)

from ragna.core import Message, Source

Expand All @@ -23,7 +31,9 @@ def _make_system_content(self, sources: list[Source]) -> str:
)
return instruction + "\n\n".join(source.content for source in sources)

def _render_prompt(self, prompt: Union[str,list[Message]], system_prompt: str) -> list[dict]:
def _render_prompt(
self, prompt: Union[str, list[Message]], system_prompt: str
) -> list[dict]:
"""
Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format.
Expand All @@ -34,14 +44,22 @@ def _render_prompt(self, prompt: Union[str,list[Message]], system_prompt: str) -
messages = [Message(content=prompt, role=MessageRole.USER)]
else:
messages = prompt
system_message = [{"role":"system", "content":system_prompt}]
messages = [{"role":i["role"],"content":i["content"]} for i in prompt if i["role"] != "system"]
system_message = [{"role": "system", "content": system_prompt}]
messages = [
{"role": i["role"], "content": i["content"]}
for i in prompt
if i["role"] != "system"
]
return system_message.extend(messages)

async def generate(
self, prompt: Union[str,list[Message]], *, system_prompt: str = "You are a helpful assistant.", max_new_tokens: int = 256
self,
prompt: Union[str, list[Message]],
*,
system_prompt: str = "You are a helpful assistant.",
max_new_tokens: int = 256,
) -> AsyncContextManager[AsyncIterator[dict[str, Any]]]:
"""
"""
Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer()
This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls.
Expand Down Expand Up @@ -77,8 +95,10 @@ def _call_openai_api(
) -> AsyncContextManager[AsyncIterator[dict[str, Any]]]:
system_prompt = self._make_system_content(sources)

yield self.generate(prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens)

yield self.generate(
prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens
)

async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
Expand Down

0 comments on commit c9e1a7e

Please sign in to comment.