Skip to content

Commit

Permalink
prompt information isolation (#845)
Browse files Browse the repository at this point in the history
Add prompt user information isolation.
  • Loading branch information
xuyuan23 authored Nov 27, 2023
2 parents a0c94c2 + fde5c11 commit 60fa7bf
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 17 deletions.
12 changes: 6 additions & 6 deletions pilot/server/knowledge/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def space_delete(request: KnowledgeSpaceRequest, user_token: UserRequest = Depen
KnowledgeSpaceRequest(user_id=user_token.user_id, name=request.name))
if len(spaces) == 0:
return Result.faild(code="E000X",
msg=f"knowledge_space {request.name} cannot not be found by user {user_token.user_id}")
msg=f"knowledge_space {request.name} can not be found by user {user_token.user_id}")
print(f"/space/delete params:")
try:
return Result.succ(knowledge_space_service.delete_space(spaces[0].id))
Expand Down Expand Up @@ -97,7 +97,7 @@ def arguments_save(space_name: str, argument_request: SpaceArgumentRequest, user
def document_add(space_name: str, request: KnowledgeDocumentRequest, user_token: UserRequest = Depends(get_user_from_headers)):
spaces = knowledge_space_service.get_knowledge_space(KnowledgeSpaceRequest(user_id=user_token.user_id, name=space_name))
if len(spaces) == 0:
return Result.faild(code="E000X", msg=f"knowledge_space {space_name} cannot not be found by user {user_token.user_id}")
return Result.faild(code="E000X", msg=f"knowledge_space {space_name} can not be found by user {user_token.user_id}")

print(f"/document/add params: {space_name}, {request}, {user_token.user_id}")
try:
Expand All @@ -116,7 +116,7 @@ def document_list(space_name: str, query_request: DocumentQueryRequest, user_tok
KnowledgeSpaceRequest(user_id=user_token.user_id, name=space_name))
if len(spaces) == 0:
return Result.faild(code="E000X",
msg=f"knowledge_space {space_name} cannot not be found by user {user_token.user_id}")
msg=f"knowledge_space {space_name} can not be found by user {user_token.user_id}")
print(f"/document/list params: {space_name}, {query_request}")
try:
return Result.succ(
Expand All @@ -132,7 +132,7 @@ def document_delete(space_name: str, query_request: DocumentQueryRequest, user_t
KnowledgeSpaceRequest(user_id=user_token.user_id, name=space_name))
if len(spaces) == 0:
return Result.faild(code="E000X",
msg=f"knowledge_space {space_name} cannot not be found by user {user_token.user_id}")
msg=f"knowledge_space {space_name} can not be found by user {user_token.user_id}")
print(f"/document/list params: {space_name}, {query_request}")
try:
return Result.succ(
Expand All @@ -154,7 +154,7 @@ async def document_upload(
KnowledgeSpaceRequest(user_id=user_token.user_id, name=space_name))
if len(spaces) == 0:
return Result.faild(code="E000X",
msg=f"knowledge_space {space_name} cannot not be found by user {user_token.user_id}")
msg=f"knowledge_space {space_name} can not be found by user {user_token.user_id}")
print(f"/document/upload params: {space_name}")
try:
space_name_dir = space_name + user_token.user_id
Expand Down Expand Up @@ -194,7 +194,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest, user_token: Use
KnowledgeSpaceRequest(user_id=user_token.user_id, name=space_name))
if len(spaces) == 0:
return Result.faild(code="E000X",
msg=f"knowledge_space {space_name} cannot not be found by user {user_token.user_id}")
msg=f"knowledge_space {space_name} can not be found by user {user_token.user_id}")
logger.info(f"Received params: {space_name}, {request}")
try:
knowledge_space_service.sync_knowledge_document(
Expand Down
2 changes: 1 addition & 1 deletion pilot/server/knowledge/space_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_knowledge_space(self, query: KnowledgeSpaceEntity):
knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.user_id is not None:
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.user_id == query.user_id or KnowledgeSpaceEntity.user_id is None or KnowledgeSpaceEntity.user_id == ''
KnowledgeSpaceEntity.user_id == query.user_id
)
if query.id is not None:
knowledge_spaces = knowledge_spaces.filter(
Expand Down
16 changes: 10 additions & 6 deletions pilot/server/prompt/api.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from fastapi import APIRouter, File, UploadFile, Form
from fastapi import APIRouter, File, UploadFile, Form, Depends

from pilot.openapi.api_view_model import Result
from pilot.server.prompt.service import PromptManageService
from pilot.server.prompt.request.request import PromptManageRequest
from pilot.user import UserRequest, get_user_from_headers

router = APIRouter()

prompt_manage_service = PromptManageService()


@router.post("/prompt/add")
def prompt_add(request: PromptManageRequest):
def prompt_add(request: PromptManageRequest, user_token: UserRequest = Depends(get_user_from_headers)):
print(f"/space/add params: {request}")
request.user_id = user_token.user_id
try:
prompt_manage_service.create_prompt(request)
return Result.succ([])
Expand All @@ -20,27 +22,29 @@ def prompt_add(request: PromptManageRequest):


@router.post("/prompt/list")
def prompt_list(request: PromptManageRequest):
def prompt_list(request: PromptManageRequest, user_token: UserRequest = Depends(get_user_from_headers)):
print(f"/prompt/list params: {request}")
try:
request.user_id = user_token.user_id
return Result.succ(prompt_manage_service.get_prompts(request))
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt list error {e}")


@router.post("/prompt/update")
def prompt_update(request: PromptManageRequest):
def prompt_update(request: PromptManageRequest, user_token: UserRequest = Depends(get_user_from_headers)):
print(f"/prompt/update params: {request}")
request.user_id = user_token.user_id
try:
return Result.succ(prompt_manage_service.update_prompt(request))
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt update error {e}")


@router.post("/prompt/delete")
def prompt_delete(request: PromptManageRequest):
def prompt_delete(request: PromptManageRequest, user_token: UserRequest = Depends(get_user_from_headers)):
print(f"/prompt/delete params: {request}")
try:
return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name))
return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name, user_token.user_id))
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt delete error {e}")
8 changes: 7 additions & 1 deletion pilot/server/prompt/prompt_manage_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ class PromptManageEntity(Base):
user_name = Column(String(128))
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
user_id = Column(String(100))

def __repr__(self):
return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}', user_id='{self.user_id}')"


class PromptManageDao(BaseDao):
Expand All @@ -48,6 +49,7 @@ def create_prompt(self, prompt: PromptManageRequest):
user_name=prompt.user_name,
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
user_id=prompt.user_id,
)
session.add(prompt_manage)
session.commit()
Expand All @@ -74,6 +76,10 @@ def get_prompts(self, query: PromptManageEntity):
prompts = prompts.filter(
PromptManageEntity.prompt_name == query.prompt_name
)
if query.user_id is not None:
prompts = prompts.filter(
PromptManageEntity.user_id == query.user_id
)

prompts = prompts.order_by(PromptManageEntity.gmt_created.desc())
result = prompts.all()
Expand Down
3 changes: 3 additions & 0 deletions pilot/server/prompt/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ class PromptManageRequest(BaseModel):

"""prompt_name: prompt name"""
prompt_name: str = None

"""用户id"""
user_id: str = None
8 changes: 5 additions & 3 deletions pilot/server/prompt/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self):
def create_prompt(self, request: PromptManageRequest):
query = PromptManageRequest(
prompt_name=request.prompt_name,
user_id=request.user_id,
)
prompt_name = prompt_manage_dao.get_prompts(query)
if len(prompt_name) > 0:
Expand All @@ -32,6 +33,7 @@ def get_prompts(self, request: PromptManageRequest):
prompt_type=request.prompt_type,
prompt_name=request.prompt_name,
user_name=request.user_name,
user_id=request.user_id
)
responses = []
prompts = prompt_manage_dao.get_prompts(query)
Expand All @@ -53,7 +55,7 @@ def get_prompts(self, request: PromptManageRequest):
"""update prompt"""

def update_prompt(self, request: PromptManageRequest):
query = PromptManageEntity(prompt_name=request.prompt_name)
query = PromptManageEntity(prompt_name=request.prompt_name, user_id=request.user_id)
prompts = prompt_manage_dao.get_prompts(query)
if len(prompts) != 1:
raise Exception(
Expand All @@ -70,8 +72,8 @@ def update_prompt(self, request: PromptManageRequest):

"""delete prompt"""

def delete_prompt(self, prompt_name: str):
query = PromptManageEntity(prompt_name=prompt_name)
def delete_prompt(self, prompt_name: str, user_id: str):
query = PromptManageEntity(prompt_name=prompt_name, user_id=user_id)
prompts = prompt_manage_dao.get_prompts(query)
if len(prompts) == 0:
raise Exception(f"delete error, no prompt name:{prompt_name} in database ")
Expand Down

0 comments on commit 60fa7bf

Please sign in to comment.