Skip to content

Commit

Permalink
Merge
Browse files Browse the repository at this point in the history
Signed-off-by: jinhai <[email protected]>
  • Loading branch information
JinHai-CN committed Nov 15, 2024
1 parent cb3b9d7 commit 33c4e09
Show file tree
Hide file tree
Showing 33 changed files with 454 additions and 413 deletions.
22 changes: 13 additions & 9 deletions agent/component/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from api.db import LLMType
from api.db.services.dialog_service import message_fit_in
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from api import settings
from agent.component.base import ComponentBase, ComponentParamBase


Expand Down Expand Up @@ -63,18 +63,20 @@ class Generate(ComponentBase):
component_name = "Generate"

def get_dependent_components(self):
cpnts = [para["component_id"] for para in self._param.parameters if para.get("component_id") and para["component_id"].lower().find("answer") < 0]
cpnts = [para["component_id"] for para in self._param.parameters if
para.get("component_id") and para["component_id"].lower().find("answer") < 0]
return cpnts

def set_cite(self, retrieval_res, answer):
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
if "empty_response" in retrieval_res.columns:
retrieval_res["empty_response"].fillna("", inplace=True)
answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
[ck["vector"] for _, ck in retrieval_res.iterrows()],
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
self._canvas.get_embedding_model()), tkweight=0.7,
vtweight=0.3)
answer, idx = settings.retrievaler.insert_citations(answer,
[ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
[ck["vector"] for _, ck in retrieval_res.iterrows()],
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
self._canvas.get_embedding_model()), tkweight=0.7,
vtweight=0.3)
doc_ids = set([])
recall_docs = []
for i in idx:
Expand Down Expand Up @@ -127,12 +129,14 @@ def _run(self, history, **kwargs):
else:
if cpn.component_name.lower() == "retrieval":
retrieval_res.append(out)
kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
kwargs[para["key"]] = " - " + "\n - ".join(
[o if isinstance(o, str) else str(o) for o in out["content"]])
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})

if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else: retrieval_res = pd.DataFrame([])
else:
retrieval_res = pd.DataFrame([])

for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
Expand Down
4 changes: 2 additions & 2 deletions agent/component/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from api import settings
from agent.component.base import ComponentBase, ComponentParamBase


Expand Down Expand Up @@ -67,7 +67,7 @@ def _run(self, history, **kwargs):
if self._param.rerank_id:
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)

kbinfos = retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
1, self._param.top_n,
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
aggs=False, rerank_mdl=rerank_mdl)
Expand Down
10 changes: 4 additions & 6 deletions api/apps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@

from flask_session import Session
from flask_login import LoginManager
from api.settings import SECRET_KEY
from api.settings import API_VERSION
from api import settings
from api.utils.api_utils import server_error_response
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer

Expand Down Expand Up @@ -78,7 +77,6 @@
app.json_encoder = CustomJSONEncoder
app.errorhandler(Exception)(server_error_response)


## convince for dev and debug
# app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False
Expand Down Expand Up @@ -110,7 +108,7 @@ def register_page(page_path):

page_name = page_path.stem.rstrip("_app")
module_name = ".".join(
page_path.parts[page_path.parts.index("api") : -1] + (page_name,)
page_path.parts[page_path.parts.index("api"): -1] + (page_name,)
)

spec = spec_from_file_location(module_name, page_path)
Expand All @@ -121,7 +119,7 @@ def register_page(page_path):
spec.loader.exec_module(page)
page_name = getattr(page, "page_name", page_name)
url_prefix = (
f"/api/{API_VERSION}" if "/sdk/" in path else f"/{API_VERSION}/{page_name}"
f"/api/{settings.API_VERSION}" if "/sdk/" in path else f"/{settings.API_VERSION}/{page_name}"
)

app.register_blueprint(page.manager, url_prefix=url_prefix)
Expand All @@ -141,7 +139,7 @@ def register_page(page_path):

@login_manager.request_loader
def load_user(web_request):
jwt = Serializer(secret_key=SECRET_KEY)
jwt = Serializer(secret_key=settings.SECRET_KEY)
authorization = web_request.headers.get("Authorization")
if authorization:
try:
Expand Down
64 changes: 33 additions & 31 deletions api/apps/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import queue_tasks, TaskService
from api.db.services.user_service import UserTenantService
from api.settings import RetCode, retrievaler
from api import settings
from api.utils import get_uuid, current_timestamp, datetime_format
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
generate_confirmation_token
Expand Down Expand Up @@ -141,7 +141,7 @@ def set_conversation():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
try:
if objs[0].source == "agent":
Expand Down Expand Up @@ -183,7 +183,7 @@ def completion():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
Expand Down Expand Up @@ -290,8 +290,8 @@ def sse():
API4ConversationService.append_message(conv.id, conv.to_dict())
rename_field(result)
return get_json_result(data=result)
#******************For dialog******************

# ******************For dialog******************
conv.message.append(msg[-1])
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
Expand Down Expand Up @@ -326,7 +326,7 @@ def stream():
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp

answer = None
for ans in chat(dia, msg, **req):
answer = ans
Expand All @@ -347,8 +347,8 @@ def get(conversation_id):
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)

try:
e, conv = API4ConversationService.get_by_id(conversation_id)
if not e:
Expand All @@ -357,8 +357,8 @@ def get(conversation_id):
conv = conv.to_dict()
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
return get_json_result(data=False, message='Token is not valid for this conversation_id!"',
code=RetCode.AUTHENTICATION_ERROR)
code=settings.RetCode.AUTHENTICATION_ERROR)

for referenct_i in conv['reference']:
if referenct_i is None or len(referenct_i) == 0:
continue
Expand All @@ -378,7 +378,7 @@ def upload():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)

kb_name = request.form.get("kb_name").strip()
tenant_id = objs[0].tenant_id
Expand All @@ -394,12 +394,12 @@ def upload():

if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)

file = request.files['file']
if file.filename == '':
return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)

root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
Expand Down Expand Up @@ -490,17 +490,17 @@ def upload_parse():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)

if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)

file_objs = request.files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)

doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
return get_json_result(data=doc_ids)
Expand All @@ -513,7 +513,7 @@ def list_chunks():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)

req = request.json

Expand All @@ -531,7 +531,7 @@ def list_chunks():
)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)

res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
res = [
{
"content": res_item["content_with_weight"],
Expand All @@ -553,7 +553,7 @@ def list_kb_docs():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)

req = request.json
tenant_id = objs[0].tenant_id
Expand Down Expand Up @@ -585,14 +585,15 @@ def list_kb_docs():
except Exception as e:
return server_error_response(e)


@manager.route('/document/infos', methods=['POST'])
@validate_request("doc_ids")
def docinfos():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
doc_ids = req["doc_ids"]
docs = DocumentService.get_by_ids(doc_ids)
Expand All @@ -606,7 +607,7 @@ def document_rm():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)

tenant_id = objs[0].tenant_id
req = request.json
Expand Down Expand Up @@ -653,7 +654,7 @@ def document_rm():
errors += str(e)

if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)

return get_json_result(data=True)

Expand All @@ -668,7 +669,7 @@ def completion_faq():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)

e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
Expand Down Expand Up @@ -805,10 +806,10 @@ def retrieval():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)

req = request.json
kb_ids = req.get("kb_id",[])
kb_ids = req.get("kb_id", [])
doc_ids = req.get("doc_ids", [])
question = req.get("question")
page = int(req.get("page", 1))
Expand All @@ -822,26 +823,27 @@ def retrieval():
embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1:
return get_json_result(
data=False, message='Knowledge bases use different embedding models or does not exist."', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Knowledge bases use different embedding models or does not exist."',
code=settings.RetCode.AUTHENTICATION_ERROR)

embd_mdl = TenantLLMService.model_instance(
kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = TenantLLMService.model_instance(
kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
if req.get("keyword", False):
chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl)
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl)
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]
return get_json_result(data=ranks)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
code=RetCode.DATA_ERROR)
code=settings.RetCode.DATA_ERROR)
return server_error_response(e)
Loading

0 comments on commit 33c4e09

Please sign in to comment.