diff --git a/api/src/crud.py b/api/src/crud.py index 27dc56f..2b342e5 100644 --- a/api/src/crud.py +++ b/api/src/crud.py @@ -114,9 +114,16 @@ async def get_user_by_name( async def get_users( - db: AsyncSession, app_id: uuid.UUID, reverse: bool = False + db: AsyncSession, + app_id: uuid.UUID, + reverse: bool = False, + filter: Optional[dict] = None, ) -> Select: stmt = select(models.User).where(models.User.app_id == app_id) + + if filter is not None: + stmt = stmt.where(models.User.h_metadata.contains(filter)) + if reverse: stmt = stmt.order_by(models.User.created_at.desc()) else: @@ -180,6 +187,7 @@ async def get_sessions( location_id: Optional[str] = None, reverse: Optional[bool] = False, is_active: Optional[bool] = False, + filter: Optional[dict] = None, ) -> Select: stmt = ( select(models.Session) @@ -191,6 +199,9 @@ async def get_sessions( if is_active: stmt = stmt.where(models.Session.is_active.is_(True)) + if filter is not None: + stmt = stmt.where(models.Session.h_metadata.contains(filter)) + if reverse: stmt = stmt.order_by(models.Session.created_at.desc()) else: @@ -205,9 +216,12 @@ async def get_sessions( async def create_session( db: AsyncSession, session: schemas.SessionCreate, - app_id: uuid.UUID, # TODO check if app id is associated with the right user + app_id: uuid.UUID, user_id: uuid.UUID, ) -> models.Session: + honcho_user = await get_user(db, app_id=app_id, user_id=user_id) + if honcho_user is None: + raise ValueError("User not found") honcho_session = models.Session( user_id=user_id, location_id=session.location_id, @@ -281,6 +295,7 @@ async def create_message( session_id=session_id, is_user=message.is_user, content=message.content, + h_metadata=message.metadata, ) db.add(honcho_message) await db.commit() @@ -294,6 +309,7 @@ async def get_messages( user_id: uuid.UUID, session_id: uuid.UUID, reverse: Optional[bool] = False, + filter: Optional[dict] = None, ) -> Select: stmt = ( select(models.Message) @@ -305,6 +321,9 @@ async def get_messages( .where(models.Message.session_id == session_id) ) + if filter is not None: + stmt = stmt.where(models.Message.h_metadata.contains(filter)) + if reverse: stmt = stmt.order_by(models.Message.created_at.desc()) else: @@ -334,6 +353,28 @@ async def get_message( return result.scalar_one_or_none() +async def update_message( + db: AsyncSession, + message: schemas.MessageUpdate, + app_id: uuid.UUID, + user_id: uuid.UUID, + session_id: uuid.UUID, + message_id: uuid.UUID, +) -> bool: + honcho_message = await get_message( + db, app_id=app_id, session_id=session_id, user_id=user_id, message_id=message_id + ) + if honcho_message is None: + raise ValueError("Message not found or does not belong to user") + if ( + message.metadata is not None + ): # Need to explicitly be there won't make it empty by default + honcho_message.h_metadata = message.metadata + await db.commit() + await db.refresh(honcho_message) + return honcho_message + + ######################################################## # metamessage methods ######################################################## @@ -360,6 +401,7 @@ async def create_metamessage( message_id=metamessage.message_id, metamessage_type=metamessage.metamessage_type, content=metamessage.content, + h_metadata=metamessage.metadata, ) db.add(honcho_metamessage) @@ -375,6 +417,7 @@ async def get_metamessages( session_id: uuid.UUID, message_id: Optional[uuid.UUID], metamessage_type: Optional[str] = None, + filter: Optional[dict] = None, reverse: Optional[bool] = False, ) -> Select: stmt = ( @@ -394,6 +437,9 @@ async def get_metamessages( if metamessage_type is not None: stmt = stmt.where(models.Metamessage.metamessage_type == metamessage_type) + if filter is not None: + stmt = stmt.where(models.Metamessage.h_metadata.contains(filter)) + if reverse: stmt = stmt.order_by(models.Metamessage.created_at.desc()) else: @@ -426,6 +472,35 @@ async def get_metamessage( return result.scalar_one_or_none() +async def update_metamessage( + db: AsyncSession, + metamessage: schemas.MetamessageUpdate, + app_id: uuid.UUID, + user_id: uuid.UUID, + session_id: uuid.UUID, + metamessage_id: uuid.UUID, +) -> bool: + honcho_metamessage = await get_metamessage( + db, + app_id=app_id, + session_id=session_id, + user_id=user_id, + message_id=metamessage.message_id, + metamessage_id=metamessage_id, + ) + if honcho_metamessage is None: + raise ValueError("Metamessage not found or does not belong to user") + if ( + metamessage.metadata is not None + ): # Need to explicitly be there won't make it empty by default + honcho_metamessage.h_metadata = metamessage.metadata + if metamessage.metamessage_type is not None: + honcho_metamessage.metamessage_type = metamessage.metamessage_type + await db.commit() + await db.refresh(honcho_metamessage) + return honcho_metamessage + + ######################################################## # collection methods ######################################################## @@ -438,6 +513,7 @@ async def get_collections( app_id: uuid.UUID, user_id: uuid.UUID, reverse: Optional[bool] = False, + filter: Optional[dict] = None, ) -> Select: """Get a distinct list of the names of collections associated with a user""" stmt = ( @@ -447,6 +523,9 @@ async def get_collections( .where(models.User.id == user_id) ) + if filter is not None: + stmt = stmt.where(models.Collection.h_metadata.contains(filter)) + if reverse: stmt = stmt.order_by(models.Collection.created_at.desc()) else: @@ -494,6 +573,7 @@ async def create_collection( honcho_collection = models.Collection( user_id=user_id, name=collection.name, + h_metadata=collection.metadata, ) try: db.add(honcho_collection) @@ -517,6 +597,8 @@ async def update_collection( ) if honcho_collection is None: raise ValueError("collection not found or does not belong to user") + if collection.metadata is not None: + honcho_collection.h_metadata = collection.metadata try: honcho_collection.name = collection.name await db.commit() @@ -563,6 +645,7 @@ async def get_documents( user_id: uuid.UUID, collection_id: uuid.UUID, reverse: Optional[bool] = False, + filter: Optional[dict] = None, ) -> Select: stmt = ( select(models.Document) @@ -573,6 +656,9 @@ async def get_documents( .where(models.Document.collection_id == collection_id) ) + if filter is not None: + stmt = stmt.where(models.Document.h_metadata.contains(filter)) + if reverse: stmt = stmt.order_by(models.Document.created_at.desc()) else: @@ -609,6 +695,7 @@ async def query_documents( user_id: uuid.UUID, collection_id: uuid.UUID, query: str, + filter: Optional[dict] = None, top_k: int = 5, ) -> Sequence[models.Document]: response = openai_client.embeddings.create( @@ -622,11 +709,13 @@ async def query_documents( .where(models.User.app_id == app_id) .where(models.User.id == user_id) .where(models.Document.collection_id == collection_id) - .order_by(models.Document.embedding.cosine_distance(embedding_query)) - .limit(top_k) + # .limit(top_k) + ) + if filter is not None: + stmt = stmt.where(models.Document.h_metadata.contains(filter)) + stmt = stmt.limit(top_k).order_by( + models.Document.embedding.cosine_distance(embedding_query) ) - # if metadata is not None: - # stmt = stmt.where(models.Document.h_metadata.contains(metadata)) result = await db.execute(stmt) return result.scalars().all() diff --git a/api/src/main.py b/api/src/main.py index 19269dc..9dcf7db 100644 --- a/api/src/main.py +++ b/api/src/main.py @@ -384,6 +384,7 @@ async def get_users( request: Request, app_id: uuid.UUID, reverse: bool = False, + filter: Optional[str] = None, db: AsyncSession = Depends(get_db), ): """Get All Users for an App @@ -396,7 +397,13 @@ async def get_users( list[schemas.User]: List of User objects """ - return await paginate(db, await crud.get_users(db, app_id=app_id, reverse=reverse)) + data = None + if filter is not None: + data = json.loads(filter) + + return await paginate( + db, await crud.get_users(db, app_id=app_id, reverse=reverse, filter=data) + ) @app.get("/apps/{app_id}/users/{name}", response_model=schemas.User) @@ -479,6 +486,7 @@ async def get_sessions( location_id: Optional[str] = None, is_active: Optional[bool] = False, reverse: Optional[bool] = False, + filter: Optional[str] = None, db: AsyncSession = Depends(get_db), ): """Get All Sessions for a User @@ -494,6 +502,11 @@ async def get_sessions( list[schemas.Session]: List of Session objects """ + + data = None + if filter is not None: + data = json.loads(filter) + return await paginate( db, await crud.get_sessions( @@ -503,6 +516,7 @@ async def get_sessions( location_id=location_id, reverse=reverse, is_active=is_active, + filter=data, ), ) @@ -557,9 +571,7 @@ async def update_session( """ if session.metadata is None: - raise HTTPException( - status_code=400, detail="Session metadata cannot be empty" - ) # TODO TEST if I can set the metadata to be blank with this + raise HTTPException(status_code=400, detail="Session metadata cannot be empty") try: return await crud.update_session( db, app_id=app_id, user_id=user_id, session_id=session_id, session=session @@ -676,6 +688,7 @@ async def get_messages( user_id: uuid.UUID, session_id: uuid.UUID, reverse: Optional[bool] = False, + filter: Optional[str] = None, db: AsyncSession = Depends(get_db), ): """Get all messages for a session @@ -695,6 +708,9 @@ async def get_messages( """ try: + data = None + if filter is not None: + data = json.loads(filter) return await paginate( db, await crud.get_messages( @@ -702,6 +718,7 @@ async def get_messages( app_id=app_id, user_id=user_id, session_id=session_id, + filter=data, reverse=reverse, ), ) @@ -729,6 +746,34 @@ async def get_message( return honcho_message +@router.put( + "sessions/{session_id}/messages/{message_id}", response_model=schemas.Message +) +async def update_message( + request: Request, + app_id: uuid.UUID, + user_id: uuid.UUID, + session_id: uuid.UUID, + message_id: uuid.UUID, + message: schemas.MessageUpdate, + db: AsyncSession = Depends(get_db), +): + """Update's the metadata of a message""" + if message.metadata is None: + raise HTTPException(status_code=400, detail="Message metadata cannot be empty") + try: + return await crud.update_message( + db, + message=message, + app_id=app_id, + user_id=user_id, + session_id=session_id, + message_id=message_id, + ) + except ValueError: + raise HTTPException(status_code=404, detail="Session not found") from None + + ######################################################## # metamessage routes ######################################################## @@ -783,6 +828,7 @@ async def get_metamessages( message_id: Optional[uuid.UUID] = None, metamessage_type: Optional[str] = None, reverse: Optional[bool] = False, + filter: Optional[str] = None, db: AsyncSession = Depends(get_db), ): """Get all messages for a session @@ -802,6 +848,9 @@ async def get_metamessages( """ try: + data = None + if filter is not None: + data = json.loads(filter) return await paginate( db, await crud.get_metamessages( @@ -811,6 +860,7 @@ async def get_metamessages( session_id=session_id, message_id=message_id, metamessage_type=metamessage_type, + filter=data, reverse=reverse, ), ) @@ -831,7 +881,7 @@ async def get_metamessage( metamessage_id: uuid.UUID, db: AsyncSession = Depends(get_db), ): - """Get a specific session for a user by ID + """Get a specific Metamessage by ID Args: app_id (uuid.UUID): The ID of the app representing the client application using @@ -858,6 +908,37 @@ async def get_metamessage( return honcho_metamessage +@router.put( + "sessions/{session_id}/metamessages/{metamessage_id}", + response_model=schemas.Metamessage, +) +async def update_metamessage( + request: Request, + app_id: uuid.UUID, + user_id: uuid.UUID, + session_id: uuid.UUID, + metamessage_id: uuid.UUID, + metamessage: schemas.MetamessageUpdate, + db: AsyncSession = Depends(get_db), +): + """Update's the metadata of a metamessage""" + if metamessage.metadata is None: + raise HTTPException( + status_code=400, detail="Metamessage metadata cannot be empty" + ) + try: + return await crud.update_metamessage( + db, + metamessage=metamessage, + app_id=app_id, + user_id=user_id, + session_id=session_id, + metamessage_id=metamessage_id, + ) + except ValueError: + raise HTTPException(status_code=404, detail="Session not found") from None + + ######################################################## # collection routes ######################################################## @@ -869,11 +950,28 @@ async def get_collections( app_id: uuid.UUID, user_id: uuid.UUID, reverse: Optional[bool] = False, + filter: Optional[str] = None, db: AsyncSession = Depends(get_db), ): + """Get All Collections for a User + + Args: + app_id (uuid.UUID): The ID of the app representing the client + application using honcho + user_id (uuid.UUID): The User ID representing the user, managed by the user + + Returns: + list[schemas.Collection]: List of Collection objects + + """ + data = None + if filter is not None: + data = json.loads(filter) return await paginate( db, - await crud.get_collections(db, app_id=app_id, user_id=user_id, reverse=reverse), + await crud.get_collections( + db, app_id=app_id, user_id=user_id, filter=data, reverse=reverse + ), ) @@ -994,9 +1092,13 @@ async def get_documents( user_id: uuid.UUID, collection_id: uuid.UUID, reverse: Optional[bool] = False, + filter: Optional[str] = None, db: AsyncSession = Depends(get_db), ): try: + data = None + if filter is not None: + data = json.loads(filter) return await paginate( db, await crud.get_documents( @@ -1004,6 +1106,7 @@ async def get_documents( app_id=app_id, user_id=user_id, collection_id=collection_id, + filter=data, reverse=reverse, ), ) @@ -1053,16 +1156,21 @@ async def query_documents( collection_id: uuid.UUID, query: str, top_k: int = 5, + filter: Optional[str] = None, db: AsyncSession = Depends(get_db), ): if top_k is not None and top_k > 50: top_k = 50 # TODO see if we need to paginate this + data = None + if filter is not None: + data = json.loads(filter) return await crud.query_documents( db=db, app_id=app_id, user_id=user_id, collection_id=collection_id, query=query, + filter=data, top_k=top_k, ) diff --git a/api/src/models.py b/api/src/models.py index ebd7cdd..0790b73 100644 --- a/api/src/models.py +++ b/api/src/models.py @@ -87,6 +87,7 @@ class Message(Base): session_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("sessions.id"), index=True) is_user: Mapped[bool] content: Mapped[str] = mapped_column(String(65535)) + h_metadata: Mapped[dict] = mapped_column("metadata", ColumnType, default={}) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), default=datetime.datetime.utcnow @@ -111,6 +112,7 @@ class Metamessage(Base): created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), default=datetime.datetime.utcnow ) + h_metadata: Mapped[dict] = mapped_column("metadata", ColumnType, default={}) def __repr__(self) -> str: return f"Metamessages(id={self.id}, message_id={self.message_id}, metamessage_type={self.metamessage_type}, content={self.content[10:]})" @@ -125,6 +127,7 @@ class Collection(Base): created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), default=datetime.datetime.utcnow ) + h_metadata: Mapped[dict] = mapped_column("metadata", ColumnType, default={}) documents = relationship( "Document", back_populates="collection", cascade="all, delete, delete-orphan" ) diff --git a/api/src/schemas.py b/api/src/schemas.py index a0a779c..eb09713 100644 --- a/api/src/schemas.py +++ b/api/src/schemas.py @@ -1,7 +1,8 @@ -from pydantic import BaseModel, validator import datetime import uuid +from pydantic import BaseModel, validator + class AppBase(BaseModel): pass @@ -68,21 +69,37 @@ class Config: class MessageBase(BaseModel): + pass + + +class MessageCreate(MessageBase): content: str is_user: bool + metadata: dict | None = {} -class MessageCreate(MessageBase): - pass +class MessageUpdate(MessageBase): + metadata: dict | None = None class Message(MessageBase): + content: str + is_user: bool session_id: uuid.UUID id: uuid.UUID + h_metadata: dict + metadata: dict created_at: datetime.datetime + @validator("metadata", pre=True, allow_reuse=True) + def fetch_h_metadata(cls, value, values): + if "h_metadata" in values: + return values["h_metadata"] + return {} + class Config: from_attributes = True + schema_extra = {"exclude": ["h_metadata"]} class SessionBase(BaseModel): @@ -120,21 +137,40 @@ class Config: class MetamessageBase(BaseModel): + pass + + +class MetamessageCreate(MetamessageBase): metamessage_type: str content: str + message_id: uuid.UUID + metadata: dict | None = {} -class MetamessageCreate(MetamessageBase): +class MetamessageUpdate(MetamessageBase): message_id: uuid.UUID + metamessage_type: str | None = None + metadata: dict | None = None class Metamessage(MetamessageBase): + metamessage_type: str + content: str id: uuid.UUID message_id: uuid.UUID + h_metadata: dict + metadata: dict created_at: datetime.datetime + @validator("metadata", pre=True, allow_reuse=True) + def fetch_h_metadata(cls, value, values): + if "h_metadata" in values: + return values["h_metadata"] + return {} + class Config: from_attributes = True + schema_extra = {"exclude": ["h_metadata"]} class CollectionBase(BaseModel): @@ -143,20 +179,31 @@ class CollectionBase(BaseModel): class CollectionCreate(CollectionBase): name: str + metadata: dict | None = {} class CollectionUpdate(CollectionBase): name: str + metadata: dict | None = None class Collection(CollectionBase): id: uuid.UUID name: str user_id: uuid.UUID + h_metadata: dict + metadata: dict created_at: datetime.datetime + @validator("metadata", pre=True, allow_reuse=True) + def fetch_h_metadata(cls, value, values): + if "h_metadata" in values: + return values["h_metadata"] + return {} + class Config: from_attributes = True + schema_extra = {"exclude": ["h_metadata"]} class DocumentBase(BaseModel): diff --git a/sdk/honcho/client.py b/sdk/honcho/client.py index 24659cc..1bc7262 100644 --- a/sdk/honcho/client.py +++ b/sdk/honcho/client.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime +import json import uuid from typing import Optional @@ -32,7 +33,13 @@ async def next(self): class AsyncGetUserPage(AsyncGetPage): """Paginated Results for Get User Requests""" - def __init__(self, response: dict, honcho: AsyncHoncho, reverse: bool): + def __init__( + self, + response: dict, + honcho: AsyncHoncho, + filter: Optional[dict], + reverse: bool, + ): """Constructor for Page Result from User Get Request Args: @@ -42,6 +49,7 @@ def __init__(self, response: dict, honcho: AsyncHoncho, reverse: bool): """ super().__init__(response) self.honcho = honcho + self.filter = filter self.reverse = reverse self.items = [ AsyncUser( @@ -57,7 +65,10 @@ async def next(self): if self.page >= self.pages: return None return await self.honcho.get_users( - page=(self.page + 1), page_size=self.page_size, reverse=self.reverse + filter=self.filter, + page=(self.page + 1), + page_size=self.page_size, + reverse=self.reverse, ) @@ -70,6 +81,7 @@ def __init__( user: AsyncUser, reverse: bool, location_id: Optional[str], + filter: Optional[dict], is_active: bool, ): """Constructor for Page Result from Session Get Request @@ -85,6 +97,7 @@ def __init__( self.location_id = location_id self.reverse = reverse self.is_active = is_active + self.filter = filter self.items = [ AsyncSession( user=user, @@ -107,6 +120,7 @@ async def next(self): return None return await self.user.get_sessions( location_id=self.location_id, + filter=self.filter, page=(self.page + 1), page_size=self.page_size, reverse=self.reverse, @@ -117,7 +131,13 @@ async def next(self): class AsyncGetMessagePage(AsyncGetPage): """Paginated Results for Get Session Requests""" - def __init__(self, response: dict, session: AsyncSession, reverse: bool): + def __init__( + self, + response: dict, + session: AsyncSession, + filter: Optional[dict], + reverse: bool, + ): """Constructor for Page Result from Session Get Request Args: @@ -127,6 +147,7 @@ def __init__(self, response: dict, session: AsyncSession, reverse: bool): """ super().__init__(response) self.session = session + self.filter = filter self.reverse = reverse self.items = [ Message( @@ -134,6 +155,7 @@ def __init__(self, response: dict, session: AsyncSession, reverse: bool): id=message["id"], is_user=message["is_user"], content=message["content"], + metadata=message["metadata"], created_at=message["created_at"], ) for message in response["items"] @@ -149,7 +171,7 @@ async def next(self): if self.page >= self.pages: return None return await self.session.get_messages( - (self.page + 1), self.page_size, self.reverse + self.filter, (self.page + 1), self.page_size, self.reverse ) @@ -158,6 +180,7 @@ def __init__( self, response: dict, session, + filter: Optional[dict], reverse: bool, message_id: Optional[uuid.UUID], metamessage_type: Optional[str], @@ -176,11 +199,13 @@ def __init__( self.session = session self.message_id = message_id self.metamessage_type = metamessage_type + self.filter = filter self.reverse = reverse self.items = [ Metamessage( id=metamessage["id"], message_id=metamessage["message_id"], + metadata=metamessage["metadata"], metamessage_type=metamessage["metamessage_type"], content=metamessage["content"], created_at=metamessage["created_at"], @@ -199,6 +224,7 @@ async def next(self): return None return await self.session.get_metamessages( metamessage_type=self.metamessage_type, + filter=self.filter, message=self.message_id, page=(self.page + 1), page_size=self.page_size, @@ -209,7 +235,9 @@ async def next(self): class AsyncGetDocumentPage(AsyncGetPage): """Paginated results for Get Document requests""" - def __init__(self, response: dict, collection, reverse: bool) -> None: + def __init__( + self, response: dict, collection, filter: Optional[dict], reverse: bool + ) -> None: """Constructor for Page Result from Document Get Request Args: @@ -220,6 +248,7 @@ def __init__(self, response: dict, collection, reverse: bool) -> None: """ super().__init__(response) self.collection = collection + self.filter = filter self.reverse = reverse self.items = [ Document( @@ -242,14 +271,19 @@ async def next(self): if self.page >= self.pages: return None return await self.collection.get_documents( - page=self.page + 1, page_size=self.page_size, reverse=self.reverse + filter=self.filter, + page=self.page + 1, + page_size=self.page_size, + reverse=self.reverse, ) class AsyncGetCollectionPage(AsyncGetPage): """Paginated results for Get Collection requests""" - def __init__(self, response: dict, user: AsyncUser, reverse: bool): + def __init__( + self, response: dict, user: AsyncUser, filter: Optional[dict], reverse: bool + ): """Constructor for page result from Get Collection Request Args: @@ -259,12 +293,14 @@ def __init__(self, response: dict, user: AsyncUser, reverse: bool): """ super().__init__(response) self.user = user + self.filter = filter self.reverse = reverse self.items = [ AsyncCollection( user=user, id=collection["id"], name=collection["name"], + metadata=collection["metadata"], created_at=collection["created_at"], ) for collection in response["items"] @@ -280,6 +316,7 @@ async def next(self): if self.page >= self.pages: return None return await self.user.get_collections( + filter=self.filter, page=self.page + 1, page_size=self.page_size, reverse=self.reverse, @@ -394,7 +431,11 @@ async def get_or_create_user(self, name: str): ) async def get_users( - self, page: int = 1, page_size: int = 50, reverse: bool = False + self, + filter: Optional[dict] = None, + page: int = 1, + page_size: int = 50, + reverse: bool = False, ): """Get Paginated list of users @@ -413,13 +454,17 @@ async def get_users( "size": page_size, "reverse": reverse, } + if filter is not None: + json_filter = json.dumps(filter) + params["filter"] = json_filter response = await self.client.get(url, params=params) response.raise_for_status() data = response.json() - return AsyncGetUserPage(data, self, reverse) + return AsyncGetUserPage(data, self, filter, reverse) async def get_users_generator( self, + filter: Optional[dict] = None, reverse: bool = False, ): """Shortcut Generator for get_users. Generator to iterate through @@ -434,7 +479,7 @@ async def get_users_generator( """ page = 1 page_size = 50 - get_user_response = await self.get_users(page, page_size, reverse) + get_user_response = await self.get_users(filter, page, page_size, reverse) while True: for session in get_user_response.items: yield session @@ -533,6 +578,7 @@ async def get_session(self, session_id: uuid.UUID): async def get_sessions( self, location_id: Optional[str] = None, + filter: Optional[dict] = None, page: int = 1, page_size: int = 50, reverse: bool = False, @@ -565,16 +611,20 @@ async def get_sessions( } if location_id: params["location_id"] = location_id + if filter is not None: + json_filter = json.dumps(filter) + params["filter"] = json_filter response = await self.honcho.client.get(url, params=params) response.raise_for_status() data = response.json() - return AsyncGetSessionPage(data, self, reverse, location_id, is_active) + return AsyncGetSessionPage(data, self, reverse, location_id, filter, is_active) async def get_sessions_generator( self, location_id: Optional[str] = None, reverse: bool = False, is_active: bool = False, + filter: Optional[dict] = None, ): """Shortcut Generator for get_sessions. Generator to iterate through all sessions for a user in an app @@ -592,7 +642,7 @@ async def get_sessions_generator( page = 1 page_size = 50 get_session_response = await self.get_sessions( - location_id, page, page_size, reverse, is_active + location_id, filter, page, page_size, reverse, is_active ) while True: for session in get_session_response.items: @@ -637,6 +687,7 @@ async def create_session( async def create_collection( self, name: str, + metadata: Optional[dict] = None, ): """Create a collection for a user @@ -647,7 +698,9 @@ async def create_collection( AsyncCollection: The Collection object of the new Collection """ - data = {"name": name} + if metadata is None: + metadata = {} + data = {"name": name, "metadata": metadata} url = f"{self.base_url}/collections" response = await self.honcho.client.post(url, json=data) response.raise_for_status() @@ -656,6 +709,7 @@ async def create_collection( self, id=data["id"], name=name, + metadata=metadata, created_at=data["created_at"], ) @@ -677,11 +731,16 @@ async def get_collection(self, name: str): user=self, id=data["id"], name=data["name"], + metadata=data["metadata"], created_at=data["created_at"], ) async def get_collections( - self, page: int = 1, page_size: int = 50, reverse: bool = False + self, + filter: Optional[dict] = None, + page: int = 1, + page_size: int = 50, + reverse: bool = False, ): """Return collections associated with a user paginated @@ -701,12 +760,17 @@ async def get_collections( "size": page_size, "reverse": reverse, } + if filter is not None: + json_filter = json.dumps(filter) + params["filter"] = json_filter response = await self.honcho.client.get(url, params=params) response.raise_for_status() data = response.json() - return AsyncGetCollectionPage(data, self, reverse) + return AsyncGetCollectionPage(data, self, filter, reverse) - async def get_collections_generator(self, reverse: bool = False): + async def get_collections_generator( + self, filter: Optional[dict] = None, reverse: bool = False + ): """Shortcut Generator for get_sessions. Generator to iterate through all sessions for a user in an app @@ -719,7 +783,9 @@ async def get_collections_generator(self, reverse: bool = False): """ page = 1 page_size = 50 - get_collection_response = await self.get_collections(page, page_size, reverse) + get_collection_response = await self.get_collections( + filter, page, page_size, reverse + ) while True: for collection in get_collection_response.items: yield collection @@ -765,7 +831,9 @@ def is_active(self): """Returns whether the session is active - made property to prevent tampering""" return self._is_active - async def create_message(self, is_user: bool, content: str): + async def create_message( + self, is_user: bool, content: str, metadata: Optional[dict] = None + ): """Adds a message to the session Args: @@ -778,7 +846,9 @@ async def create_message(self, is_user: bool, content: str): """ if not self.is_active: raise Exception("Session is inactive") - data = {"is_user": is_user, "content": content} + if metadata is None: + metadata = {} + data = {"is_user": is_user, "content": content, "metadata": metadata} url = f"{self.base_url}/messages" response = await self.user.honcho.client.post(url, json=data) response.raise_for_status() @@ -788,6 +858,7 @@ async def create_message(self, is_user: bool, content: str): id=data["id"], is_user=is_user, content=content, + metadata=metadata, created_at=data["created_at"], ) @@ -810,11 +881,16 @@ async def get_message(self, message_id: uuid.UUID) -> Message: id=data["id"], is_user=data["is_user"], content=data["content"], + metadata=data["metadata"], created_at=data["created_at"], ) async def get_messages( - self, page: int = 1, page_size: int = 50, reverse: bool = False + self, + filter: Optional[dict] = None, + page: int = 1, + page_size: int = 50, + reverse: bool = False, ) -> AsyncGetMessagePage: """Get all messages for a session @@ -833,12 +909,17 @@ async def get_messages( "size": page_size, "reverse": reverse, } + if filter is not None: + json_filter = json.dumps(filter) + params["filter"] = json_filter response = await self.user.honcho.client.get(url, params=params) response.raise_for_status() data = response.json() - return AsyncGetMessagePage(data, self, reverse) + return AsyncGetMessagePage(data, self, filter, reverse) - async def get_messages_generator(self, reverse: bool = False): + async def get_messages_generator( + self, filter: Optional[dict] = None, reverse: bool = False + ): """Shortcut Generator for get_messages. Generator to iterate through all messages for a session in an app @@ -848,7 +929,7 @@ async def get_messages_generator(self, reverse: bool = False): """ page = 1 page_size = 50 - get_messages_page = await self.get_messages(page, page_size, reverse) + get_messages_page = await self.get_messages(filter, page, page_size, reverse) while True: for message in get_messages_page.items: yield message @@ -860,7 +941,11 @@ async def get_messages_generator(self, reverse: bool = False): get_messages_page = new_messages async def create_metamessage( - self, message: Message, metamessage_type: str, content: str + self, + message: Message, + metamessage_type: str, + content: str, + metadata: Optional[dict] = None, ): """Adds a metamessage to a session and links it to a specific message @@ -875,10 +960,13 @@ async def create_metamessage( """ if not self.is_active: raise Exception("Session is inactive") + if metadata is None: + metadata = {} data = { "metamessage_type": metamessage_type, "content": content, "message_id": message.id, + "metadata": metadata, } url = f"{self.base_url}/metamessages" response = await self.user.honcho.client.post(url, json=data) @@ -889,6 +977,7 @@ async def create_metamessage( message_id=message.id, metamessage_type=metamessage_type, content=content, + metadata=metadata, created_at=data["created_at"], ) @@ -911,6 +1000,7 @@ async def get_metamessage(self, metamessage_id: uuid.UUID) -> Metamessage: message_id=data["message_id"], metamessage_type=data["metamessage_type"], content=data["content"], + metadata=data["metadata"], created_at=data["created_at"], ) @@ -918,6 +1008,7 @@ async def get_metamessages( self, metamessage_type: Optional[str] = None, message: Optional[Message] = None, + filter: Optional[dict] = None, page: int = 1, page_size: int = 50, reverse: bool = False, @@ -948,18 +1039,22 @@ async def get_metamessages( if message: # url += f"&message_id={message.id}" params["message_id"] = message.id + if filter is not None: + json_metadata = json.dumps(filter) + params["filter"] = json_metadata response = await self.user.honcho.client.get(url, params=params) response.raise_for_status() data = response.json() message_id = message.id if message else None return AsyncGetMetamessagePage( - data, self, reverse, message_id, metamessage_type + data, self, filter, reverse, message_id, metamessage_type ) async def get_metamessages_generator( self, metamessage_type: Optional[str] = None, message: Optional[Message] = None, + filter: Optional[dict] = None, reverse: bool = False, ): """Shortcut Generator for get_metamessages. Generator to iterate @@ -978,6 +1073,7 @@ async def get_metamessages_generator( get_metamessages_page = await self.get_metamessages( metamessage_type=metamessage_type, message=message, + filter=filter, page=page, page_size=page_size, reverse=reverse, @@ -1008,6 +1104,50 @@ async def update(self, metadata: dict): self.metadata = metadata return success + async def update_message(self, message: Message, metadata: dict): + """Update the metadata of a message + + Args: + message (Message): The message to update + metadata (dict): The new metadata for the message + + Returns: + boolean: Whether the message was successfully updated + """ + info = {"metadata": metadata} + url = f"{self.base_url}/messages/{message.id}" + response = await self.user.honcho.client.put(url, json=info) + success = response.status_code < 400 + message.metadata = metadata + return success + + async def update_metamessage( + self, + metamessage: Metamessage, + metamessage_type: Optional[str], + metadata: Optional[dict], + ): + """Update the metadata of a metamessage + + Args: + metamessage (Metamessage): The metamessage to update + metadata (dict): The new metadata for the metamessage + + Returns: + boolean: Whether the metamessage was successfully updated + """ + if metadata is None and metamessage_type is None: + raise ValueError("metadata and metamessage_type cannot both be None") + info = {"metamessage_type": metamessage_type, "metadata": metadata} + url = f"{self.base_url}/metamessages/{metamessage.id}" + response = await self.user.honcho.client.put(url, json=info) + success = response.status_code < 400 + if metamessage_type is not None: + metamessage.metamessage_type = metamessage_type + if metadata is not None: + metamessage.metadata = metadata + return success + async def close(self): """Closes a session by marking it as inactive""" url = f"{self.base_url}" @@ -1024,12 +1164,14 @@ def __init__( user: AsyncUser, id: uuid.UUID, name: str, + metadata: dict, created_at: datetime.datetime, ): """Constructor for Collection""" self.user = user self.id: uuid.UUID = id self.name: str = name + self.metadata: dict = metadata self.created_at: datetime.datetime = created_at @property @@ -1041,7 +1183,7 @@ def __str__(self): """String representation of Collection""" return f"AsyncCollection(id={self.id}, name={self.name}, created_at={self.created_at})" # noqa: E501 - async def update(self, name: str): + async def update(self, name: Optional[str] = None, metadata: Optional[dict] = None): """Update the name of the collection Args: @@ -1050,12 +1192,17 @@ async def update(self, name: str): Returns: boolean: Whether the session was successfully updated """ - info = {"name": name} + if metadata is None and name is None: + raise ValueError("metadata and name cannot both be None") + info = {"name": name, "metadata": metadata} url = f"{self.base_url}" response = await self.user.honcho.client.put(url, json=info) response.raise_for_status() success = response.status_code < 400 - self.name = name + if name is not None: + self.name = name + if metadata is not None: + self.metadata = metadata return success async def delete(self): @@ -1113,7 +1260,11 @@ async def get_document(self, document_id: uuid.UUID) -> Document: ) async def get_documents( - self, page: int = 1, page_size: int = 50, reverse: bool = False + self, + filter: Optional[dict] = None, + page: int = 1, + page_size: int = 50, + reverse: bool = False, ) -> AsyncGetDocumentPage: """Get all documents for a collection @@ -1132,12 +1283,17 @@ async def get_documents( "size": page_size, "reverse": reverse, } + if filter is not None: + json_filter = json.dumps(filter) + params["filter"] = json_filter response = await self.user.honcho.client.get(url, params=params) response.raise_for_status() data = response.json() - return AsyncGetDocumentPage(data, self, reverse) + return AsyncGetDocumentPage(data, self, filter, reverse) - async def get_documents_generator(self, reverse: bool = False): + async def get_documents_generator( + self, filter: Optional[dict] = None, reverse: bool = False + ): """Shortcut Generator for get_documents. Generator to iterate through all documents for a collection in an app @@ -1147,7 +1303,7 @@ async def get_documents_generator(self, reverse: bool = False): """ page = 1 page_size = 50 - get_documents_page = await self.get_documents(page, page_size, reverse) + get_documents_page = await self.get_documents(filter, page, page_size, reverse) while True: for document in get_documents_page.items: yield document @@ -1185,7 +1341,10 @@ async def query(self, query: str, top_k: int = 5) -> list[Document]: return data async def update_document( - self, document: Document, content: Optional[str], metadata: Optional[dict] + self, + document: Document, + content: Optional[str] = None, + metadata: Optional[dict] = None, ) -> Document: """Update a document in the collection diff --git a/sdk/honcho/schemas.py b/sdk/honcho/schemas.py index b5f74d2..1a0a2fe 100644 --- a/sdk/honcho/schemas.py +++ b/sdk/honcho/schemas.py @@ -1,32 +1,60 @@ import uuid import datetime + class Message: - def __init__(self, session_id: uuid.UUID, id: uuid.UUID, is_user: bool, content: str, created_at: datetime.datetime): + def __init__( + self, + session_id: uuid.UUID, + id: uuid.UUID, + is_user: bool, + content: str, + metadata: dict, + created_at: datetime.datetime, + ): """Constructor for Message""" self.session_id = session_id self.id = id self.is_user = is_user self.content = content + self.metadata = metadata self.created_at = created_at def __str__(self): return f"Message(id={self.id}, is_user={self.is_user}, content={self.content})" + class Metamessage: - def __init__(self, id: uuid.UUID, message_id: uuid.UUID, metamessage_type: str, content: str, created_at: datetime.datetime): + def __init__( + self, + id: uuid.UUID, + message_id: uuid.UUID, + metamessage_type: str, + content: str, + metadata: dict, + created_at: datetime.datetime, + ): """Constructor for Metamessage""" self.id = id self.message_id = message_id self.metamessage_type = metamessage_type self.content = content + self.metadata = metadata self.created_at = created_at def __str__(self): return f"Metamessage(id={self.id}, message_id={self.message_id}, metamessage_type={self.metamessage_type}, content={self.content})" - + + class Document: - def __init__(self, id: uuid.UUID, collection_id: uuid.UUID, content: str, metadata: dict, created_at: datetime.datetime): + def __init__( + self, + id: uuid.UUID, + collection_id: uuid.UUID, + content: str, + metadata: dict, + created_at: datetime.datetime, + ): """Constructor for Document""" self.collection_id = collection_id self.id = id @@ -36,4 +64,3 @@ def __init__(self, id: uuid.UUID, collection_id: uuid.UUID, content: str, metada def __str__(self) -> str: return f"Document(id={self.id}, metadata={self.metadata}, content={self.content}, created_at={self.created_at})" - diff --git a/sdk/honcho/sync_client.py b/sdk/honcho/sync_client.py index 2d7b260..58fe0eb 100644 --- a/sdk/honcho/sync_client.py +++ b/sdk/honcho/sync_client.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime +import json import uuid from typing import Optional @@ -32,7 +33,13 @@ def next(self): class GetUserPage(GetPage): """Paginated Results for Get User Requests""" - def __init__(self, response: dict, honcho: Honcho, reverse: bool): + def __init__( + self, + response: dict, + honcho: Honcho, + filter: Optional[dict], + reverse: bool, + ): """Constructor for Page Result from User Get Request Args: @@ -42,6 +49,7 @@ def __init__(self, response: dict, honcho: Honcho, reverse: bool): """ super().__init__(response) self.honcho = honcho + self.filter = filter self.reverse = reverse self.items = [ User( @@ -57,7 +65,10 @@ def next(self): if self.page >= self.pages: return None return self.honcho.get_users( - page=(self.page + 1), page_size=self.page_size, reverse=self.reverse + filter=self.filter, + page=(self.page + 1), + page_size=self.page_size, + reverse=self.reverse, ) @@ -70,6 +81,7 @@ def __init__( user: User, reverse: bool, location_id: Optional[str], + filter: Optional[dict], is_active: bool, ): """Constructor for Page Result from Session Get Request @@ -85,6 +97,7 @@ def __init__( self.location_id = location_id self.reverse = reverse self.is_active = is_active + self.filter = filter self.items = [ Session( user=user, @@ -107,6 +120,7 @@ def next(self): return None return self.user.get_sessions( location_id=self.location_id, + filter=self.filter, page=(self.page + 1), page_size=self.page_size, reverse=self.reverse, @@ -117,7 +131,13 @@ def next(self): class GetMessagePage(GetPage): """Paginated Results for Get Session Requests""" - def __init__(self, response: dict, session: Session, reverse: bool): + def __init__( + self, + response: dict, + session: Session, + filter: Optional[dict], + reverse: bool, + ): """Constructor for Page Result from Session Get Request Args: @@ -127,6 +147,7 @@ def __init__(self, response: dict, session: Session, reverse: bool): """ super().__init__(response) self.session = session + self.filter = filter self.reverse = reverse self.items = [ Message( @@ -134,6 +155,7 @@ def __init__(self, response: dict, session: Session, reverse: bool): id=message["id"], is_user=message["is_user"], content=message["content"], + metadata=message["metadata"], created_at=message["created_at"], ) for message in response["items"] @@ -149,7 +171,7 @@ def next(self): if self.page >= self.pages: return None return self.session.get_messages( - (self.page + 1), self.page_size, self.reverse + self.filter, (self.page + 1), self.page_size, self.reverse ) @@ -158,6 +180,7 @@ def __init__( self, response: dict, session, + filter: Optional[dict], reverse: bool, message_id: Optional[uuid.UUID], metamessage_type: Optional[str], @@ -176,11 +199,13 @@ def __init__( self.session = session self.message_id = message_id self.metamessage_type = metamessage_type + self.filter = filter self.reverse = reverse self.items = [ Metamessage( id=metamessage["id"], message_id=metamessage["message_id"], + metadata=metamessage["metadata"], metamessage_type=metamessage["metamessage_type"], content=metamessage["content"], created_at=metamessage["created_at"], @@ -199,6 +224,7 @@ def next(self): return None return self.session.get_metamessages( metamessage_type=self.metamessage_type, + filter=self.filter, message=self.message_id, page=(self.page + 1), page_size=self.page_size, @@ -209,7 +235,9 @@ def next(self): class GetDocumentPage(GetPage): """Paginated results for Get Document requests""" - def __init__(self, response: dict, collection, reverse: bool) -> None: + def __init__( + self, response: dict, collection, filter: Optional[dict], reverse: bool + ) -> None: """Constructor for Page Result from Document Get Request Args: @@ -220,6 +248,7 @@ def __init__(self, response: dict, collection, reverse: bool) -> None: """ super().__init__(response) self.collection = collection + self.filter = filter self.reverse = reverse self.items = [ Document( @@ -242,14 +271,19 @@ def next(self): if self.page >= self.pages: return None return self.collection.get_documents( - page=self.page + 1, page_size=self.page_size, reverse=self.reverse + filter=self.filter, + page=self.page + 1, + page_size=self.page_size, + reverse=self.reverse, ) class GetCollectionPage(GetPage): """Paginated results for Get Collection requests""" - def __init__(self, response: dict, user: User, reverse: bool): + def __init__( + self, response: dict, user: User, filter: Optional[dict], reverse: bool + ): """Constructor for page result from Get Collection Request Args: @@ -259,12 +293,14 @@ def __init__(self, response: dict, user: User, reverse: bool): """ super().__init__(response) self.user = user + self.filter = filter self.reverse = reverse self.items = [ Collection( user=user, id=collection["id"], name=collection["name"], + metadata=collection["metadata"], created_at=collection["created_at"], ) for collection in response["items"] @@ -280,6 +316,7 @@ def next(self): if self.page >= self.pages: return None return self.user.get_collections( + filter=self.filter, page=self.page + 1, page_size=self.page_size, reverse=self.reverse, @@ -394,7 +431,11 @@ def get_or_create_user(self, name: str): ) def get_users( - self, page: int = 1, page_size: int = 50, reverse: bool = False + self, + filter: Optional[dict] = None, + page: int = 1, + page_size: int = 50, + reverse: bool = False, ): """Get Paginated list of users @@ -413,13 +454,17 @@ def get_users( "size": page_size, "reverse": reverse, } + if filter is not None: + json_filter = json.dumps(filter) + params["filter"] = json_filter response = self.client.get(url, params=params) response.raise_for_status() data = response.json() - return GetUserPage(data, self, reverse) + return GetUserPage(data, self, filter, reverse) def get_users_generator( self, + filter: Optional[dict] = None, reverse: bool = False, ): """Shortcut Generator for get_users. Generator to iterate through @@ -434,7 +479,7 @@ def get_users_generator( """ page = 1 page_size = 50 - get_user_response = self.get_users(page, page_size, reverse) + get_user_response = self.get_users(filter, page, page_size, reverse) while True: for session in get_user_response.items: yield session @@ -533,6 +578,7 @@ def get_session(self, session_id: uuid.UUID): def get_sessions( self, location_id: Optional[str] = None, + filter: Optional[dict] = None, page: int = 1, page_size: int = 50, reverse: bool = False, @@ -565,16 +611,20 @@ def get_sessions( } if location_id: params["location_id"] = location_id + if filter is not None: + json_filter = json.dumps(filter) + params["filter"] = json_filter response = self.honcho.client.get(url, params=params) response.raise_for_status() data = response.json() - return GetSessionPage(data, self, reverse, location_id, is_active) + return GetSessionPage(data, self, reverse, location_id, filter, is_active) def get_sessions_generator( self, location_id: Optional[str] = None, reverse: bool = False, is_active: bool = False, + filter: Optional[dict] = None, ): """Shortcut Generator for get_sessions. Generator to iterate through all sessions for a user in an app @@ -592,7 +642,7 @@ def get_sessions_generator( page = 1 page_size = 50 get_session_response = self.get_sessions( - location_id, page, page_size, reverse, is_active + location_id, filter, page, page_size, reverse, is_active ) while True: for session in get_session_response.items: @@ -637,6 +687,7 @@ def create_session( def create_collection( self, name: str, + metadata: Optional[dict] = None, ): """Create a collection for a user @@ -647,7 +698,9 @@ def create_collection( Collection: The Collection object of the new Collection """ - data = {"name": name} + if metadata is None: + metadata = {} + data = {"name": name, "metadata": metadata} url = f"{self.base_url}/collections" response = self.honcho.client.post(url, json=data) response.raise_for_status() @@ -656,6 +709,7 @@ def create_collection( self, id=data["id"], name=name, + metadata=metadata, created_at=data["created_at"], ) @@ -677,11 +731,16 @@ def get_collection(self, name: str): user=self, id=data["id"], name=data["name"], + metadata=data["metadata"], created_at=data["created_at"], ) def get_collections( - self, page: int = 1, page_size: int = 50, reverse: bool = False + self, + filter: Optional[dict] = None, + page: int = 1, + page_size: int = 50, + reverse: bool = False, ): """Return collections associated with a user paginated @@ -701,12 +760,17 @@ def get_collections( "size": page_size, "reverse": reverse, } + if filter is not None: + json_filter = json.dumps(filter) + params["filter"] = json_filter response = self.honcho.client.get(url, params=params) response.raise_for_status() data = response.json() - return GetCollectionPage(data, self, reverse) + return GetCollectionPage(data, self, filter, reverse) - def get_collections_generator(self, reverse: bool = False): + def get_collections_generator( + self, filter: Optional[dict] = None, reverse: bool = False + ): """Shortcut Generator for get_sessions. Generator to iterate through all sessions for a user in an app @@ -719,7 +783,9 @@ def get_collections_generator(self, reverse: bool = False): """ page = 1 page_size = 50 - get_collection_response = self.get_collections(page, page_size, reverse) + get_collection_response = self.get_collections( + filter, page, page_size, reverse + ) while True: for collection in get_collection_response.items: yield collection @@ -765,7 +831,9 @@ def is_active(self): """Returns whether the session is active - made property to prevent tampering""" return self._is_active - def create_message(self, is_user: bool, content: str): + def create_message( + self, is_user: bool, content: str, metadata: Optional[dict] = None + ): """Adds a message to the session Args: @@ -778,7 +846,9 @@ def create_message(self, is_user: bool, content: str): """ if not self.is_active: raise Exception("Session is inactive") - data = {"is_user": is_user, "content": content} + if metadata is None: + metadata = {} + data = {"is_user": is_user, "content": content, "metadata": metadata} url = f"{self.base_url}/messages" response = self.user.honcho.client.post(url, json=data) response.raise_for_status() @@ -788,6 +858,7 @@ def create_message(self, is_user: bool, content: str): id=data["id"], is_user=is_user, content=content, + metadata=metadata, created_at=data["created_at"], ) @@ -810,11 +881,16 @@ def get_message(self, message_id: uuid.UUID) -> Message: id=data["id"], is_user=data["is_user"], content=data["content"], + metadata=data["metadata"], created_at=data["created_at"], ) def get_messages( - self, page: int = 1, page_size: int = 50, reverse: bool = False + self, + filter: Optional[dict] = None, + page: int = 1, + page_size: int = 50, + reverse: bool = False, ) -> GetMessagePage: """Get all messages for a session @@ -833,12 +909,17 @@ def get_messages( "size": page_size, "reverse": reverse, } + if filter is not None: + json_filter = json.dumps(filter) + params["filter"] = json_filter response = self.user.honcho.client.get(url, params=params) response.raise_for_status() data = response.json() - return GetMessagePage(data, self, reverse) + return GetMessagePage(data, self, filter, reverse) - def get_messages_generator(self, reverse: bool = False): + def get_messages_generator( + self, filter: Optional[dict] = None, reverse: bool = False + ): """Shortcut Generator for get_messages. Generator to iterate through all messages for a session in an app @@ -848,7 +929,7 @@ def get_messages_generator(self, reverse: bool = False): """ page = 1 page_size = 50 - get_messages_page = self.get_messages(page, page_size, reverse) + get_messages_page = self.get_messages(filter, page, page_size, reverse) while True: for message in get_messages_page.items: yield message @@ -860,7 +941,11 @@ def get_messages_generator(self, reverse: bool = False): get_messages_page = new_messages def create_metamessage( - self, message: Message, metamessage_type: str, content: str + self, + message: Message, + metamessage_type: str, + content: str, + metadata: Optional[dict] = None, ): """Adds a metamessage to a session and links it to a specific message @@ -875,10 +960,13 @@ def create_metamessage( """ if not self.is_active: raise Exception("Session is inactive") + if metadata is None: + metadata = {} data = { "metamessage_type": metamessage_type, "content": content, "message_id": message.id, + "metadata": metadata, } url = f"{self.base_url}/metamessages" response = self.user.honcho.client.post(url, json=data) @@ -889,6 +977,7 @@ def create_metamessage( message_id=message.id, metamessage_type=metamessage_type, content=content, + metadata=metadata, created_at=data["created_at"], ) @@ -911,6 +1000,7 @@ def get_metamessage(self, metamessage_id: uuid.UUID) -> Metamessage: message_id=data["message_id"], metamessage_type=data["metamessage_type"], content=data["content"], + metadata=data["metadata"], created_at=data["created_at"], ) @@ -918,6 +1008,7 @@ def get_metamessages( self, metamessage_type: Optional[str] = None, message: Optional[Message] = None, + filter: Optional[dict] = None, page: int = 1, page_size: int = 50, reverse: bool = False, @@ -948,18 +1039,22 @@ def get_metamessages( if message: # url += f"&message_id={message.id}" params["message_id"] = message.id + if filter is not None: + json_metadata = json.dumps(filter) + params["filter"] = json_metadata response = self.user.honcho.client.get(url, params=params) response.raise_for_status() data = response.json() message_id = message.id if message else None return GetMetamessagePage( - data, self, reverse, message_id, metamessage_type + data, self, filter, reverse, message_id, metamessage_type ) def get_metamessages_generator( self, metamessage_type: Optional[str] = None, message: Optional[Message] = None, + filter: Optional[dict] = None, reverse: bool = False, ): """Shortcut Generator for get_metamessages. Generator to iterate @@ -978,6 +1073,7 @@ def get_metamessages_generator( get_metamessages_page = self.get_metamessages( metamessage_type=metamessage_type, message=message, + filter=filter, page=page, page_size=page_size, reverse=reverse, @@ -1008,6 +1104,50 @@ def update(self, metadata: dict): self.metadata = metadata return success + def update_message(self, message: Message, metadata: dict): + """Update the metadata of a message + + Args: + message (Message): The message to update + metadata (dict): The new metadata for the message + + Returns: + boolean: Whether the message was successfully updated + """ + info = {"metadata": metadata} + url = f"{self.base_url}/messages/{message.id}" + response = self.user.honcho.client.put(url, json=info) + success = response.status_code < 400 + message.metadata = metadata + return success + + def update_metamessage( + self, + metamessage: Metamessage, + metamessage_type: Optional[str], + metadata: Optional[dict], + ): + """Update the metadata of a metamessage + + Args: + metamessage (Metamessage): The metamessage to update + metadata (dict): The new metadata for the metamessage + + Returns: + boolean: Whether the metamessage was successfully updated + """ + if metadata is None and metamessage_type is None: + raise ValueError("metadata and metamessage_type cannot both be None") + info = {"metamessage_type": metamessage_type, "metadata": metadata} + url = f"{self.base_url}/metamessages/{metamessage.id}" + response = self.user.honcho.client.put(url, json=info) + success = response.status_code < 400 + if metamessage_type is not None: + metamessage.metamessage_type = metamessage_type + if metadata is not None: + metamessage.metadata = metadata + return success + def close(self): """Closes a session by marking it as inactive""" url = f"{self.base_url}" @@ -1024,12 +1164,14 @@ def __init__( user: User, id: uuid.UUID, name: str, + metadata: dict, created_at: datetime.datetime, ): """Constructor for Collection""" self.user = user self.id: uuid.UUID = id self.name: str = name + self.metadata: dict = metadata self.created_at: datetime.datetime = created_at @property @@ -1041,7 +1183,7 @@ def __str__(self): """String representation of Collection""" return f"Collection(id={self.id}, name={self.name}, created_at={self.created_at})" # noqa: E501 - def update(self, name: str): + def update(self, name: Optional[str] = None, metadata: Optional[dict] = None): """Update the name of the collection Args: @@ -1050,12 +1192,17 @@ def update(self, name: str): Returns: boolean: Whether the session was successfully updated """ - info = {"name": name} + if metadata is None and name is None: + raise ValueError("metadata and name cannot both be None") + info = {"name": name, "metadata": metadata} url = f"{self.base_url}" response = self.user.honcho.client.put(url, json=info) response.raise_for_status() success = response.status_code < 400 - self.name = name + if name is not None: + self.name = name + if metadata is not None: + self.metadata = metadata return success def delete(self): @@ -1113,7 +1260,11 @@ def get_document(self, document_id: uuid.UUID) -> Document: ) def get_documents( - self, page: int = 1, page_size: int = 50, reverse: bool = False + self, + filter: Optional[dict] = None, + page: int = 1, + page_size: int = 50, + reverse: bool = False, ) -> GetDocumentPage: """Get all documents for a collection @@ -1132,12 +1283,17 @@ def get_documents( "size": page_size, "reverse": reverse, } + if filter is not None: + json_filter = json.dumps(filter) + params["filter"] = json_filter response = self.user.honcho.client.get(url, params=params) response.raise_for_status() data = response.json() - return GetDocumentPage(data, self, reverse) + return GetDocumentPage(data, self, filter, reverse) - def get_documents_generator(self, reverse: bool = False): + def get_documents_generator( + self, filter: Optional[dict] = None, reverse: bool = False + ): """Shortcut Generator for get_documents. Generator to iterate through all documents for a collection in an app @@ -1147,7 +1303,7 @@ def get_documents_generator(self, reverse: bool = False): """ page = 1 page_size = 50 - get_documents_page = self.get_documents(page, page_size, reverse) + get_documents_page = self.get_documents(filter, page, page_size, reverse) while True: for document in get_documents_page.items: yield document @@ -1185,7 +1341,10 @@ def query(self, query: str, top_k: int = 5) -> list[Document]: return data def update_document( - self, document: Document, content: Optional[str], metadata: Optional[dict] + self, + document: Document, + content: Optional[str] = None, + metadata: Optional[dict] = None, ) -> Document: """Update a document in the collection diff --git a/sdk/tests/test_async.py b/sdk/tests/test_async.py index 04ffc4a..f18ac99 100644 --- a/sdk/tests/test_async.py +++ b/sdk/tests/test_async.py @@ -15,6 +15,45 @@ from honcho import AsyncHoncho as Honcho +@pytest.mark.asyncio +async def test_session_metadata_filter(): + app_name = str(uuid1()) + user_name = str(uuid1()) + honcho = Honcho(app_name, "http://localhost:8000") + await honcho.initialize() + user = await honcho.create_user(user_name) + await user.create_session() + await user.create_session(metadata={"foo": "bar"}) + await user.create_session(metadata={"foo": "bar"}) + + response = await user.get_sessions(filter={"foo": "bar"}) + retrieved_sessions = response.items + + assert len(retrieved_sessions) == 2 + + response = await user.get_sessions() + + assert len(response.items) == 3 + + +@pytest.mark.asyncio +async def test_delete_session_metadata(): + app_name = str(uuid1()) + user_name = str(uuid1()) + honcho = Honcho(app_name, "http://localhost:8000") + await honcho.initialize() + user = await honcho.create_user(user_name) + retrieved_session = await user.create_session(metadata={"foo": "bar"}) + + assert retrieved_session.metadata == {"foo": "bar"} + + await retrieved_session.update(metadata={}) + + session_copy = await user.get_session(retrieved_session.id) + + assert session_copy.metadata == {} + + @pytest.mark.asyncio async def test_user_update(): user_name = str(uuid1()) diff --git a/sdk/tests/test_sync.py b/sdk/tests/test_sync.py index fd92234..0077ff0 100644 --- a/sdk/tests/test_sync.py +++ b/sdk/tests/test_sync.py @@ -15,6 +15,43 @@ from honcho import Honcho as Honcho +def test_session_metadata_filter(): + app_name = str(uuid1()) + user_name = str(uuid1()) + honcho = Honcho(app_name, "http://localhost:8000") + honcho.initialize() + user = honcho.create_user(user_name) + user.create_session() + user.create_session(metadata={"foo": "bar"}) + user.create_session(metadata={"foo": "bar"}) + + response = user.get_sessions(filter={"foo": "bar"}) + retrieved_sessions = response.items + + assert len(retrieved_sessions) == 2 + + response = user.get_sessions() + + assert len(response.items) == 3 + + +def test_delete_session_metadata(): + app_name = str(uuid1()) + user_name = str(uuid1()) + honcho = Honcho(app_name, "http://localhost:8000") + honcho.initialize() + user = honcho.create_user(user_name) + retrieved_session = user.create_session(metadata={"foo": "bar"}) + + assert retrieved_session.metadata == {"foo": "bar"} + + retrieved_session.update(metadata={}) + + session_copy = user.get_session(retrieved_session.id) + + assert session_copy.metadata == {} + + def test_user_update(): user_name = str(uuid1()) app_name = str(uuid1())