diff --git a/examples/dynamic_task_ordering_example.py b/examples/dynamic_task_ordering_example.py new file mode 100644 index 0000000000..ebff99885c --- /dev/null +++ b/examples/dynamic_task_ordering_example.py @@ -0,0 +1,113 @@ +""" +Example demonstrating dynamic task ordering in CrewAI. + +This example shows how to use the task_ordering_callback to dynamically +determine the execution order of tasks based on runtime conditions. +""" + +from crewai import Agent, Crew, Task +from crewai.process import Process + + +def priority_based_ordering(all_tasks, completed_outputs, current_index): + """ + Order tasks by priority (lower number = higher priority). + + Args: + all_tasks: List of all tasks in the crew + completed_outputs: List of TaskOutput objects for completed tasks + current_index: Current task index (for default ordering) + + Returns: + int: Index of next task to execute + Task: Task object to execute next + None: Use default ordering + """ + completed_tasks = {id(task) for task in all_tasks if task.output is not None} + + remaining_tasks = [ + (i, task) for i, task in enumerate(all_tasks) + if id(task) not in completed_tasks + ] + + if not remaining_tasks: + return None + + remaining_tasks.sort(key=lambda x: getattr(x[1], 'priority', 999)) + + return remaining_tasks[0][0] + + +def conditional_ordering(all_tasks, completed_outputs, current_index): + """ + Order tasks based on previous task outputs. + + This example shows how to make task ordering decisions based on + the results of previously completed tasks. + """ + if len(completed_outputs) == 0: + return 0 + + last_output = completed_outputs[-1] + + if "urgent" in last_output.raw.lower(): + completed_tasks = {id(task) for task in all_tasks if task.output is not None} + for i, task in enumerate(all_tasks): + if (hasattr(task, 'priority') and task.priority == 1 and + id(task) not in completed_tasks): + return i + + return None + + +researcher = Agent( + role="Research Analyst", + goal="Gather and analyze information", + backstory="Expert at finding and synthesizing information" +) + +writer = Agent( + role="Content Writer", + goal="Create compelling content", + backstory="Skilled at crafting engaging narratives" +) + +reviewer = Agent( + role="Quality Reviewer", + goal="Ensure content quality", + backstory="Meticulous attention to detail" +) + +research_task = Task( + description="Research the latest trends in AI", + expected_output="Comprehensive research report", + agent=researcher +) +research_task.priority = 2 + +urgent_task = Task( + description="Write urgent press release", + expected_output="Press release draft", + agent=writer +) +urgent_task.priority = 1 + +review_task = Task( + description="Review and edit content", + expected_output="Polished final content", + agent=reviewer +) +review_task.priority = 3 + +crew = Crew( + agents=[researcher, writer, reviewer], + tasks=[research_task, urgent_task, review_task], + process=Process.sequential, + task_ordering_callback=priority_based_ordering, + verbose=True +) + +if __name__ == "__main__": + print("Starting crew with dynamic task ordering...") + result = crew.kickoff() + print(f"Completed {len(result.tasks_output)} tasks") diff --git a/src/crewai/crew.py b/src/crewai/crew.py index ed9479bdc9..7664349ef6 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -1,4 +1,5 @@ import asyncio +import inspect import json import re import uuid @@ -113,6 +114,9 @@ class Crew(FlowTrackable, BaseModel): execution. step_callback: Callback to be executed after each step for every agents execution. + task_ordering_callback: Callback to determine the next task to execute + dynamically. Receives (all_tasks, completed_outputs, current_index) + and returns next task index, Task object, or None for default ordering. share_crew: Whether you want to share the complete crew information and execution with crewAI to make the library better, and allow us to train models. @@ -213,6 +217,12 @@ class Crew(FlowTrackable, BaseModel): "It may be used to adjust the output of the crew." ), ) + task_ordering_callback: Callable[ + [list[Task], list[TaskOutput], int], int | Task | None + ] | None = Field( + default=None, + description="Callback to determine the next task to execute. Receives (all_tasks, completed_outputs, current_index) and returns next task index, Task object, or None for default ordering.", + ) max_rpm: int | None = Field( default=None, description=( @@ -535,6 +545,25 @@ def validate_context_no_future_tasks(self): ) return self + @model_validator(mode="after") + def validate_task_ordering_callback(self): + """Validates that the task ordering callback has the correct signature.""" + if self.task_ordering_callback is not None: + if not callable(self.task_ordering_callback): + raise ValueError("task_ordering_callback must be callable") + + try: + sig = inspect.signature(self.task_ordering_callback) + except (ValueError, TypeError): + pass + else: + if len(sig.parameters) != 3: + raise ValueError( + "task_ordering_callback must accept exactly 3 parameters: (tasks, outputs, current_index)" + ) + + return self + @property def key(self) -> str: source: list[str] = [agent.key for agent in self.agents] + [ @@ -847,12 +876,12 @@ def _execute_tasks( start_index: int | None = 0, was_replayed: bool = False, ) -> CrewOutput: - """Executes tasks sequentially and returns the final output. + """Executes tasks with optional dynamic ordering and returns the final output. Args: tasks (List[Task]): List of tasks to execute - manager (Optional[BaseAgent], optional): Manager agent to use for - delegation. Defaults to None. + start_index (int | None): Starting index for task execution + was_replayed (bool): Whether this is a replay execution Returns: CrewOutput: Final output of the crew @@ -861,7 +890,8 @@ def _execute_tasks( task_outputs: list[TaskOutput] = [] futures: list[tuple[Task, Future[TaskOutput], int]] = [] last_sync_output: TaskOutput | None = None - + executed_task_indices: set[int] = set() + for task_index, task in enumerate(tasks): if start_index is not None and task_index < start_index: if task.output: @@ -870,7 +900,66 @@ def _execute_tasks( else: task_outputs = [task.output] last_sync_output = task.output - continue + executed_task_indices.add(task_index) + + while len(executed_task_indices) < len(tasks): + # Find next task to execute + if self.task_ordering_callback: + try: + next_task_result = self.task_ordering_callback( + tasks, task_outputs, len(executed_task_indices) + ) + + if next_task_result is None: + task_index = next(i for i in range(len(tasks)) if i not in executed_task_indices) + elif isinstance(next_task_result, int): + if 0 <= next_task_result < len(tasks) and next_task_result not in executed_task_indices: + task_index = next_task_result + else: + self._logger.log( + "warning", + f"Invalid or already executed task index {next_task_result} from ordering callback, using default", + color="yellow" + ) + task_index = next(i for i in range(len(tasks)) if i not in executed_task_indices) + elif isinstance(next_task_result, Task): + try: + candidate_index = tasks.index(next_task_result) + if candidate_index not in executed_task_indices: + task_index = candidate_index + else: + self._logger.log( + "warning", + "Task from ordering callback already executed, using default", + color="yellow" + ) + task_index = next(i for i in range(len(tasks)) if i not in executed_task_indices) + except ValueError: + self._logger.log( + "warning", + "Task from ordering callback not found in tasks list, using default", + color="yellow" + ) + task_index = next(i for i in range(len(tasks)) if i not in executed_task_indices) + else: + self._logger.log( + "warning", + f"Invalid return type from ordering callback: {type(next_task_result)}, using default", + color="yellow" + ) + task_index = next(i for i in range(len(tasks)) if i not in executed_task_indices) + except Exception as e: + self._logger.log( + "warning", + f"Error in task ordering callback: {e}, using default ordering", + color="yellow" + ) + task_index = next(i for i in range(len(tasks)) if i not in executed_task_indices) + else: + task_index = next(i for i in range(len(tasks)) if i not in executed_task_indices) + + task = tasks[task_index] + executed_task_indices.add(task_index) agent_to_use = self._get_agent_to_use(task) if agent_to_use is None: @@ -880,9 +969,7 @@ def _execute_tasks( f"or a manager agent is provided." ) - # Determine which tools to use - task tools take precedence over agent tools tools_for_task = task.tools or agent_to_use.tools or [] - # Prepare tools and ensure they're compatible with task execution tools_for_task = self._prepare_tools( agent_to_use, task, @@ -923,6 +1010,7 @@ def _execute_tasks( task_outputs.append(task_output) self._process_task_result(task, task_output) self._store_execution_log(task, task_output, task_index, was_replayed) + last_sync_output = task_output if futures: task_outputs = self._process_async_tasks(futures, was_replayed) diff --git a/tests/test_dynamic_task_ordering.py b/tests/test_dynamic_task_ordering.py new file mode 100644 index 0000000000..941b840eff --- /dev/null +++ b/tests/test_dynamic_task_ordering.py @@ -0,0 +1,302 @@ +import pytest +from unittest.mock import Mock + +from crewai import Agent, Crew, Task +from crewai.process import Process +from crewai.task import TaskOutput + + +@pytest.fixture +def agents(): + return [ + Agent(role="Agent 1", goal="Goal 1", backstory="Backstory 1"), + Agent(role="Agent 2", goal="Goal 2", backstory="Backstory 2"), + Agent(role="Agent 3", goal="Goal 3", backstory="Backstory 3"), + ] + + +@pytest.fixture +def tasks(agents): + return [ + Task(description="Task 1", expected_output="Output 1", agent=agents[0]), + Task(description="Task 2", expected_output="Output 2", agent=agents[1]), + Task(description="Task 3", expected_output="Output 3", agent=agents[2]), + ] + + +def test_sequential_process_with_reverse_ordering(agents, tasks): + """Test sequential process with reverse task ordering.""" + execution_order = [] + + def reverse_ordering_callback(all_tasks, completed_outputs, current_index): + completed_tasks = {id(task) for task in all_tasks if task.output is not None} + remaining_indices = [i for i in range(len(all_tasks)) + if id(all_tasks[i]) not in completed_tasks] + if remaining_indices: + next_index = max(remaining_indices) + execution_order.append(next_index) + return next_index + return None + + crew = Crew( + agents=agents, + tasks=tasks, + process=Process.sequential, + task_ordering_callback=reverse_ordering_callback, + verbose=False + ) + + result = crew.kickoff() + + assert len(result.tasks_output) == 3 + assert execution_order == [2, 1, 0] + + +def test_hierarchical_process_with_priority_ordering(agents, tasks): + """Test hierarchical process with priority-based task ordering.""" + + tasks[0].priority = 3 + tasks[1].priority = 1 + tasks[2].priority = 2 + + execution_order = [] + + def priority_ordering_callback(all_tasks, completed_outputs, current_index): + completed_tasks = {id(task) for task in all_tasks if task.output is not None} + remaining_tasks = [ + (i, task) for i, task in enumerate(all_tasks) + if id(task) not in completed_tasks + ] + + if remaining_tasks: + remaining_tasks.sort(key=lambda x: getattr(x[1], 'priority', 999)) + next_index = remaining_tasks[0][0] + execution_order.append(next_index) + return next_index + + return None + + crew = Crew( + agents=agents, + tasks=tasks, + process=Process.hierarchical, + manager_llm="gpt-4o", + task_ordering_callback=priority_ordering_callback, + verbose=False + ) + + result = crew.kickoff() + assert len(result.tasks_output) == 3 + assert execution_order == [1, 2, 0] + + +def test_task_ordering_callback_with_task_object_return(): + """Test callback returning Task object instead of index.""" + + agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")] + tasks = [ + Task(description="Task A", expected_output="Output A", agent=agents[0]), + Task(description="Task B", expected_output="Output B", agent=agents[0]), + ] + + execution_order = [] + + def task_object_callback(all_tasks, completed_outputs, current_index): + if len(completed_outputs) == 0: + execution_order.append(1) + return all_tasks[1] + elif len(completed_outputs) == 1: + execution_order.append(0) + return all_tasks[0] + return None + + crew = Crew( + agents=agents, + tasks=tasks, + process=Process.sequential, + task_ordering_callback=task_object_callback, + verbose=False + ) + + result = crew.kickoff() + assert len(result.tasks_output) == 2 + assert execution_order == [1, 0] + + +def test_invalid_task_ordering_callback_index(): + """Test handling of invalid task index from callback.""" + + agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")] + tasks = [Task(description="Task", expected_output="Output", agent=agents[0])] + + def invalid_callback(all_tasks, completed_outputs, current_index): + return 999 + + crew = Crew( + agents=agents, + tasks=tasks, + process=Process.sequential, + task_ordering_callback=invalid_callback, + verbose=False + ) + + result = crew.kickoff() + assert len(result.tasks_output) == 1 + + +def test_task_ordering_callback_exception_handling(): + """Test handling of exceptions in task ordering callback.""" + + agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")] + tasks = [Task(description="Task", expected_output="Output", agent=agents[0])] + + def failing_callback(all_tasks, completed_outputs, current_index): + raise ValueError("Callback error") + + crew = Crew( + agents=agents, + tasks=tasks, + process=Process.sequential, + task_ordering_callback=failing_callback, + verbose=False + ) + + result = crew.kickoff() + assert len(result.tasks_output) == 1 + + +def test_task_ordering_callback_validation(): + """Test validation of task ordering callback signature.""" + + agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")] + tasks = [Task(description="Task", expected_output="Output", agent=agents[0])] + + def invalid_signature_callback(only_one_param): + return 0 + + with pytest.raises(ValueError, match="task_ordering_callback must accept exactly 3 parameters"): + Crew( + agents=agents, + tasks=tasks, + process=Process.sequential, + task_ordering_callback=invalid_signature_callback + ) + + +def test_no_task_ordering_callback_default_behavior(): + """Test that default behavior is unchanged when no callback is provided.""" + + agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")] + tasks = [ + Task(description="Task 1", expected_output="Output 1", agent=agents[0]), + Task(description="Task 2", expected_output="Output 2", agent=agents[0]), + ] + + crew = Crew( + agents=agents, + tasks=tasks, + process=Process.sequential, + verbose=False + ) + + result = crew.kickoff() + assert len(result.tasks_output) == 2 + + +def test_task_ordering_callback_with_none_return(): + """Test callback returning None for default ordering.""" + + agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")] + tasks = [ + Task(description="Task 1", expected_output="Output 1", agent=agents[0]), + Task(description="Task 2", expected_output="Output 2", agent=agents[0]), + ] + + def none_callback(all_tasks, completed_outputs, current_index): + return None + + crew = Crew( + agents=agents, + tasks=tasks, + process=Process.sequential, + task_ordering_callback=none_callback, + verbose=False + ) + + result = crew.kickoff() + assert len(result.tasks_output) == 2 + + +def test_task_ordering_callback_invalid_task_object(): + """Test handling of invalid Task object from callback.""" + + agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")] + tasks = [Task(description="Task", expected_output="Output", agent=agents[0])] + + invalid_task = Task(description="Invalid", expected_output="Invalid", agent=agents[0]) + + def invalid_task_callback(all_tasks, completed_outputs, current_index): + return invalid_task + + crew = Crew( + agents=agents, + tasks=tasks, + process=Process.sequential, + task_ordering_callback=invalid_task_callback, + verbose=False + ) + + result = crew.kickoff() + assert len(result.tasks_output) == 1 + + +def test_task_ordering_callback_invalid_return_type(): + """Test handling of invalid return type from callback.""" + + agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")] + tasks = [Task(description="Task", expected_output="Output", agent=agents[0])] + + def invalid_type_callback(all_tasks, completed_outputs, current_index): + return "invalid" + + crew = Crew( + agents=agents, + tasks=tasks, + process=Process.sequential, + task_ordering_callback=invalid_type_callback, + verbose=False + ) + + result = crew.kickoff() + assert len(result.tasks_output) == 1 + + +def test_task_ordering_prevents_infinite_loops(): + """Test that task ordering prevents infinite loops by tracking executed tasks.""" + + agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")] + tasks = [ + Task(description="Task 1", expected_output="Output 1", agent=agents[0]), + Task(description="Task 2", expected_output="Output 2", agent=agents[0]), + ] + + call_count = 0 + + def loop_callback(all_tasks, completed_outputs, current_index): + nonlocal call_count + call_count += 1 + if call_count > 10: + pytest.fail("Callback called too many times, possible infinite loop") + return 0 + + crew = Crew( + agents=agents, + tasks=tasks, + process=Process.sequential, + task_ordering_callback=loop_callback, + verbose=False + ) + + result = crew.kickoff() + assert len(result.tasks_output) == 2 + assert call_count <= 4