diff --git a/libs/langgraph/langgraph/prebuilt/__init__.py b/libs/langgraph/langgraph/prebuilt/__init__.py index 671804258..50387838a 100644 --- a/libs/langgraph/langgraph/prebuilt/__init__.py +++ b/libs/langgraph/langgraph/prebuilt/__init__.py @@ -1,6 +1,7 @@ """langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools.""" from langgraph.prebuilt.chat_agent_executor import create_react_agent +from langgraph.prebuilt.pipeline import create_pipeline from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation from langgraph.prebuilt.tool_node import ( InjectedState, @@ -12,6 +13,7 @@ __all__ = [ "create_react_agent", + "create_pipeline", "ToolExecutor", "ToolInvocation", "ToolNode", diff --git a/libs/langgraph/langgraph/prebuilt/pipeline.py b/libs/langgraph/langgraph/prebuilt/pipeline.py new file mode 100644 index 000000000..f5a28dec3 --- /dev/null +++ b/libs/langgraph/langgraph/prebuilt/pipeline.py @@ -0,0 +1,65 @@ +from typing import Any, Optional, Sequence, Type, Union, cast + +from langchain_core.runnables.base import Runnable, RunnableLike + +from langgraph.graph.state import END, START, StateGraph + + +def _get_step_name(step: RunnableLike) -> str: + if isinstance(step, Runnable): + return step.get_name() + elif callable(step): + return getattr(step, "__name__", step.__class__.__name__) + else: + raise TypeError(f"Unsupported step type: {step}") + + +def create_pipeline( + steps: Sequence[Union[RunnableLike, tuple[str, RunnableLike]]], + *, + state_schema: Type[Any], + input_schema: Optional[Type[Any]] = None, + output_schema: Optional[Type[Any]] = None, +) -> StateGraph: + """Create a pipeline graph that runs a series of provided steps in order. + + Args: + steps: A sequence of RunnableLike objects (e.g. a LangChain Runnable or a callable) or (name, RunnableLike) tuples. + If no names are provided, the name will be inferred from the step object (e.g. a runnable or a callable name). + Each step will be executed in the order provided. + state_schema: The state schema for the graph. + input_schema: The input schema for the graph. + output_schema: The output schema for the graph. Will only be used when calling `graph.invoke()`. + + Returns: + A StateGraph object. + """ + if len(steps) < 1: + raise ValueError("Pipeline requires at least one step.") + + node_names = set() + builder = StateGraph(state_schema, input=input_schema, output=output_schema) + previous_name: Optional[str] = None + for step in steps: + if isinstance(step, tuple) and len(step) == 2: + name, step = step + else: + name = _get_step_name(step) + + if name in node_names: + raise ValueError( + f"Step names must be unique: step with the name '{name}' already exists. " + "If you need to use two different runnables/callables with the same name (for example, using `lambda`), please provide them as tuples (name, runnable/callable)." + ) + + node_names.add(name) + builder.add_node(name, step) + if previous_name is None: + builder.add_edge(START, name) + else: + builder.add_edge(previous_name, name) + + previous_name = name + + builder.add_edge(cast(str, previous_name), END) + return builder diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index a6655a451..4ef9fe895 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -1,5 +1,6 @@ import dataclasses import json +import operator from functools import partial from typing import ( Annotated, @@ -43,10 +44,12 @@ from langgraph.prebuilt import ( ToolNode, ValidationNode, + create_pipeline, create_react_agent, tools_condition, ) from langgraph.prebuilt.chat_agent_executor import _validate_chat_history +from langgraph.prebuilt.pipeline import _get_step_name from langgraph.prebuilt.tool_node import ( TOOL_CALL_ERROR_TEMPLATE, InjectedState, @@ -1416,3 +1419,108 @@ def foo(a: str, b: int) -> float: return 0.0 assert _get_state_args(foo) == {"a": None, "b": "bar"} + + +def test__get_step_name() -> None: + # default runnable name + assert _get_step_name(RunnableLambda(func=lambda x: x)) == "RunnableLambda" + # custom runnable name + assert ( + _get_step_name(RunnableLambda(name="my_runnable", func=lambda x: x)) + == "my_runnable" + ) + + # lambda + assert _get_step_name(lambda x: x) == "" + + # regular function + def func(state): + return + + assert _get_step_name(func) == "func" + + class MyClass: + def __call__(self, state): + return + + def class_method(self, state): + return + + # callable class + assert _get_step_name(MyClass()) == "MyClass" + + # class method + assert _get_step_name(MyClass().class_method) == "class_method" + + +def test_pipeline(): + class State(TypedDict): + foo: Annotated[list[str], operator.add] + bar: str + + def step1(state: State): + return {"foo": ["step1"], "bar": "baz"} + + def step2(state: State): + return {"foo": ["step2"]} + + # test raising if less than 1 steps + with pytest.raises(ValueError): + create_pipeline([], state_schema=State) + + # test raising if duplicate step names + with pytest.raises(ValueError): + create_pipeline([step1, step1], state_schema=State) + + with pytest.raises(ValueError): + create_pipeline([("foo", step1), ("foo", step1)], state_schema=State) + + # test unnamed steps + builder = create_pipeline([step1, step2], state_schema=State) + executor = builder.compile() + result = executor.invoke({"foo": []}) + assert result == {"foo": ["step1", "step2"], "bar": "baz"} + stream_chunks = list(executor.stream({"foo": []})) + assert stream_chunks == [ + {"step1": {"foo": ["step1"], "bar": "baz"}}, + {"step2": {"foo": ["step2"]}}, + ] + + # test named steps + builder_named_steps = create_pipeline( + [("meow1", step1), ("meow2", step2)], state_schema=State + ) + executor_named_steps = builder_named_steps.compile() + result = executor_named_steps.invoke({"foo": []}) + stream_chunks = list(executor_named_steps.stream({"foo": []})) + assert result == {"foo": ["step1", "step2"], "bar": "baz"} + assert stream_chunks == [ + {"meow1": {"foo": ["step1"], "bar": "baz"}}, + {"meow2": {"foo": ["step2"]}}, + ] + + # test input/output schema & functions w/ duplicate names + class Input(TypedDict): + foo: Annotated[list[str], operator.add] + + class Output(TypedDict): + bar: str + + builder_named_steps = create_pipeline( + [ + ("meow1", lambda state: {"foo": ["foo"]}), + ("meow2", lambda state: {"bar": state["foo"][0] + "bar"}), + ], + state_schema=State, + input_schema=Input, + output_schema=Output, + ) + executor_named_steps = builder_named_steps.compile() + result = executor_named_steps.invoke({"foo": []}) + stream_chunks = list(executor_named_steps.stream({"foo": []})) + # filtered by output schema + assert result == {"bar": "foobar"} + assert stream_chunks == [ + {"meow1": {"foo": ["foo"]}}, + {"meow2": {"bar": "foobar"}}, + ]