Skip to content

Commit

Permalink
Use input_fields from processors when invoking as tool
Browse files Browse the repository at this point in the history
  • Loading branch information
ajhai committed Oct 23, 2024
1 parent f7a140d commit d6b5060
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 24 deletions.
1 change: 1 addition & 0 deletions llmstack/apps/runner/agent_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ async def _process_output(self):
tool_call_args = tool_call.arguments
try:
tool_call_args = json.loads(tool_call_args)
tool_call_args["_inputs0"] = self._messages["_inputs0"]
except Exception:
pass

Expand Down
2 changes: 2 additions & 0 deletions llmstack/apps/runner/app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ def _get_actor_configs_from_processors(
"session_id": session_id,
"request_user": self._source.request_user,
"app_uuid": self._source.id,
"input_fields": processor.get("input_fields", []),
"is_tool": is_agent,
"output_template": processor.get("output_template", {"markdown": ""}) if is_agent else None,
},
dependencies=processor.get(
Expand Down
9 changes: 3 additions & 6 deletions llmstack/play/output_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ def stitch_model_objects(obj1: Any, obj2: Any) -> Any:
)
return stitched_obj

if isinstance(obj1, BaseModel):
obj1 = obj1.model_dump()
if isinstance(obj2, BaseModel):
obj2 = obj2.model_dump()

def stitch_fields(
obj1_fields: Dict[str, Any],
obj2_fields: Dict[str, Any],
Expand Down Expand Up @@ -191,7 +186,9 @@ def finalize(
"""
Closes the output stream and returns stitched data.
"""
output = self._data if not self._output_cls else self._output_cls(**self._data)
output = (
self._data if not self._output_cls or isinstance(self._data, BaseModel) else self._output_cls(**self._data)
)
self._data = None

# Send the end message
Expand Down
44 changes: 26 additions & 18 deletions llmstack/processors/providers/api_processor_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(
request_user=None,
app_uuid=None,
coordinator_urn=None,
input_fields=[],
dependencies=[],
output_template={},
id=None,
Expand All @@ -164,6 +165,7 @@ def __init__(
self._session_id = session_id
self._request_user = request_user
self._app_uuid = app_uuid
self._input_fields = input_fields
self._is_tool = is_tool
self._session_enabled = session_enabled
self._output_template = output_template
Expand Down Expand Up @@ -311,27 +313,33 @@ def get_bookkeeping_data(self) -> BookKeepingData:
return None

def input(self, message: Any) -> Any:
# Hydrate the input and config before processing
if self._is_tool:
# NO-OP when the processor is a tool
return
try:
self._input = (
hydrate_input(
self._input_template,
message,
if self._is_tool and len(self._input_fields) == 0:
self._input = self._get_input_class()(**message)
self._config = hydrate_input(self._config_template, message)
elif self._is_tool:
hydrated_input = hydrate_input(self._input_template, message)
self._input = self._get_input_class()(
**(hydrated_input.model_dump() if isinstance(hydrated_input, BaseModel) else hydrated_input)
)
if message
else self._input_template
)
self._config = (
hydrate_input(
self._config_template,
message,
self._config = hydrate_input(self._config_template, message)
else:
self._input = (
hydrate_input(
self._input_template,
message,
)
if message
else self._input_template
)
self._config = (
hydrate_input(
self._config_template,
message,
)
if self._config and message
else self._config_template
)
if self._config and message
else self._config_template
)
output = self.process()
except Exception as e:
logger.exception("Error processing input")
Expand Down

0 comments on commit d6b5060

Please sign in to comment.