Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support o1 developer message format #1717

Merged
merged 5 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions cookbook/examples/streamlit/paperpal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main() -> None:
with st.container():
search_terms_container = st.empty()
search_generator_input = {"topic": report_topic, "num_terms": num_search_terms}
search_terms = search_term_generator.run(json.dumps(search_generator_input))
search_terms = search_term_generator.run(json.dumps(search_generator_input)).content
if search_terms:
search_terms_container.json(search_terms.model_dump())
status.update(label="Search Terms Generated", state="complete", expanded=False)
Expand All @@ -88,7 +88,11 @@ def main() -> None:
exa_search_results = exa_search_agent.run(search_terms.model_dump_json(indent=4))
if isinstance(exa_search_results, str):
raise ValueError("Unexpected string response from exa_search_agent")
if exa_search_results and len(exa_search_results.content.results) > 0:
if (
exa_search_results
and exa_search_results.content
and len(exa_search_results.content.results) > 0
):
exa_content = exa_search_results.model_dump_json(indent=4)
exa_container.json(exa_search_results.content.results)
status.update(label="Exa Search Complete", state="complete", expanded=False)
Expand All @@ -102,11 +106,11 @@ def main() -> None:
with st.container():
arxiv_container = st.empty()
arxiv_search_results = arxiv_search_agent.run(search_terms.model_dump_json(indent=4))
if arxiv_search_results and arxiv_search_results.content.results:
if arxiv_search_results and arxiv_search_results.content and arxiv_search_results.content.results:
arxiv_container.json([result.model_dump() for result in arxiv_search_results.content.results])
status.update(label="ArXiv Search Complete", state="complete", expanded=False)

if arxiv_search_results and arxiv_search_results.content.results:
if arxiv_search_results and arxiv_search_results.content and arxiv_search_results.content.results:
paper_summaries = []
for result in arxiv_search_results.content.results:
summary = {
Expand Down
7 changes: 7 additions & 0 deletions cookbook/providers/openai/o1/o1_mini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from phi.agent import Agent, RunResponse # noqa
from phi.model.openai import OpenAIChat

agent = Agent(model=OpenAIChat(id="o1-mini"))

# Print the response in the terminal
agent.print_response("What is the closest galaxy to milky way?")
8 changes: 8 additions & 0 deletions cookbook/providers/openai/o1/o1_mini_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Iterator # noqa
from phi.agent import Agent, RunResponse # noqa
from phi.model.openai import OpenAIChat

agent = Agent(model=OpenAIChat(id="o1-mini"))

# Print the response in the terminal
agent.print_response("What is the closest galaxy to milky way?", stream=True)
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,5 @@

agent = Agent(model=OpenAIChat(id="o1-preview"))

# Get the response in a variable
# run: RunResponse = agent.run("What is the closest galaxy to milky way?")
# print(run.content)

# Print the response in the terminal
agent.print_response("What is the closest galaxy to milky way?")
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,5 @@

agent = Agent(model=OpenAIChat(id="o1-preview"))

# Get the response in a variable
# run_response: Iterator[RunResponse] = agent.run("What is the closest galaxy to milky way?", stream=True)
# for chunk in run_response:
# print(chunk.content)

# Print the response in the terminal
agent.print_response("What is the closest galaxy to milky way?", stream=True)
16 changes: 9 additions & 7 deletions cookbook/vectordb/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from phi.agent import Agent
from phi.knowledge.pdf import PDFUrlKnowledgeBase
import os
#os.environ["OPENAI_API_KEY"] = ""

# os.environ["OPENAI_API_KEY"] = ""
from phi.vectordb.mongodb import MongoDBVector

# MongoDB Atlas connection string
# MongoDB Atlas connection string
"""
Example connection strings:
"mongodb+srv://<username>:<password>@cluster0.mongodb.net/?retryWrites=true&w=majority"
Expand All @@ -16,10 +16,12 @@

knowledge_base = PDFUrlKnowledgeBase(
urls=["https://phi-public.s3.amazonaws.com/recipes/ThaiRecipes.pdf"],
vector_db=MongoDBVector(collection_name="recipes", db_url=mdb_connection_string, wait_until_index_ready=60, wait_after_insert=300),
) #adjust wait_after_insert and wait_until_index_ready to your needs
knowledge_base.load(recreate=True)
vector_db=MongoDBVector(
collection_name="recipes", db_url=mdb_connection_string, wait_until_index_ready=60, wait_after_insert=300
),
) # adjust wait_after_insert and wait_until_index_ready to your needs
knowledge_base.load(recreate=True)

# Create and use the agent
agent = Agent(knowledge_base=knowledge_base, show_tool_calls=True)
agent.print_response("How to make Thai curry?", markdown=True)
agent.print_response("How to make Thai curry?", markdown=True)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel
from dotenv import load_dotenv

from cookbook.examples.workflows.content_creator_workflow.config import TYPEFULLY_API_URL, HEADERS, PostType
from cookbook.workflows.content_creator_workflow.config import TYPEFULLY_API_URL, HEADERS, PostType
from phi.utils.log import logger

load_dotenv()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from phi.model.openai import OpenAIChat
from phi.tools.firecrawl import FirecrawlTools
from phi.utils.log import logger
from cookbook.examples.workflows.content_creator_workflow.scheduler import schedule
from cookbook.examples.workflows.content_creator_workflow.prompts import agents_config, tasks_config
from cookbook.examples.workflows.content_creator_workflow.config import PostType
from cookbook.workflows.content_creator_workflow.scheduler import schedule
from cookbook.workflows.content_creator_workflow.prompts import agents_config, tasks_config
from cookbook.workflows.content_creator_workflow.config import PostType

# Load environment variables
load_dotenv()
Expand All @@ -34,7 +34,7 @@ class Tweet(BaseModel):

content: str
is_hook: bool = Field(default=False, description="Marks if this tweet is the 'hook' (first tweet)")
media_urls: Optional[List[str]] = Field(default_factory=list, description="Associated media URLs, if any")
media_urls: Optional[List[str]] = Field(default_factory=list, description="Associated media URLs, if any") # type: ignore


class Thread(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion phi/model/ollama/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,6 @@ async def aresponse_stream(self, messages: List[Message]) -> Any:
yield post_tool_call_response
logger.debug("---------- Ollama Async Response End ----------")

def model_copy(self, *, update: Optional[dict[str, Any]] = None, deep: bool = False) -> "Ollama":
def model_copy(self, *, update: Optional[Mapping[str, Any]] = None, deep: bool = False) -> "Ollama":
new_model = Ollama(**self.model_dump(exclude={"client"}), client=self.client)
return new_model
7 changes: 5 additions & 2 deletions phi/model/openai/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@ def format_message(self, message: Message) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The formatted message.
"""
if message.role == "system":
message.role = "developer"

if message.role == "user":
if message.images is not None:
message = self.add_images_to_message(message=message, images=message.images)
Expand Down Expand Up @@ -599,7 +602,7 @@ def response(self, messages: List[Message]) -> ModelResponse:
# -*- Parse transcript if available
if response_audio:
if response_audio.transcript and not response_message.content:
response_message.content = response_message.audio.transcript
response_message.content = response_audio.transcript

# -*- Parse structured outputs
try:
Expand Down Expand Up @@ -677,7 +680,7 @@ async def aresponse(self, messages: List[Message]) -> ModelResponse:
# -*- Parse transcript if available
if response_audio:
if response_audio.transcript and not response_message.content:
response_message.content = response_message.audio.transcript
response_message.content = response_audio.transcript

# -*- Parse structured outputs
try:
Expand Down
29 changes: 14 additions & 15 deletions phi/vectordb/mongodb/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
except ImportError:
raise ImportError("`pymongo` not installed. Please install using `pip install pymongo`")


class MongoDBVector(VectorDb):
"""
MongoDB Vector Database implementation with elegant handling of Atlas Search index creation.
Expand All @@ -34,8 +35,8 @@ def __init__(
embedder: Embedder = OpenAIEmbedder(),
distance_metric: str = Distance.cosine,
overwrite: bool = False,
wait_until_index_ready: float = None,
wait_after_insert: float = None,
wait_until_index_ready: Optional[float] = None,
wait_after_insert: Optional[float] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -71,9 +72,9 @@ def _get_client(self) -> MongoClient:
"""Create or retrieve the MongoDB client."""
try:
logger.debug("Creating MongoDB Client")
client = MongoClient(self.connection_string, **self.kwargs)
client: MongoClient = MongoClient(self.connection_string, **self.kwargs)
# Trigger a connection to verify the client
client.admin.command('ping')
client.admin.command("ping")
logger.info("Connected to MongoDB successfully.")
return client
except errors.ConnectionFailure as e:
Expand All @@ -85,7 +86,7 @@ def _get_client(self) -> MongoClient:

def _get_or_create_collection(self) -> Collection:
"""Get or create the MongoDB collection, handling Atlas Search index creation."""

self._collection = self._db[self.collection_name]

if not self.collection_exists():
Expand Down Expand Up @@ -120,7 +121,7 @@ def _create_search_index(self, overwrite: bool = True) -> None:
"type": "vector",
"numDimensions": 1536,
"path": "embedding",
"similarity": self.distance_metric, #cosine
"similarity": self.distance_metric, # cosine
},
]
},
Expand Down Expand Up @@ -157,7 +158,7 @@ def _wait_for_index_ready(self) -> None:
break
except Exception as e:
logger.error(f"Error checking index status: {e}")
if time.time() - start_time > self.wait_until_index_ready:
if time.time() - start_time > self.wait_until_index_ready: # type: ignore
raise TimeoutError("Timeout waiting for search index to become ready.")
time.sleep(1)

Expand All @@ -171,7 +172,7 @@ def create(self) -> None:

def doc_exists(self, document: Document) -> bool:
"""Check if a document exists in the MongoDB collection based on its content."""
doc_id = md5(document.content.encode('utf-8')).hexdigest()
doc_id = md5(document.content.encode("utf-8")).hexdigest()
try:
exists = self._collection.find_one({"_id": doc_id}) is not None
logger.debug(f"Document {'exists' if exists else 'does not exist'}: {doc_id}")
Expand Down Expand Up @@ -219,7 +220,7 @@ def insert(self, documents: List[Document], filters: Optional[Dict[str, Any]] =
# lets wait for 5 minutes.... just in case
# feel free to 'optimize'... :)
if self.wait_after_insert and self.wait_after_insert > 0:
time.sleep(self.wait_after_insert)
time.sleep(self.wait_after_insert)
except errors.BulkWriteError as e:
logger.warning(f"Bulk write error while inserting documents: {e.details}")
except Exception as e:
Expand All @@ -245,9 +246,7 @@ def upsert_available(self) -> bool:
"""Indicate that upsert functionality is available."""
return True

def search(
self, query: str, limit: int = 5, filters: Optional[Dict[str, Any]] = None
) -> List[Document]:
def search(self, query: str, limit: int = 5, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
"""Search the MongoDB collection for documents relevant to the query."""
query_embedding = self.embedder.get_embedding(query)
if query_embedding is None:
Expand All @@ -268,7 +267,7 @@ def search(
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
]
pipeline.append({"$project": {"embedding": 0}})
agg = list(self._collection.aggregate(pipeline))
agg = list(self._collection.aggregate(pipeline)) # type: ignore
docs = []
for doc in agg:
docs.append(
Expand Down Expand Up @@ -366,7 +365,7 @@ def prepare_doc(self, document: Document) -> Dict[str, Any]:
raise ValueError(f"Failed to generate embedding for document: {document.id}")

cleaned_content = document.content.replace("\x00", "\ufffd")
doc_id = md5(cleaned_content.encode('utf-8')).hexdigest()
doc_id = md5(cleaned_content.encode("utf-8")).hexdigest()
doc_data = {
"_id": doc_id,
"name": document.name,
Expand All @@ -385,4 +384,4 @@ def get_count(self) -> int:
return count
except Exception as e:
logger.error(f"Error getting document count: {e}")
return 0
return 0
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ module = [
"psycopg2.*",
"pyarrow.*",
"pycountry.*",
"pymongo.*",
"pypdf.*",
"pytz.*",
"qdrant_client.*",
Expand Down
Loading