-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
modify config files,schema based on new node repo
- Loading branch information
Showing
8 changed files
with
139 additions
and
124 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
OPENAI_API_KEY= | ||
NODE_URL=http://localhost:7001 |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
[ | ||
{ | ||
"name": "agent_deployment_1", | ||
"module": {"name": "babyagi_task_finalizer"}, | ||
"worker_node_url": "http://localhost:7001", | ||
"agent_config": { | ||
"config_name": "TaskFinalizerAgentConfig", | ||
"llm_config": {"config_name": "model_2"}, | ||
"persona_module" : null, | ||
"system_prompt": { | ||
"role": "You are a helpful AI assistant.", | ||
"persona": "" | ||
}, | ||
"user_message_template": "You are given the following task: {{task}}. The task is to accomplish the following objective: {{objective}}." | ||
} | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[ | ||
{ | ||
"config_name": "model_1", | ||
"client": "ollama", | ||
"model": "ollama/phi", | ||
"temperature": 0.7, | ||
"max_tokens": 1000, | ||
"api_base": "http://localhost:11434" | ||
}, | ||
{ | ||
"config_name": "model_2", | ||
"client": "openai", | ||
"model": "gpt-4o-mini", | ||
"temperature": 0.7, | ||
"max_tokens": 1000, | ||
"api_base": "https://api.openai.com/v1" | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,97 @@ | ||
#!/usr/bin/env python | ||
from dotenv import load_dotenv | ||
from babyagi_task_finalizer.schemas import InputSchema, TaskExecutorPromptSchema, TaskFinalizerAgentConfig, TaskFinalizer | ||
import json | ||
from litellm import completion | ||
import os | ||
import yaml | ||
import instructor | ||
from litellm import Router | ||
from babyagi_task_finalizer.schemas import InputSchema, TaskFinalizer | ||
from babyagi_task_finalizer.utils import get_logger | ||
|
||
from naptha_sdk.schemas import AgentDeployment, AgentRunInput | ||
from naptha_sdk.utils import get_logger | ||
|
||
load_dotenv() | ||
logger = get_logger(__name__) | ||
|
||
client = instructor.patch( | ||
Router( | ||
model_list= | ||
[ | ||
{ | ||
"model_name": "gpt-3.5-turbo", | ||
"litellm_params": { | ||
"model": "openai/gpt-3.5-turbo", | ||
"api_key": os.getenv("OPENAI_API_KEY"), | ||
}, | ||
} | ||
], | ||
# default_litellm_params={"acompletion": True}, | ||
) | ||
) | ||
|
||
def llm_call(messages, response_model=None): | ||
if response_model: | ||
response = client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=response_model, | ||
class TaskFinalizerAgent: | ||
def __init__(self, agent_deployment: AgentDeployment): | ||
self.agent_deployment = agent_deployment | ||
|
||
def execute_task(self, inputs: InputSchema): | ||
if isinstance(self.agent_deployment.agent_config, dict): | ||
self.agent_deployment.agent_config = TaskFinalizerAgentConfig(**self.agent_deployment.agent_config) | ||
|
||
user_prompt = self.agent_deployment.agent_config.user_message_template.replace( | ||
"{{task}}", inputs.tool_input_data.task | ||
).replace( | ||
"{{objective}}", inputs.tool_input_data.objective | ||
) | ||
|
||
messages = [ | ||
{"role": "system", "content": json.dumps(self.agent_deployment.agent_config.system_prompt)}, | ||
{"role": "user", "content": user_prompt} | ||
] | ||
|
||
api_key = None if self.agent_deployment.agent_config.llm_config.client == "ollama" else ( | ||
"EMPTY" if self.agent_deployment.agent_config.llm_config.client == "vllm" else os.getenv("OPENAI_API_KEY") | ||
) | ||
|
||
response = completion( | ||
model=self.agent_deployment.agent_config.llm_config.model, | ||
messages=messages, | ||
temperature=0.0, | ||
max_tokens=1000, | ||
temperature=self.agent_deployment.agent_config.llm_config.temperature, | ||
max_tokens=self.agent_deployment.agent_config.llm_config.max_tokens, | ||
api_base=self.agent_deployment.agent_config.llm_config.api_base, | ||
api_key=api_key | ||
) | ||
return response | ||
|
||
def run(inputs: InputSchema, *args, **kwargs): | ||
logger.info(f"Running with inputs {inputs.objective}") | ||
logger.info(f"Running with inputs {inputs.task}") | ||
cfg = kwargs["cfg"] | ||
|
||
user_prompt = cfg["inputs"]["user_message_template"].replace("{{task}}", inputs.task).replace("{{objective}}", inputs.objective) | ||
|
||
messages = [ | ||
{"role": "system", "content": cfg["inputs"]["system_message"]}, | ||
{"role": "user", "content": user_prompt} | ||
] | ||
|
||
response = llm_call(messages, response_model=TaskFinalizer) | ||
|
||
logger.info(f"Result: {response}") | ||
|
||
return response.model_dump_json() | ||
|
||
|
||
# Parse the response into the TaskFinalizer model | ||
response_content = response.choices[0].message.content | ||
|
||
try: | ||
# Attempt to parse the response as JSON | ||
parsed_response = json.loads(response_content) | ||
task_finalizer = TaskFinalizer(**parsed_response) | ||
except (json.JSONDecodeError, TypeError): | ||
# If parsing fails, create a TaskFinalizer with the raw content | ||
task_finalizer = TaskFinalizer( | ||
final_report=response_content, | ||
new_tasks=[], | ||
objective_met=False | ||
) | ||
|
||
logger.info(f"Response: {task_finalizer}") | ||
return task_finalizer.model_dump_json() | ||
|
||
def run(agent_run: AgentRunInput, *args, **kwargs): | ||
logger.info(f"Running with inputs {agent_run.inputs.tool_input_data}") | ||
task_finalizer_agent = TaskFinalizerAgent(agent_run.agent_deployment) | ||
method = getattr(task_finalizer_agent, agent_run.inputs.tool_name, None) | ||
return method(agent_run.inputs) | ||
|
||
if __name__ == "__main__": | ||
with open("babyagi_task_finalizer/component.yaml", "r") as f: | ||
cfg = yaml.safe_load(f) | ||
|
||
inputs = InputSchema( | ||
task="Weather pattern between year 1900 and 2000 was cool andry", | ||
objective="Write a blog post about the weather in London." | ||
from naptha_sdk.client.naptha import Naptha | ||
from naptha_sdk.configs import load_agent_deployments | ||
|
||
naptha = Naptha() | ||
|
||
# Configs | ||
agent_deployments = load_agent_deployments( | ||
"babyagi_task_finalizer/configs/agent_deployments.json", | ||
load_persona_data=False, | ||
load_persona_schema=False | ||
) | ||
|
||
r = run(inputs, cfg=cfg) | ||
logger.info(f"Result: {type(r)}") | ||
|
||
|
||
import json | ||
t = TaskFinalizer(**json.loads(r)) | ||
logger.info(f"Final report: {type(t)}") | ||
|
||
|
||
input_params = InputSchema( | ||
tool_name="execute_task", | ||
tool_input_data=TaskExecutorPromptSchema( | ||
task="Weather pattern between year 1900 and 2000", | ||
objective="Write a blog post about the weather in London." | ||
), | ||
) | ||
|
||
agent_run = AgentRunInput( | ||
inputs=input_params, | ||
agent_deployment=agent_deployments[0], | ||
consumer_id=naptha.user.id, | ||
) | ||
|
||
response = run(agent_run) | ||
print(response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters