Skip to content
This repository has been archived by the owner on Dec 11, 2024. It is now read-only.

Commit

Permalink
Merge branch 'danswer-ai-main'
Browse files Browse the repository at this point in the history
  • Loading branch information
onimsha committed Apr 23, 2024
2 parents b4cca06 + c6d7ee2 commit a1c8feb
Show file tree
Hide file tree
Showing 64 changed files with 2,155 additions and 1,599 deletions.
10 changes: 3 additions & 7 deletions backend/danswer/chat/load_yamls.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
for prompt in all_prompts:
upsert_prompt(
user_id=None,
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
Expand All @@ -34,7 +34,6 @@ def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
shared=True,
db_session=db_session,
commit=True,
)
Expand Down Expand Up @@ -67,9 +66,7 @@ def load_personas_from_yaml(
prompts: list[PromptDBModel | None] | None = None
else:
prompts = [
get_prompt_by_name(
prompt_name, user_id=None, shared=True, db_session=db_session
)
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
Expand All @@ -80,7 +77,7 @@ def load_personas_from_yaml(

p_id = persona.get("id")
upsert_persona(
user_id=None,
user=None,
# Negative to not conflict with existing personas
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
Expand All @@ -96,7 +93,6 @@ def load_personas_from_yaml(
prompts=cast(list[PromptDBModel] | None, prompts),
document_sets=doc_sets,
default_persona=True,
shared=True,
is_public=True,
db_session=db_session,
)
Expand Down
79 changes: 38 additions & 41 deletions backend/danswer/db/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import Session

from danswer.auth.schemas import UserRole
from danswer.configs.chat_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
Expand All @@ -27,6 +28,7 @@
from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import StarterMessage
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
Expand Down Expand Up @@ -313,13 +315,16 @@ def set_as_latest_chat_message(

def get_prompt_by_id(
prompt_id: int,
user_id: UUID | None,
user: User | None,
db_session: Session,
include_deleted: bool = False,
) -> Prompt:
stmt = select(Prompt).where(
Prompt.id == prompt_id, or_(Prompt.user_id == user_id, Prompt.user_id.is_(None))
)
stmt = select(Prompt).where(Prompt.id == prompt_id)

# if user is not specified OR they are an admin, they should
# have access to all prompts, so this where clause is not needed
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(or_(Prompt.user_id == user.id, Prompt.user_id.is_(None)))

if not include_deleted:
stmt = stmt.where(Prompt.deleted.is_(False))
Expand Down Expand Up @@ -351,14 +356,16 @@ def get_default_prompt() -> Prompt:

def get_persona_by_id(
persona_id: int,
# if user_id is `None` assume the user is an admin or auth is disabled
user_id: UUID | None,
# if user is `None` assume the user is an admin or auth is disabled
user: User | None,
db_session: Session,
include_deleted: bool = False,
) -> Persona:
stmt = select(Persona).where(Persona.id == persona_id)
if user_id is not None:
stmt = stmt.where(or_(Persona.user_id == user_id, Persona.user_id.is_(None)))

# if user is an admin, they should have access to all Personas
if user is not None and user.role != UserRole.ADMIN:
stmt = stmt.where(or_(Persona.user_id == user.id, Persona.user_id.is_(None)))

if not include_deleted:
stmt = stmt.where(Persona.deleted.is_(False))
Expand Down Expand Up @@ -397,41 +404,40 @@ def get_personas_by_ids(


def get_prompt_by_name(
prompt_name: str, user_id: UUID | None, shared: bool, db_session: Session
prompt_name: str, user: User | None, db_session: Session
) -> Prompt | None:
"""Cannot do shared and user owned simultaneously as there may be two of those"""
stmt = select(Prompt).where(Prompt.name == prompt_name)
if shared:
stmt = stmt.where(Prompt.user_id.is_(None))
else:
stmt = stmt.where(Prompt.user_id == user_id)

# if user is not specified OR they are an admin, they should
# have access to all prompts, so this where clause is not needed
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(Prompt.user_id == user.id)

result = db_session.execute(stmt).scalar_one_or_none()
return result


def get_persona_by_name(
persona_name: str, user_id: UUID | None, shared: bool, db_session: Session
persona_name: str, user: User | None, db_session: Session
) -> Persona | None:
"""Cannot do shared and user owned simultaneously as there may be two of those"""
"""Admins can see all, regular users can only fetch their own.
If user is None, assume the user is an admin or auth is disabled."""
stmt = select(Persona).where(Persona.name == persona_name)
if shared:
stmt = stmt.where(Persona.user_id.is_(None))
else:
stmt = stmt.where(Persona.user_id == user_id)
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(Persona.user_id == user.id)
result = db_session.execute(stmt).scalar_one_or_none()
return result


def upsert_prompt(
user_id: UUID | None,
user: User | None,
name: str,
description: str,
system_prompt: str,
task_prompt: str,
include_citations: bool,
datetime_aware: bool,
personas: list[Persona] | None,
shared: bool,
db_session: Session,
prompt_id: int | None = None,
default_prompt: bool = True,
Expand All @@ -440,9 +446,7 @@ def upsert_prompt(
if prompt_id is not None:
prompt = db_session.query(Prompt).filter_by(id=prompt_id).first()
else:
prompt = get_prompt_by_name(
prompt_name=name, user_id=user_id, shared=shared, db_session=db_session
)
prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session)

if prompt:
if not default_prompt and prompt.default_prompt:
Expand All @@ -463,7 +467,7 @@ def upsert_prompt(
else:
prompt = Prompt(
id=prompt_id,
user_id=None if shared else user_id,
user_id=user.id if user else None,
name=name,
description=description,
system_prompt=system_prompt,
Expand All @@ -485,7 +489,7 @@ def upsert_prompt(


def upsert_persona(
user_id: UUID | None,
user: User | None,
name: str,
description: str,
num_chunks: float,
Expand All @@ -496,7 +500,6 @@ def upsert_persona(
document_sets: list[DBDocumentSet] | None,
llm_model_version_override: str | None,
starter_messages: list[StarterMessage] | None,
shared: bool,
is_public: bool,
db_session: Session,
persona_id: int | None = None,
Expand All @@ -507,7 +510,7 @@ def upsert_persona(
persona = db_session.query(Persona).filter_by(id=persona_id).first()
else:
persona = get_persona_by_name(
persona_name=name, user_id=user_id, shared=shared, db_session=db_session
persona_name=name, user=user, db_session=db_session
)

if persona:
Expand Down Expand Up @@ -539,7 +542,7 @@ def upsert_persona(
else:
persona = Persona(
id=persona_id,
user_id=None if shared else user_id,
user_id=user.id if user else None,
is_public=is_public,
name=name,
description=description,
Expand All @@ -566,24 +569,20 @@ def upsert_persona(

def mark_prompt_as_deleted(
prompt_id: int,
user_id: UUID | None,
user: User | None,
db_session: Session,
) -> None:
prompt = get_prompt_by_id(
prompt_id=prompt_id, user_id=user_id, db_session=db_session
)
prompt = get_prompt_by_id(prompt_id=prompt_id, user=user, db_session=db_session)
prompt.deleted = True
db_session.commit()


def mark_persona_as_deleted(
persona_id: int,
user_id: UUID | None,
user: User | None,
db_session: Session,
) -> None:
persona = get_persona_by_id(
persona_id=persona_id, user_id=user_id, db_session=db_session
)
persona = get_persona_by_id(persona_id=persona_id, user=user, db_session=db_session)
persona.deleted = True
db_session.commit()

Expand Down Expand Up @@ -621,9 +620,7 @@ def update_persona_visibility(
is_visible: bool,
db_session: Session,
) -> None:
persona = get_persona_by_id(
persona_id=persona_id, user_id=None, db_session=db_session
)
persona = get_persona_by_id(persona_id=persona_id, user=None, db_session=db_session)
persona.is_visible = is_visible
db_session.commit()

Expand Down
4 changes: 1 addition & 3 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,6 @@ class Prompt(Base):
__tablename__ = "prompt"

id: Mapped[int] = mapped_column(primary_key=True)
# If not belong to a user, then it's shared
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
name: Mapped[str] = mapped_column(String)
description: Mapped[str] = mapped_column(String)
Expand Down Expand Up @@ -770,7 +769,6 @@ class Persona(Base):
__tablename__ = "persona"

id: Mapped[int] = mapped_column(primary_key=True)
# If not belong to a user, then it's shared
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
name: Mapped[str] = mapped_column(String)
description: Mapped[str] = mapped_column(String)
Expand Down Expand Up @@ -824,7 +822,7 @@ class Persona(Base):
back_populates="personas",
)
# Owner
user: Mapped[User] = relationship("User", back_populates="personas")
user: Mapped[User | None] = relationship("User", back_populates="personas")
# Other users with access
users: Mapped[list[User]] = relationship(
"User",
Expand Down
20 changes: 14 additions & 6 deletions backend/danswer/db/persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from danswer.db.chat import get_prompts_by_ids
from danswer.db.chat import upsert_persona
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.models import Persona__User
from danswer.db.models import User
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSnapshot
Expand All @@ -21,9 +22,19 @@ def make_persona_private(
group_ids: list[int] | None,
db_session: Session,
) -> None:
if user_ids is not None:
db_session.query(Persona__User).filter(
Persona__User.persona_id == persona_id
).delete(synchronize_session="fetch")

for user_uuid in user_ids:
db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid))

db_session.commit()

# May cause error if someone switches down to MIT from EE
if user_ids or group_ids:
raise NotImplementedError("Danswer MIT does not support private Document Sets")
if group_ids:
raise NotImplementedError("Danswer MIT does not support private Personas")


def create_update_persona(
Expand All @@ -32,8 +43,6 @@ def create_update_persona(
user: User | None,
db_session: Session,
) -> PersonaSnapshot:
user_id = user.id if user is not None else None

# Permission to actually use these is checked later
document_sets = list(
get_document_sets_by_ids(
Expand All @@ -51,7 +60,7 @@ def create_update_persona(
try:
persona = upsert_persona(
persona_id=persona_id,
user_id=user_id,
user=user,
name=create_persona_request.name,
description=create_persona_request.description,
num_chunks=create_persona_request.num_chunks,
Expand All @@ -62,7 +71,6 @@ def create_update_persona(
document_sets=document_sets,
llm_model_version_override=create_persona_request.llm_model_version_override,
starter_messages=create_persona_request.starter_messages,
shared=create_persona_request.shared,
is_public=create_persona_request.is_public,
db_session=db_session,
)
Expand Down
3 changes: 1 addition & 2 deletions backend/danswer/db/slack_bot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def create_slack_bot_persona(
# create/update persona associated with the slack bot
persona_name = _build_persona_name(channel_names)
persona = upsert_persona(
user_id=None, # Slack Bot Personas are not attached to users
user=None, # Slack Bot Personas are not attached to users
persona_id=existing_persona_id,
name=persona_name,
description="",
Expand All @@ -61,7 +61,6 @@ def create_slack_bot_persona(
document_sets=document_sets,
llm_model_version_override=None,
starter_messages=None,
shared=True,
is_public=True,
default_persona=False,
db_session=db_session,
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/one_shot_answer/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def stream_answer_objects(
prompt = None
if query_req.prompt_id is not None:
prompt = get_prompt_by_id(
prompt_id=query_req.prompt_id, user_id=user_id, db_session=db_session
prompt_id=query_req.prompt_id, user=user, db_session=db_session
)
if prompt is None:
if not chat_session.persona.prompts:
Expand Down
Loading

0 comments on commit a1c8feb

Please sign in to comment.