Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variable extract - Global variable throughout the conversation #5087

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion agent/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, dsl: str, tenant_id=None):
self.history = []
self.messages = []
self.answer = []
self.variables = {}
self.components = {}
self.dsl = json.loads(dsl) if dsl else {
"components": {
Expand All @@ -94,14 +95,16 @@ def __init__(self, dsl: str, tenant_id=None):
"messages": [],
"reference": [],
"path": [],
"answer": []
"answer": [],
"variables":{}
}
self._tenant_id = tenant_id
self._embed_id = ""
self.load()

def load(self):
self.components = self.dsl["components"]

cpn_nms = set([])
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
Expand All @@ -126,6 +129,8 @@ def load(self):
self.answer = self.dsl["answer"]
self.reference = self.dsl["reference"]
self._embed_id = self.dsl.get("embed_id", "")
self.variables = self.dsl.get("variables", {})


def __str__(self):
self.dsl["path"] = self.path
Expand All @@ -134,6 +139,7 @@ def __str__(self):
self.dsl["answer"] = self.answer
self.dsl["reference"] = self.reference
self.dsl["embed_id"] = self._embed_id
self.dsl["variables"] = self.variables
dsl = {
"components": {}
}
Expand All @@ -157,10 +163,12 @@ def reset(self):
self.history = []
self.messages = []
self.answer = []
self.variables = {}
self.reference = []
for k, cpn in self.components.items():
self.components[k]["obj"].reset()
self._embed_id = ""
self.variables = {}

def get_compnent_name(self, cid):
for n in self.dsl["graph"]["nodes"]:
Expand Down Expand Up @@ -309,6 +317,12 @@ def get_history(self, window_size):
else:
convs.append({"role": role, "content": str(obj)})
return convs
def update_variables(self, variables):
for key, value in variables.items():
if not self.variables.get(key):
self.variables[key] = ""
if value:
self.variables[key] = value

def add_user_input(self, question):
self.history.append(("user", question))
Expand Down Expand Up @@ -351,6 +365,9 @@ def _find_loop(self, max_loops=6):

def get_prologue(self):
return self.components["begin"]["obj"]._param.prologue

def get_variables(self):
return self.variables

def set_global_param(self, **kwargs):
for k, v in kwargs.items():
Expand Down
5 changes: 4 additions & 1 deletion agent/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .relevant import Relevant, RelevantParam
from .message import Message, MessageParam
from .rewrite import RewriteQuestion, RewriteQuestionParam
from .variable import VariableExtract, VariableExtractParam
from .keyword import KeywordExtract, KeywordExtractParam
from .concentrator import Concentrator, ConcentratorParam
from .baidu import Baidu, BaiduParam
Expand Down Expand Up @@ -76,7 +77,9 @@ def component_class(class_name):
"Message",
"MessageParam",
"RewriteQuestion",
"RewriteQuestionParam",
"RewriteQuestionParam",
"VariableExtract",
"VariableExtractParam",
"KeywordExtract",
"KeywordExtractParam",
"Concentrator",
Expand Down
21 changes: 17 additions & 4 deletions agent/component/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,18 @@ def check_defined_type(param, descr, types):
raise ValueError(
descr + " {} not supported, should be one of {}".format(param, types)
)

@staticmethod
def check_json(param, descr=""):
if type(param).__name__ != "str":
raise ValueError(
descr + " {} not supported, should be string type".format(param)
)
try:
json.loads(param)
except json.JSONDecodeError:
raise ValueError(
descr + " {} not supported, should be json string".format(param)
)
@staticmethod
def check_and_change_lower(param, valid_list, descr=""):
if type(param).__name__ != "str":
Expand Down Expand Up @@ -467,6 +478,7 @@ def get_input(self):
if self._param.query:
self._param.inputs = []
outs = []
vars =self._canvas.get_variables()
for q in self._param.query:
if q.get("component_id"):
if q["component_id"].split("@")[0].lower().find("begin") >= 0:
Expand All @@ -480,7 +492,6 @@ def get_input(self):
else:
assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue

if q["component_id"].lower().find("answer") == 0:
txt = []
for r, c in self._canvas.history[::-1][:self._param.message_history_window_size][::-1]:
Expand All @@ -489,8 +500,10 @@ def get_input(self):
self._param.inputs.append({"content": txt, "component_id": q["component_id"]})
outs.append(pd.DataFrame([{"content": txt}]))
continue

outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
if q["component_id"] in vars.keys():
outs.append(pd.DataFrame([{"content": vars[q["component_id"]]}]))
else:
outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
self._param.inputs.append({"component_id": q["component_id"],
"content": "\n".join(
[str(d["content"]) for d in outs[-1].to_dict('records')])})
Expand Down
5 changes: 5 additions & 0 deletions agent/component/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ def _run(self, history, **kwargs):
kwargs[para["key"]] = " - " + "\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
self._param.inputs.append({"component_id": para["key"], "content": kwargs[para["key"]]})

# Replace variables in the prompt
for var_key, var_value in self._canvas.get_variables().items():
if var_value:
prompt = prompt.replace(f"{{{var_key}}}", str(var_value))

if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else:
Expand Down
10 changes: 7 additions & 3 deletions agent/component/invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
from deepdoc.parser import HtmlParser
from agent.component.base import ComponentBase, ComponentParamBase


class InvokeParam(ComponentParamBase):
"""
Define the Crawler component parameters.
Define the Invoke component parameters.
"""

def __init__(self):
Expand All @@ -37,7 +36,7 @@ def __init__(self):
self.clean_html = False

def check(self):
self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put'])
self.check_valid_value(self.method.lower(), "Type of content from the invoke", ['get', 'post', 'put'])
self.check_empty(self.url, "End point URL")
self.check_positive_integer(self.timeout, "Timeout time in second")
self.check_boolean(self.clean_html, "Clean HTML")
Expand All @@ -48,6 +47,8 @@ class Invoke(ComponentBase, ABC):

def _run(self, history, **kwargs):
args = {}
vars =self._canvas.get_variables()

for para in self._param.variables:
if para.get("component_id"):
if '@' in para["component_id"]:
Expand All @@ -58,6 +59,8 @@ def _run(self, history, **kwargs):
if param["key"] == field:
if "value" in param:
args[para["key"]] = param["value"]
elif para.get("component_id") in vars.keys():
args[para["key"]] = vars[para.get("component_id")]
else:
cpn = self._canvas.get_component(para["component_id"])["obj"]
if cpn.component_name.lower() == "answer":
Expand All @@ -66,6 +69,7 @@ def _run(self, history, **kwargs):
_, out = cpn.output(allow_partial=False)
if not out.empty:
args[para["key"]] = "\n".join(out["content"])

else:
args[para["key"]] = para["value"]

Expand Down
4 changes: 3 additions & 1 deletion agent/component/switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from abc import ABC
from agent.component.base import ComponentBase, ComponentParamBase


class SwitchParam(ComponentParamBase):
"""
Define the Switch component parameters.
Expand Down Expand Up @@ -67,13 +66,16 @@ def _run(self, history, **kwargs):
for item in cond["items"]:
if not item["cpn_id"]:
continue
vars =self._canvas.get_variables()
cid = item["cpn_id"].split("@")[0]
if item["cpn_id"].find("@") > 0:
cpn_id, key = item["cpn_id"].split("@")
for p in self._canvas.get_component(cid)["obj"]._param.query:
if p["key"] == key:
res.append(self.process_operator(p.get("value",""), item["operator"], item.get("value", "")))
break
elif item["cpn_id"] in vars.keys():
res.append(self.process_operator(vars[item["cpn_id"]], item["operator"], item.get("value", "")))
else:
out = self._canvas.get_component(cid)["obj"].output()[1]
cpn_input = "" if "content" not in out.columns else " ".join([str(s) for s in out["content"]])
Expand Down
7 changes: 4 additions & 3 deletions agent/component/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from agent.component.base import ComponentBase, ComponentParamBase
from jinja2 import Template as Jinja2Template


class TemplateParam(ComponentParamBase):
"""
Define the Generate component parameters.
Expand Down Expand Up @@ -92,15 +91,17 @@ def _run(self, history, **kwargs):
continue

_, out = cpn.output(allow_partial=False)

result = ""
if "content" in out.columns:
result = "\n".join(
[o if isinstance(o, str) else str(o) for o in out["content"]]
)

self.make_kwargs(para, kwargs, result)

# Replace variables in the content
for var_key, var_value in self._canvas.get_variables().items():
if var_value:
content = content.replace(f"{{{var_key}}}", str(var_value))
template = Jinja2Template(content)

try:
Expand Down
125 changes: 125 additions & 0 deletions agent/component/variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
from abc import ABC
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from agent.component import GenerateParam, Generate
import json


class VariableExtractParam(GenerateParam):
"""
Define the VariableExtract component parameters.
"""

def __init__(self):
super().__init__()
self.temperature = 0.9
self.variables = {}
self.prompt = ""


def check(self):
super().check()
self.check_json(self.variables , "JSON format")

def get_prompt(self, conv, variables):
prompt = f"""
You are a data expert extracting information. DON'T generate anything except the information extracted by the template.
######################################
Example
######################################
# Example 1
REQUEST: Get 'UserName', 'Address' from the conversation.
## Conversation
-ASSISTANT: What is your name?
-USER: My name is Jennifer, I live in Washington.

## Output template:
```json
{{
"UserName":"Jennifer",
"Address":"Washington"
}}
```
###########
# Example 2
REQUEST: Get 'UserCode', 'Department' from the conversation.
## Conversation
-USER: My employee code is 39211.
-ASSISTANT: What department are you in?
-USER: I am in HR department.
## Output template:
```json
{{
"UserCode":"39211",
"Department":"HR"
}}
```
###################
# Real Data
REQUEST: Get '{", ".join(variables.keys())}' from the conversation.

## Conversation
{conv}
######################################
"""
logging.info(prompt)
return prompt


class VariableExtract(Generate, ABC):
component_name = "VariableExtract"

def _run(self, history, **kwargs):

variables = {}
if self._param.variables:
variables = json.loads(self._param.variables)
self._canvas.update_variables(variables)

hist = self._canvas.get_history(self._param.message_history_window_size)
conv = []
for m in hist:
if m["role"] not in ["user"]:
continue
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
conv = "\n".join(conv)
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
ans = chat_mdl.chat(self._param.get_prompt(conv, variables),
[{"role": "user", "content": "Output template:"}], self._param.gen_conf())
match = re.search(r"```json\s*(.*?)\s*```", ans, re.DOTALL)
if match:
ans = match.group(1)
logging.debug(ans)
if not ans:
logging.debug(ans)
return VariableExtract.be_output("JSON not found!")


logging.info(f"ans: {ans}")
try:
ans_json = json.loads(ans)
self._canvas.update_variables(ans_json)
return VariableExtract.be_output(ans)
except json.JSONDecodeError:
logging.warning(f"VariableExtract: LLM returned non-JSON output: {ans}")
return VariableExtract.be_output("non-JSON")

def debug(self, **kwargs):
return self._run([], **kwargs)
4 changes: 3 additions & 1 deletion api/apps/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def set_conversation():
"dialog_id": cvs.id,
"user_id": request.args.get("user_id", ""),
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent"
"source": "agent",
"variables":canvas.get_variables(),

}
API4ConversationService.save(**conv)
return get_json_result(data=conv)
Expand Down
Loading