Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prompt information isolation #845

Merged
merged 2 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pilot/server/knowledge/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def space_delete(request: KnowledgeSpaceRequest, user_token: UserRequest = Depen
KnowledgeSpaceRequest(user_id=user_token.user_id, name=request.name))
if len(spaces) == 0:
return Result.faild(code="E000X",
msg=f"knowledge_space {request.name} cannot not be found by user {user_token.user_id}")
msg=f"knowledge_space {request.name} can not be found by user {user_token.user_id}")
print(f"/space/delete params:")
try:
return Result.succ(knowledge_space_service.delete_space(spaces[0].id))
Expand Down Expand Up @@ -97,7 +97,7 @@ def arguments_save(space_name: str, argument_request: SpaceArgumentRequest, user
def document_add(space_name: str, request: KnowledgeDocumentRequest, user_token: UserRequest = Depends(get_user_from_headers)):
spaces = knowledge_space_service.get_knowledge_space(KnowledgeSpaceRequest(user_id=user_token.user_id, name=space_name))
if len(spaces) == 0:
return Result.faild(code="E000X", msg=f"knowledge_space {space_name} cannot not be found by user {user_token.user_id}")
return Result.faild(code="E000X", msg=f"knowledge_space {space_name} can not be found by user {user_token.user_id}")

print(f"/document/add params: {space_name}, {request}, {user_token.user_id}")
try:
Expand All @@ -116,7 +116,7 @@ def document_list(space_name: str, query_request: DocumentQueryRequest, user_tok
KnowledgeSpaceRequest(user_id=user_token.user_id, name=space_name))
if len(spaces) == 0:
return Result.faild(code="E000X",
msg=f"knowledge_space {space_name} cannot not be found by user {user_token.user_id}")
msg=f"knowledge_space {space_name} can not be found by user {user_token.user_id}")
print(f"/document/list params: {space_name}, {query_request}")
try:
return Result.succ(
Expand All @@ -132,7 +132,7 @@ def document_delete(space_name: str, query_request: DocumentQueryRequest, user_t
KnowledgeSpaceRequest(user_id=user_token.user_id, name=space_name))
if len(spaces) == 0:
return Result.faild(code="E000X",
msg=f"knowledge_space {space_name} cannot not be found by user {user_token.user_id}")
msg=f"knowledge_space {space_name} can not be found by user {user_token.user_id}")
print(f"/document/list params: {space_name}, {query_request}")
try:
return Result.succ(
Expand All @@ -154,7 +154,7 @@ async def document_upload(
KnowledgeSpaceRequest(user_id=user_token.user_id, name=space_name))
if len(spaces) == 0:
return Result.faild(code="E000X",
msg=f"knowledge_space {space_name} cannot not be found by user {user_token.user_id}")
msg=f"knowledge_space {space_name} can not be found by user {user_token.user_id}")
print(f"/document/upload params: {space_name}")
try:
space_name_dir = space_name + user_token.user_id
Expand Down Expand Up @@ -194,7 +194,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest, user_token: Use
KnowledgeSpaceRequest(user_id=user_token.user_id, name=space_name))
if len(spaces) == 0:
return Result.faild(code="E000X",
msg=f"knowledge_space {space_name} cannot not be found by user {user_token.user_id}")
msg=f"knowledge_space {space_name} can not be found by user {user_token.user_id}")
logger.info(f"Received params: {space_name}, {request}")
try:
knowledge_space_service.sync_knowledge_document(
Expand Down
16 changes: 10 additions & 6 deletions pilot/server/prompt/api.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from fastapi import APIRouter, File, UploadFile, Form
from fastapi import APIRouter, File, UploadFile, Form, Depends

from pilot.openapi.api_view_model import Result
from pilot.server.prompt.service import PromptManageService
from pilot.server.prompt.request.request import PromptManageRequest
from pilot.user import UserRequest, get_user_from_headers

router = APIRouter()

prompt_manage_service = PromptManageService()


@router.post("/prompt/add")
def prompt_add(request: PromptManageRequest):
def prompt_add(request: PromptManageRequest, user_token: UserRequest = Depends(get_user_from_headers)):
print(f"/space/add params: {request}")
request.user_id = user_token.user_id
try:
prompt_manage_service.create_prompt(request)
return Result.succ([])
Expand All @@ -20,27 +22,29 @@ def prompt_add(request: PromptManageRequest):


@router.post("/prompt/list")
def prompt_list(request: PromptManageRequest):
def prompt_list(request: PromptManageRequest, user_token: UserRequest = Depends(get_user_from_headers)):
print(f"/prompt/list params: {request}")
try:
request.user_id = user_token.user_id
return Result.succ(prompt_manage_service.get_prompts(request))
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt list error {e}")


@router.post("/prompt/update")
def prompt_update(request: PromptManageRequest):
def prompt_update(request: PromptManageRequest, user_token: UserRequest = Depends(get_user_from_headers)):
print(f"/prompt/update params: {request}")
request.user_id = user_token.user_id
try:
return Result.succ(prompt_manage_service.update_prompt(request))
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt update error {e}")


@router.post("/prompt/delete")
def prompt_delete(request: PromptManageRequest):
def prompt_delete(request: PromptManageRequest, user_token: UserRequest = Depends(get_user_from_headers)):
print(f"/prompt/delete params: {request}")
try:
return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name))
return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name, user_token.user_id))
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt delete error {e}")
8 changes: 7 additions & 1 deletion pilot/server/prompt/prompt_manage_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ class PromptManageEntity(Base):
user_name = Column(String(128))
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
user_id = Column(String(100))

def __repr__(self):
return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}', user_id='{self.user_id}')"


class PromptManageDao(BaseDao):
Expand All @@ -48,6 +49,7 @@ def create_prompt(self, prompt: PromptManageRequest):
user_name=prompt.user_name,
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
user_id=prompt.user_id,
)
session.add(prompt_manage)
session.commit()
Expand All @@ -74,6 +76,10 @@ def get_prompts(self, query: PromptManageEntity):
prompts = prompts.filter(
PromptManageEntity.prompt_name == query.prompt_name
)
if query.user_id is not None:
prompts = prompts.filter(
PromptManageEntity.user_id == query.user_id
)

prompts = prompts.order_by(PromptManageEntity.gmt_created.desc())
result = prompts.all()
Expand Down
3 changes: 3 additions & 0 deletions pilot/server/prompt/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ class PromptManageRequest(BaseModel):

"""prompt_name: prompt name"""
prompt_name: str = None

"""用户id"""
user_id: str = None
8 changes: 5 additions & 3 deletions pilot/server/prompt/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self):
def create_prompt(self, request: PromptManageRequest):
query = PromptManageRequest(
prompt_name=request.prompt_name,
user_id=request.user_id,
)
prompt_name = prompt_manage_dao.get_prompts(query)
if len(prompt_name) > 0:
Expand All @@ -32,6 +33,7 @@ def get_prompts(self, request: PromptManageRequest):
prompt_type=request.prompt_type,
prompt_name=request.prompt_name,
user_name=request.user_name,
user_id=request.user_id
)
responses = []
prompts = prompt_manage_dao.get_prompts(query)
Expand All @@ -53,7 +55,7 @@ def get_prompts(self, request: PromptManageRequest):
"""update prompt"""

def update_prompt(self, request: PromptManageRequest):
query = PromptManageEntity(prompt_name=request.prompt_name)
query = PromptManageEntity(prompt_name=request.prompt_name, user_id=request.user_id)
prompts = prompt_manage_dao.get_prompts(query)
if len(prompts) != 1:
raise Exception(
Expand All @@ -70,8 +72,8 @@ def update_prompt(self, request: PromptManageRequest):

"""delete prompt"""

def delete_prompt(self, prompt_name: str):
query = PromptManageEntity(prompt_name=prompt_name)
def delete_prompt(self, prompt_name: str, user_id: str):
query = PromptManageEntity(prompt_name=prompt_name, user_id=user_id)
prompts = prompt_manage_dao.get_prompts(query)
if len(prompts) == 0:
raise Exception(f"delete error, no prompt name:{prompt_name} in database ")
Expand Down