diff --git a/pilot/server/knowledge/api.py b/pilot/server/knowledge/api.py index 51f9d5098..c0966cb62 100644 --- a/pilot/server/knowledge/api.py +++ b/pilot/server/knowledge/api.py @@ -65,7 +65,7 @@ def space_delete(request: KnowledgeSpaceRequest, user_token: UserRequest = Depen KnowledgeSpaceRequest(user_id=user_token.user_id, name=request.name)) if len(spaces) == 0: return Result.faild(code="E000X", - msg=f"knowledge_space {request.name} cannot not be found by user {user_token.user_id}") + msg=f"knowledge_space {request.name} can not be found by user {user_token.user_id}") print(f"/space/delete params:") try: return Result.succ(knowledge_space_service.delete_space(spaces[0].id)) @@ -97,7 +97,7 @@ def arguments_save(space_name: str, argument_request: SpaceArgumentRequest, user def document_add(space_name: str, request: KnowledgeDocumentRequest, user_token: UserRequest = Depends(get_user_from_headers)): spaces = knowledge_space_service.get_knowledge_space(KnowledgeSpaceRequest(user_id=user_token.user_id, name=space_name)) if len(spaces) == 0: - return Result.faild(code="E000X", msg=f"knowledge_space {space_name} cannot not be found by user {user_token.user_id}") + return Result.faild(code="E000X", msg=f"knowledge_space {space_name} can not be found by user {user_token.user_id}") print(f"/document/add params: {space_name}, {request}, {user_token.user_id}") try: @@ -116,7 +116,7 @@ def document_list(space_name: str, query_request: DocumentQueryRequest, user_tok KnowledgeSpaceRequest(user_id=user_token.user_id, name=space_name)) if len(spaces) == 0: return Result.faild(code="E000X", - msg=f"knowledge_space {space_name} cannot not be found by user {user_token.user_id}") + msg=f"knowledge_space {space_name} can not be found by user {user_token.user_id}") print(f"/document/list params: {space_name}, {query_request}") try: return Result.succ( @@ -132,7 +132,7 @@ def document_delete(space_name: str, query_request: DocumentQueryRequest, user_t KnowledgeSpaceRequest(user_id=user_token.user_id, name=space_name)) if len(spaces) == 0: return Result.faild(code="E000X", - msg=f"knowledge_space {space_name} cannot not be found by user {user_token.user_id}") + msg=f"knowledge_space {space_name} can not be found by user {user_token.user_id}") print(f"/document/list params: {space_name}, {query_request}") try: return Result.succ( @@ -154,7 +154,7 @@ async def document_upload( KnowledgeSpaceRequest(user_id=user_token.user_id, name=space_name)) if len(spaces) == 0: return Result.faild(code="E000X", - msg=f"knowledge_space {space_name} cannot not be found by user {user_token.user_id}") + msg=f"knowledge_space {space_name} can not be found by user {user_token.user_id}") print(f"/document/upload params: {space_name}") try: space_name_dir = space_name + user_token.user_id @@ -194,7 +194,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest, user_token: Use KnowledgeSpaceRequest(user_id=user_token.user_id, name=space_name)) if len(spaces) == 0: return Result.faild(code="E000X", - msg=f"knowledge_space {space_name} cannot not be found by user {user_token.user_id}") + msg=f"knowledge_space {space_name} can not be found by user {user_token.user_id}") logger.info(f"Received params: {space_name}, {request}") try: knowledge_space_service.sync_knowledge_document( diff --git a/pilot/server/knowledge/space_db.py b/pilot/server/knowledge/space_db.py index f76b2f6ae..3f8ad957d 100644 --- a/pilot/server/knowledge/space_db.py +++ b/pilot/server/knowledge/space_db.py @@ -56,7 +56,7 @@ def get_knowledge_space(self, query: KnowledgeSpaceEntity): knowledge_spaces = session.query(KnowledgeSpaceEntity) if query.user_id is not None: knowledge_spaces = knowledge_spaces.filter( - KnowledgeSpaceEntity.user_id == query.user_id or KnowledgeSpaceEntity.user_id is None or KnowledgeSpaceEntity.user_id == '' + KnowledgeSpaceEntity.user_id == query.user_id ) if query.id is not None: knowledge_spaces = knowledge_spaces.filter( diff --git a/pilot/server/prompt/api.py b/pilot/server/prompt/api.py index b94546891..18189f6e1 100644 --- a/pilot/server/prompt/api.py +++ b/pilot/server/prompt/api.py @@ -1,8 +1,9 @@ -from fastapi import APIRouter, File, UploadFile, Form +from fastapi import APIRouter, File, UploadFile, Form, Depends from pilot.openapi.api_view_model import Result from pilot.server.prompt.service import PromptManageService from pilot.server.prompt.request.request import PromptManageRequest +from pilot.user import UserRequest, get_user_from_headers router = APIRouter() @@ -10,8 +11,9 @@ @router.post("/prompt/add") -def prompt_add(request: PromptManageRequest): +def prompt_add(request: PromptManageRequest, user_token: UserRequest = Depends(get_user_from_headers)): print(f"/space/add params: {request}") + request.user_id = user_token.user_id try: prompt_manage_service.create_prompt(request) return Result.succ([]) @@ -20,17 +22,19 @@ def prompt_add(request: PromptManageRequest): @router.post("/prompt/list") -def prompt_list(request: PromptManageRequest): +def prompt_list(request: PromptManageRequest, user_token: UserRequest = Depends(get_user_from_headers)): print(f"/prompt/list params: {request}") try: + request.user_id = user_token.user_id return Result.succ(prompt_manage_service.get_prompts(request)) except Exception as e: return Result.faild(code="E010X", msg=f"prompt list error {e}") @router.post("/prompt/update") -def prompt_update(request: PromptManageRequest): +def prompt_update(request: PromptManageRequest, user_token: UserRequest = Depends(get_user_from_headers)): print(f"/prompt/update params: {request}") + request.user_id = user_token.user_id try: return Result.succ(prompt_manage_service.update_prompt(request)) except Exception as e: @@ -38,9 +42,9 @@ def prompt_update(request: PromptManageRequest): @router.post("/prompt/delete") -def prompt_delete(request: PromptManageRequest): +def prompt_delete(request: PromptManageRequest, user_token: UserRequest = Depends(get_user_from_headers)): print(f"/prompt/delete params: {request}") try: - return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name)) + return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name, user_token.user_id)) except Exception as e: return Result.faild(code="E010X", msg=f"prompt delete error {e}") diff --git a/pilot/server/prompt/prompt_manage_db.py b/pilot/server/prompt/prompt_manage_db.py index ea482a3bb..28828e43a 100644 --- a/pilot/server/prompt/prompt_manage_db.py +++ b/pilot/server/prompt/prompt_manage_db.py @@ -26,9 +26,10 @@ class PromptManageEntity(Base): user_name = Column(String(128)) gmt_created = Column(DateTime) gmt_modified = Column(DateTime) + user_id = Column(String(100)) def __repr__(self): - return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}', user_id='{self.user_id}')" class PromptManageDao(BaseDao): @@ -48,6 +49,7 @@ def create_prompt(self, prompt: PromptManageRequest): user_name=prompt.user_name, gmt_created=datetime.now(), gmt_modified=datetime.now(), + user_id=prompt.user_id, ) session.add(prompt_manage) session.commit() @@ -74,6 +76,10 @@ def get_prompts(self, query: PromptManageEntity): prompts = prompts.filter( PromptManageEntity.prompt_name == query.prompt_name ) + if query.user_id is not None: + prompts = prompts.filter( + PromptManageEntity.user_id == query.user_id + ) prompts = prompts.order_by(PromptManageEntity.gmt_created.desc()) result = prompts.all() diff --git a/pilot/server/prompt/request/request.py b/pilot/server/prompt/request/request.py index c1b0683ec..2a1dbe08f 100644 --- a/pilot/server/prompt/request/request.py +++ b/pilot/server/prompt/request/request.py @@ -22,3 +22,6 @@ class PromptManageRequest(BaseModel): """prompt_name: prompt name""" prompt_name: str = None + + """用户id""" + user_id: str = None diff --git a/pilot/server/prompt/service.py b/pilot/server/prompt/service.py index c108d8b88..8eb522796 100644 --- a/pilot/server/prompt/service.py +++ b/pilot/server/prompt/service.py @@ -16,6 +16,7 @@ def __init__(self): def create_prompt(self, request: PromptManageRequest): query = PromptManageRequest( prompt_name=request.prompt_name, + user_id=request.user_id, ) prompt_name = prompt_manage_dao.get_prompts(query) if len(prompt_name) > 0: @@ -32,6 +33,7 @@ def get_prompts(self, request: PromptManageRequest): prompt_type=request.prompt_type, prompt_name=request.prompt_name, user_name=request.user_name, + user_id=request.user_id ) responses = [] prompts = prompt_manage_dao.get_prompts(query) @@ -53,7 +55,7 @@ def get_prompts(self, request: PromptManageRequest): """update prompt""" def update_prompt(self, request: PromptManageRequest): - query = PromptManageEntity(prompt_name=request.prompt_name) + query = PromptManageEntity(prompt_name=request.prompt_name, user_id=request.user_id) prompts = prompt_manage_dao.get_prompts(query) if len(prompts) != 1: raise Exception( @@ -70,8 +72,8 @@ def update_prompt(self, request: PromptManageRequest): """delete prompt""" - def delete_prompt(self, prompt_name: str): - query = PromptManageEntity(prompt_name=prompt_name) + def delete_prompt(self, prompt_name: str, user_id: str): + query = PromptManageEntity(prompt_name=prompt_name, user_id=user_id) prompts = prompt_manage_dao.get_prompts(query) if len(prompts) == 0: raise Exception(f"delete error, no prompt name:{prompt_name} in database ")