Skip to content

Commit

Permalink
Metadata filtering for all fixes dev-261
Browse files Browse the repository at this point in the history
  • Loading branch information
VVoruganti committed Mar 14, 2024
1 parent 53ef28d commit 8cac9b7
Show file tree
Hide file tree
Showing 9 changed files with 755 additions and 87 deletions.
101 changes: 95 additions & 6 deletions api/src/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,16 @@ async def get_user_by_name(


async def get_users(
db: AsyncSession, app_id: uuid.UUID, reverse: bool = False
db: AsyncSession,
app_id: uuid.UUID,
reverse: bool = False,
filter: Optional[dict] = None,
) -> Select:
stmt = select(models.User).where(models.User.app_id == app_id)

if filter is not None:
stmt = stmt.where(models.User.h_metadata.contains(filter))

if reverse:
stmt = stmt.order_by(models.User.created_at.desc())
else:
Expand Down Expand Up @@ -180,6 +187,7 @@ async def get_sessions(
location_id: Optional[str] = None,
reverse: Optional[bool] = False,
is_active: Optional[bool] = False,
filter: Optional[dict] = None,
) -> Select:
stmt = (
select(models.Session)
Expand All @@ -191,6 +199,9 @@ async def get_sessions(
if is_active:
stmt = stmt.where(models.Session.is_active.is_(True))

if filter is not None:
stmt = stmt.where(models.Session.h_metadata.contains(filter))

if reverse:
stmt = stmt.order_by(models.Session.created_at.desc())
else:
Expand All @@ -205,9 +216,12 @@ async def get_sessions(
async def create_session(
db: AsyncSession,
session: schemas.SessionCreate,
app_id: uuid.UUID, # TODO check if app id is associated with the right user
app_id: uuid.UUID,
user_id: uuid.UUID,
) -> models.Session:
honcho_user = await get_user(db, app_id=app_id, user_id=user_id)
if honcho_user is None:
raise ValueError("User not found")
honcho_session = models.Session(
user_id=user_id,
location_id=session.location_id,
Expand Down Expand Up @@ -281,6 +295,7 @@ async def create_message(
session_id=session_id,
is_user=message.is_user,
content=message.content,
h_metadata=message.metadata,
)
db.add(honcho_message)
await db.commit()
Expand All @@ -294,6 +309,7 @@ async def get_messages(
user_id: uuid.UUID,
session_id: uuid.UUID,
reverse: Optional[bool] = False,
filter: Optional[dict] = None,
) -> Select:
stmt = (
select(models.Message)
Expand All @@ -305,6 +321,9 @@ async def get_messages(
.where(models.Message.session_id == session_id)
)

if filter is not None:
stmt = stmt.where(models.Message.h_metadata.contains(filter))

if reverse:
stmt = stmt.order_by(models.Message.created_at.desc())
else:
Expand Down Expand Up @@ -334,6 +353,28 @@ async def get_message(
return result.scalar_one_or_none()


async def update_message(
db: AsyncSession,
message: schemas.MessageUpdate,
app_id: uuid.UUID,
user_id: uuid.UUID,
session_id: uuid.UUID,
message_id: uuid.UUID,
) -> bool:
honcho_message = await get_message(
db, app_id=app_id, session_id=session_id, user_id=user_id, message_id=message_id
)
if honcho_message is None:
raise ValueError("Message not found or does not belong to user")
if (
message.metadata is not None
): # Need to explicitly be there won't make it empty by default
honcho_message.h_metadata = message.metadata
await db.commit()
await db.refresh(honcho_message)
return honcho_message


########################################################
# metamessage methods
########################################################
Expand All @@ -360,6 +401,7 @@ async def create_metamessage(
message_id=metamessage.message_id,
metamessage_type=metamessage.metamessage_type,
content=metamessage.content,
h_metadata=metamessage.metadata,
)

db.add(honcho_metamessage)
Expand All @@ -375,6 +417,7 @@ async def get_metamessages(
session_id: uuid.UUID,
message_id: Optional[uuid.UUID],
metamessage_type: Optional[str] = None,
filter: Optional[dict] = None,
reverse: Optional[bool] = False,
) -> Select:
stmt = (
Expand All @@ -394,6 +437,9 @@ async def get_metamessages(
if metamessage_type is not None:
stmt = stmt.where(models.Metamessage.metamessage_type == metamessage_type)

if filter is not None:
stmt = stmt.where(models.Metamessage.h_metadata.contains(filter))

if reverse:
stmt = stmt.order_by(models.Metamessage.created_at.desc())
else:
Expand Down Expand Up @@ -426,6 +472,35 @@ async def get_metamessage(
return result.scalar_one_or_none()


async def update_metamessage(
db: AsyncSession,
metamessage: schemas.MetamessageUpdate,
app_id: uuid.UUID,
user_id: uuid.UUID,
session_id: uuid.UUID,
metamessage_id: uuid.UUID,
) -> bool:
honcho_metamessage = await get_metamessage(
db,
app_id=app_id,
session_id=session_id,
user_id=user_id,
message_id=metamessage.message_id,
metamessage_id=metamessage_id,
)
if honcho_metamessage is None:
raise ValueError("Metamessage not found or does not belong to user")
if (
metamessage.metadata is not None
): # Need to explicitly be there won't make it empty by default
honcho_metamessage.h_metadata = metamessage.metadata
if metamessage.metamessage_type is not None:
honcho_metamessage.metamessage_type = metamessage.metamessage_type
await db.commit()
await db.refresh(honcho_metamessage)
return honcho_metamessage


########################################################
# collection methods
########################################################
Expand All @@ -438,6 +513,7 @@ async def get_collections(
app_id: uuid.UUID,
user_id: uuid.UUID,
reverse: Optional[bool] = False,
filter: Optional[dict] = None,
) -> Select:
"""Get a distinct list of the names of collections associated with a user"""
stmt = (
Expand All @@ -447,6 +523,9 @@ async def get_collections(
.where(models.User.id == user_id)
)

if filter is not None:
stmt = stmt.where(models.Collection.h_metadata.contains(filter))

if reverse:
stmt = stmt.order_by(models.Collection.created_at.desc())
else:
Expand Down Expand Up @@ -494,6 +573,7 @@ async def create_collection(
honcho_collection = models.Collection(
user_id=user_id,
name=collection.name,
h_metadata=collection.metadata,
)
try:
db.add(honcho_collection)
Expand All @@ -517,6 +597,8 @@ async def update_collection(
)
if honcho_collection is None:
raise ValueError("collection not found or does not belong to user")
if collection.metadata is not None:
honcho_collection.h_metadata = collection.metadata
try:
honcho_collection.name = collection.name
await db.commit()
Expand Down Expand Up @@ -563,6 +645,7 @@ async def get_documents(
user_id: uuid.UUID,
collection_id: uuid.UUID,
reverse: Optional[bool] = False,
filter: Optional[dict] = None,
) -> Select:
stmt = (
select(models.Document)
Expand All @@ -573,6 +656,9 @@ async def get_documents(
.where(models.Document.collection_id == collection_id)
)

if filter is not None:
stmt = stmt.where(models.Document.h_metadata.contains(filter))

if reverse:
stmt = stmt.order_by(models.Document.created_at.desc())
else:
Expand Down Expand Up @@ -609,6 +695,7 @@ async def query_documents(
user_id: uuid.UUID,
collection_id: uuid.UUID,
query: str,
filter: Optional[dict] = None,
top_k: int = 5,
) -> Sequence[models.Document]:
response = openai_client.embeddings.create(
Expand All @@ -622,11 +709,13 @@ async def query_documents(
.where(models.User.app_id == app_id)
.where(models.User.id == user_id)
.where(models.Document.collection_id == collection_id)
.order_by(models.Document.embedding.cosine_distance(embedding_query))
.limit(top_k)
# .limit(top_k)
)
if filter is not None:
stmt = stmt.where(models.Document.h_metadata.contains(filter))
stmt = stmt.limit(top_k).order_by(
models.Document.embedding.cosine_distance(embedding_query)
)
# if metadata is not None:
# stmt = stmt.where(models.Document.h_metadata.contains(metadata))
result = await db.execute(stmt)
return result.scalars().all()

Expand Down
Loading

0 comments on commit 8cac9b7

Please sign in to comment.