From a8c465b1ab35c22d4c3832edd7d82bb55ff5a804 Mon Sep 17 00:00:00 2001 From: vbarda Date: Wed, 23 Oct 2024 10:55:02 -0400 Subject: [PATCH 1/7] [rfc] langgraph: add prebuilt sequential executor graph --- libs/langgraph/langgraph/prebuilt/__init__.py | 2 + .../langgraph/prebuilt/sequential_executor.py | 77 +++++++++++++++++++ libs/langgraph/tests/test_prebuilt.py | 43 +++++++++++ 3 files changed, 122 insertions(+) create mode 100644 libs/langgraph/langgraph/prebuilt/sequential_executor.py diff --git a/libs/langgraph/langgraph/prebuilt/__init__.py b/libs/langgraph/langgraph/prebuilt/__init__.py index 671804258..cb6bf10cd 100644 --- a/libs/langgraph/langgraph/prebuilt/__init__.py +++ b/libs/langgraph/langgraph/prebuilt/__init__.py @@ -1,6 +1,7 @@ """langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools.""" from langgraph.prebuilt.chat_agent_executor import create_react_agent +from langgraph.prebuilt.sequential_executor import create_sequential_executor from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation from langgraph.prebuilt.tool_node import ( InjectedState, @@ -12,6 +13,7 @@ __all__ = [ "create_react_agent", + "create_sequential_executor", "ToolExecutor", "ToolInvocation", "ToolNode", diff --git a/libs/langgraph/langgraph/prebuilt/sequential_executor.py b/libs/langgraph/langgraph/prebuilt/sequential_executor.py new file mode 100644 index 000000000..a0e54d777 --- /dev/null +++ b/libs/langgraph/langgraph/prebuilt/sequential_executor.py @@ -0,0 +1,77 @@ +from typing import Any, Optional, Type, Union + +from langchain_core.runnables.base import Runnable, RunnableLike + +from langgraph.checkpoint.base import BaseCheckpointSaver +from langgraph.graph.state import END, START, CompiledStateGraph, StateGraph +from langgraph.store.base import BaseStore + + +def _get_name(runnable_like: RunnableLike) -> str: + if isinstance(runnable_like, Runnable): + if runnable_like.name is None: + raise ValueError( + f"Runnable ({runnable_like}) needs to have a name attribute. " + "Consider setting the name or passing it as a tuple (name, runnable)." + ) + return runnable_like.name + elif callable(runnable_like): + return runnable_like.__name__ + else: + raise ValueError(f"Unsupported runnable_like: {runnable_like}") + + +def create_sequential_executor( + *steps: Union[RunnableLike, tuple[str, RunnableLike]], + state_schema: Type[Any], + checkpointer: Optional[BaseCheckpointSaver] = None, + store: Optional[BaseStore] = None, + interrupt_before: Optional[list[str]] = None, + interrupt_after: Optional[list[str]] = None, + debug: bool = False, +) -> CompiledStateGraph: + """Creates a sequential executor graph that runs a series of provided steps in order. + + Args: + *steps: A sequence of RunnableLike objects or (name, RunnableLike) tuples. + If no names are provided, the name will be inferred from the step object (e.g. a runnable or a callable name). + Each step will be executed in the order provided. + state_schema: The state schema for the graph. + checkpointer: An optional checkpoint saver object. This is used for persisting + the state of the graph (e.g., as chat memory) for a single thread (e.g., a single conversation). + store: An optional store object. This is used for persisting data + across multiple threads (e.g., multiple conversations / users). + interrupt_before: An optional list of step names to interrupt before execution. + interrupt_after: An optional list of step names to interrupt after execution. + debug: A flag to enable debug mode. + + Returns: + A CompiledStateGraph object. + """ + if len(steps) < 2: + raise ValueError("Sequential executor requires at least two steps.") + + builder = StateGraph(state_schema) + previous_name = None + for step_idx, step in enumerate(steps): + if isinstance(step, tuple) and len(step) == 2: + name, step = step + else: + name = _get_name(step) + + builder.add_node(name, step) + if step_idx == 0: + builder.add_edge(START, name) + else: + builder.add_edge(previous_name, name) + + previous_name = name + + builder.add_edge(previous_name, END) + return builder.compile( + checkpointer=checkpointer, + store=store, + interrupt_before=interrupt_before, + interrupt_after=interrupt_after, + debug=debug, + ) diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index 2ebc2f6cb..dacab4fa7 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -1,5 +1,6 @@ import dataclasses import json +import operator from functools import partial from typing import ( Annotated, @@ -41,6 +42,7 @@ ToolNode, ValidationNode, create_react_agent, + create_sequential_executor, tools_condition, ) from langgraph.prebuilt.tool_node import InjectedState, InjectedStore @@ -968,3 +970,44 @@ def tool_normal(input: str) -> str: id=result["messages"][3].id, ), ] + + +def test_sequential_executor(): + class State(TypedDict): + foo: Annotated[list[str], operator.add] + bar: str + + def step1(state: State): + return {"foo": ["step1"], "bar": "baz"} + + def step2(state: State): + return {"foo": ["step2"]} + + # test raising if less than 2 steps + with pytest.raises(ValueError): + create_sequential_executor(state_schema=State) + + with pytest.raises(ValueError): + create_sequential_executor(step1, state_schema=State) + + # test unnamed steps + executor = create_sequential_executor(step1, step2, state_schema=State) + result = executor.invoke({"foo": []}) + assert result == {"foo": ["step1", "step2"], "bar": "baz"} + stream_chunks = list(executor.stream({"foo": []})) + assert stream_chunks == [ + {"step1": {"foo": ["step1"], "bar": "baz"}}, + {"step2": {"foo": ["step2"]}}, + ] + + # test named steps + executor_named_steps = create_sequential_executor( + ("meow1", step1), ("meow2", step2), state_schema=State + ) + result = executor_named_steps.invoke({"foo": []}) + stream_chunks = list(executor_named_steps.stream({"foo": []})) + assert result == {"foo": ["step1", "step2"], "bar": "baz"} + assert stream_chunks == [ + {"meow1": {"foo": ["step1"], "bar": "baz"}}, + {"meow2": {"foo": ["step2"]}}, + ] From 4dfba663d34300f7a900d342dda0483cb3a7e267 Mon Sep 17 00:00:00 2001 From: vbarda Date: Wed, 23 Oct 2024 15:01:33 -0400 Subject: [PATCH 2/7] rename --- .../langgraph/prebuilt/sequential_executor.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/libs/langgraph/langgraph/prebuilt/sequential_executor.py b/libs/langgraph/langgraph/prebuilt/sequential_executor.py index a0e54d777..8c775c0c0 100644 --- a/libs/langgraph/langgraph/prebuilt/sequential_executor.py +++ b/libs/langgraph/langgraph/prebuilt/sequential_executor.py @@ -7,18 +7,18 @@ from langgraph.store.base import BaseStore -def _get_name(runnable_like: RunnableLike) -> str: - if isinstance(runnable_like, Runnable): - if runnable_like.name is None: +def _get_name(step: RunnableLike) -> str: + if isinstance(step, Runnable): + if step.name is None: raise ValueError( - f"Runnable ({runnable_like}) needs to have a name attribute. " + f"Runnable ({step}) needs to have a name attribute. " "Consider setting the name or passing it as a tuple (name, runnable)." ) - return runnable_like.name - elif callable(runnable_like): - return runnable_like.__name__ + return step.name + elif callable(step): + return step.__name__ else: - raise ValueError(f"Unsupported runnable_like: {runnable_like}") + raise TypeError(f"Unsupported step type: {step}") def create_sequential_executor( From 6ec8616fdde763a236bce80e3f82db679f7e6ab4 Mon Sep 17 00:00:00 2001 From: vbarda Date: Wed, 23 Oct 2024 15:12:07 -0400 Subject: [PATCH 3/7] lint --- .../langgraph/prebuilt/sequential_executor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/libs/langgraph/langgraph/prebuilt/sequential_executor.py b/libs/langgraph/langgraph/prebuilt/sequential_executor.py index 8c775c0c0..a92177114 100644 --- a/libs/langgraph/langgraph/prebuilt/sequential_executor.py +++ b/libs/langgraph/langgraph/prebuilt/sequential_executor.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Type, Union, cast from langchain_core.runnables.base import Runnable, RunnableLike @@ -16,7 +16,7 @@ def _get_name(step: RunnableLike) -> str: ) return step.name elif callable(step): - return step.__name__ + return getattr(step, "__name__", step.__class__.__name__) else: raise TypeError(f"Unsupported step type: {step}") @@ -52,22 +52,22 @@ def create_sequential_executor( raise ValueError("Sequential executor requires at least two steps.") builder = StateGraph(state_schema) - previous_name = None - for step_idx, step in enumerate(steps): + previous_name: Optional[str] = None + for step in steps: if isinstance(step, tuple) and len(step) == 2: name, step = step else: name = _get_name(step) builder.add_node(name, step) - if step_idx == 0: + if previous_name is None: builder.add_edge(START, name) else: builder.add_edge(previous_name, name) previous_name = name - builder.add_edge(previous_name, END) + builder.add_edge(cast(str, previous_name), END) return builder.compile( checkpointer=checkpointer, store=store, From 586df0bc3f76fe154f5e84ec043e39983fb1ec0e Mon Sep 17 00:00:00 2001 From: vbarda Date: Fri, 1 Nov 2024 09:31:12 -0400 Subject: [PATCH 4/7] code review --- libs/langgraph/langgraph/prebuilt/__init__.py | 4 ++-- ...quential_executor.py => chain_executor.py} | 24 ++++++++++++++----- libs/langgraph/tests/test_prebuilt.py | 18 ++++++++------ 3 files changed, 31 insertions(+), 15 deletions(-) rename libs/langgraph/langgraph/prebuilt/{sequential_executor.py => chain_executor.py} (78%) diff --git a/libs/langgraph/langgraph/prebuilt/__init__.py b/libs/langgraph/langgraph/prebuilt/__init__.py index cb6bf10cd..2ea96829b 100644 --- a/libs/langgraph/langgraph/prebuilt/__init__.py +++ b/libs/langgraph/langgraph/prebuilt/__init__.py @@ -1,7 +1,7 @@ """langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools.""" +from langgraph.prebuilt.chain_executor import create_chain_executor from langgraph.prebuilt.chat_agent_executor import create_react_agent -from langgraph.prebuilt.sequential_executor import create_sequential_executor from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation from langgraph.prebuilt.tool_node import ( InjectedState, @@ -13,7 +13,7 @@ __all__ = [ "create_react_agent", - "create_sequential_executor", + "create_chain_executor", "ToolExecutor", "ToolInvocation", "ToolNode", diff --git a/libs/langgraph/langgraph/prebuilt/sequential_executor.py b/libs/langgraph/langgraph/prebuilt/chain_executor.py similarity index 78% rename from libs/langgraph/langgraph/prebuilt/sequential_executor.py rename to libs/langgraph/langgraph/prebuilt/chain_executor.py index a92177114..f76968d7c 100644 --- a/libs/langgraph/langgraph/prebuilt/sequential_executor.py +++ b/libs/langgraph/langgraph/prebuilt/chain_executor.py @@ -21,22 +21,26 @@ def _get_name(step: RunnableLike) -> str: raise TypeError(f"Unsupported step type: {step}") -def create_sequential_executor( +def create_chain_executor( *steps: Union[RunnableLike, tuple[str, RunnableLike]], state_schema: Type[Any], + return_compiled: bool = True, checkpointer: Optional[BaseCheckpointSaver] = None, store: Optional[BaseStore] = None, interrupt_before: Optional[list[str]] = None, interrupt_after: Optional[list[str]] = None, debug: bool = False, -) -> CompiledStateGraph: - """Creates a sequential executor graph that runs a series of provided steps in order. +) -> Union[CompiledStateGraph, StateGraph]: + """Creates a chain executor graph that runs a series of provided steps in order. Args: *steps: A sequence of RunnableLike objects or (name, RunnableLike) tuples. If no names are provided, the name will be inferred from the step object (e.g. a runnable or a callable name). Each step will be executed in the order provided. state_schema: The state schema for the graph. + return_compiled: Whether to return the compiled graph or the builder object. + If False, all of the arguments except `steps` and `state_schema` will be ignored. + Defaults to True (return compiled graph). checkpointer: An optional checkpoint saver object. This is used for persisting the state of the graph (e.g., as chat memory) for a single thread (e.g., a single conversation). store: An optional store object. This is used for persisting data @@ -46,11 +50,12 @@ def create_sequential_executor( debug: A flag to enable debug mode. Returns: - A CompiledStateGraph object. + A CompiledStateGraph object if `return_compiled` is True, otherwise a StateGraph object. """ - if len(steps) < 2: - raise ValueError("Sequential executor requires at least two steps.") + if len(steps) < 1: + raise ValueError("Sequential executor requires at least one step.") + node_names = set() builder = StateGraph(state_schema) previous_name: Optional[str] = None for step in steps: @@ -59,6 +64,10 @@ def create_sequential_executor( else: name = _get_name(step) + if name in node_names: + raise ValueError(f"Node name {name} already exists.") + + node_names.add(name) builder.add_node(name, step) if previous_name is None: builder.add_edge(START, name) @@ -68,6 +77,9 @@ def create_sequential_executor( previous_name = name builder.add_edge(cast(str, previous_name), END) + if not return_compiled: + return builder + return builder.compile( checkpointer=checkpointer, store=store, diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index a009373bb..a7b7de87e 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -44,8 +44,8 @@ from langgraph.prebuilt import ( ToolNode, ValidationNode, + create_chain_executor, create_react_agent, - create_sequential_executor, tools_condition, ) from langgraph.prebuilt.tool_node import ( @@ -1337,7 +1337,7 @@ def tool_normal(input: str) -> str: ] -def test_sequential_executor(): +def test_chain_executor(): class State(TypedDict): foo: Annotated[list[str], operator.add] bar: str @@ -1348,15 +1348,19 @@ def step1(state: State): def step2(state: State): return {"foo": ["step2"]} - # test raising if less than 2 steps + # test raising if less than 1 steps with pytest.raises(ValueError): - create_sequential_executor(state_schema=State) + create_chain_executor(state_schema=State) + # test raising if duplicate step names with pytest.raises(ValueError): - create_sequential_executor(step1, state_schema=State) + create_chain_executor(step1, step1, state_schema=State) + + with pytest.raises(ValueError): + create_chain_executor(("foo", step1), ("foo", step1), state_schema=State) # test unnamed steps - executor = create_sequential_executor(step1, step2, state_schema=State) + executor = create_chain_executor(step1, step2, state_schema=State) result = executor.invoke({"foo": []}) assert result == {"foo": ["step1", "step2"], "bar": "baz"} stream_chunks = list(executor.stream({"foo": []})) @@ -1366,7 +1370,7 @@ def step2(state: State): ] # test named steps - executor_named_steps = create_sequential_executor( + executor_named_steps = create_chain_executor( ("meow1", step1), ("meow2", step2), state_schema=State ) result = executor_named_steps.invoke({"foo": []}) From 0ad50e9ce201e4c92384d4e271cbd6edeb3334ef Mon Sep 17 00:00:00 2001 From: vbarda Date: Fri, 1 Nov 2024 09:39:22 -0400 Subject: [PATCH 5/7] fix merge --- libs/langgraph/tests/test_prebuilt.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index a7b7de87e..cd29c13b7 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -1337,6 +1337,21 @@ def tool_normal(input: str) -> str: ] +def test__get_state_args() -> None: + class Schema1(BaseModel): + a: Annotated[str, InjectedState] + + class Schema2(Schema1): + b: Annotated[int, InjectedState("bar")] + + @dec_tool(args_schema=Schema2) + def foo(a: str, b: int) -> float: + """return""" + return 0.0 + + assert _get_state_args(foo) == {"a": None, "b": "bar"} + + def test_chain_executor(): class State(TypedDict): foo: Annotated[list[str], operator.add] From 20297658bf5c0709b271956d3b178f958ee9d5c3 Mon Sep 17 00:00:00 2001 From: vbarda Date: Fri, 1 Nov 2024 10:25:42 -0400 Subject: [PATCH 6/7] more code review --- libs/langgraph/langgraph/prebuilt/__init__.py | 4 +- libs/langgraph/langgraph/prebuilt/chain.py | 64 +++++++++++++ .../langgraph/prebuilt/chain_executor.py | 89 ------------------- libs/langgraph/tests/test_prebuilt.py | 73 +++++++++++++-- 4 files changed, 132 insertions(+), 98 deletions(-) create mode 100644 libs/langgraph/langgraph/prebuilt/chain.py delete mode 100644 libs/langgraph/langgraph/prebuilt/chain_executor.py diff --git a/libs/langgraph/langgraph/prebuilt/__init__.py b/libs/langgraph/langgraph/prebuilt/__init__.py index 2ea96829b..05dfa0823 100644 --- a/libs/langgraph/langgraph/prebuilt/__init__.py +++ b/libs/langgraph/langgraph/prebuilt/__init__.py @@ -1,6 +1,6 @@ """langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools.""" -from langgraph.prebuilt.chain_executor import create_chain_executor +from langgraph.prebuilt.chain import create_chain from langgraph.prebuilt.chat_agent_executor import create_react_agent from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation from langgraph.prebuilt.tool_node import ( @@ -13,7 +13,7 @@ __all__ = [ "create_react_agent", - "create_chain_executor", + "create_chain", "ToolExecutor", "ToolInvocation", "ToolNode", diff --git a/libs/langgraph/langgraph/prebuilt/chain.py b/libs/langgraph/langgraph/prebuilt/chain.py new file mode 100644 index 000000000..d4dd813d7 --- /dev/null +++ b/libs/langgraph/langgraph/prebuilt/chain.py @@ -0,0 +1,64 @@ +from typing import Any, Optional, Type, Union, cast + +from langchain_core.runnables.base import Runnable, RunnableLike + +from langgraph.graph.state import END, START, StateGraph + + +def _get_step_name(step: RunnableLike) -> str: + if isinstance(step, Runnable): + return step.get_name() + elif callable(step): + return getattr(step, "__name__", step.__class__.__name__) + else: + raise TypeError(f"Unsupported step type: {step}") + + +def create_chain( + *steps: Union[RunnableLike, tuple[str, RunnableLike]], + state_schema: Type[Any], + input_schema: Optional[Type[Any]] = None, + output_schema: Optional[Type[Any]] = None, +) -> StateGraph: + """Creates a chain graph that runs a series of provided steps in order. + + Args: + *steps: A sequence of RunnableLike objects (e.g. a LangChain Runnable or a callable) or (name, RunnableLike) tuples. + If no names are provided, the name will be inferred from the step object (e.g. a runnable or a callable name). + Each step will be executed in the order provided. + state_schema: The state schema for the graph. + input_schema: The input schema for the graph. + output_schema: The output schema for the graph. Will only be used when calling `graph.invoke()`. + + Returns: + A StateGraph object. + """ + if len(steps) < 1: + raise ValueError("Chain requires at least one step.") + + node_names = set() + builder = StateGraph(state_schema, input=input_schema, output=output_schema) + previous_name: Optional[str] = None + for step in steps: + if isinstance(step, tuple) and len(step) == 2: + name, step = step + else: + name = _get_step_name(step) + + if name in node_names: + raise ValueError( + f"Step names must be unique: step with the name '{name}' already exists. " + "If you need to use two different runnables/callables with the same name (for example, using `lambda`), please provide them as tuples (name, runnable/callable)." + ) + + node_names.add(name) + builder.add_node(name, step) + if previous_name is None: + builder.add_edge(START, name) + else: + builder.add_edge(previous_name, name) + + previous_name = name + + builder.add_edge(cast(str, previous_name), END) + return builder diff --git a/libs/langgraph/langgraph/prebuilt/chain_executor.py b/libs/langgraph/langgraph/prebuilt/chain_executor.py deleted file mode 100644 index f76968d7c..000000000 --- a/libs/langgraph/langgraph/prebuilt/chain_executor.py +++ /dev/null @@ -1,89 +0,0 @@ -from typing import Any, Optional, Type, Union, cast - -from langchain_core.runnables.base import Runnable, RunnableLike - -from langgraph.checkpoint.base import BaseCheckpointSaver -from langgraph.graph.state import END, START, CompiledStateGraph, StateGraph -from langgraph.store.base import BaseStore - - -def _get_name(step: RunnableLike) -> str: - if isinstance(step, Runnable): - if step.name is None: - raise ValueError( - f"Runnable ({step}) needs to have a name attribute. " - "Consider setting the name or passing it as a tuple (name, runnable)." - ) - return step.name - elif callable(step): - return getattr(step, "__name__", step.__class__.__name__) - else: - raise TypeError(f"Unsupported step type: {step}") - - -def create_chain_executor( - *steps: Union[RunnableLike, tuple[str, RunnableLike]], - state_schema: Type[Any], - return_compiled: bool = True, - checkpointer: Optional[BaseCheckpointSaver] = None, - store: Optional[BaseStore] = None, - interrupt_before: Optional[list[str]] = None, - interrupt_after: Optional[list[str]] = None, - debug: bool = False, -) -> Union[CompiledStateGraph, StateGraph]: - """Creates a chain executor graph that runs a series of provided steps in order. - - Args: - *steps: A sequence of RunnableLike objects or (name, RunnableLike) tuples. - If no names are provided, the name will be inferred from the step object (e.g. a runnable or a callable name). - Each step will be executed in the order provided. - state_schema: The state schema for the graph. - return_compiled: Whether to return the compiled graph or the builder object. - If False, all of the arguments except `steps` and `state_schema` will be ignored. - Defaults to True (return compiled graph). - checkpointer: An optional checkpoint saver object. This is used for persisting - the state of the graph (e.g., as chat memory) for a single thread (e.g., a single conversation). - store: An optional store object. This is used for persisting data - across multiple threads (e.g., multiple conversations / users). - interrupt_before: An optional list of step names to interrupt before execution. - interrupt_after: An optional list of step names to interrupt after execution. - debug: A flag to enable debug mode. - - Returns: - A CompiledStateGraph object if `return_compiled` is True, otherwise a StateGraph object. - """ - if len(steps) < 1: - raise ValueError("Sequential executor requires at least one step.") - - node_names = set() - builder = StateGraph(state_schema) - previous_name: Optional[str] = None - for step in steps: - if isinstance(step, tuple) and len(step) == 2: - name, step = step - else: - name = _get_name(step) - - if name in node_names: - raise ValueError(f"Node name {name} already exists.") - - node_names.add(name) - builder.add_node(name, step) - if previous_name is None: - builder.add_edge(START, name) - else: - builder.add_edge(previous_name, name) - - previous_name = name - - builder.add_edge(cast(str, previous_name), END) - if not return_compiled: - return builder - - return builder.compile( - checkpointer=checkpointer, - store=store, - interrupt_before=interrupt_before, - interrupt_after=interrupt_after, - debug=debug, - ) diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index cd29c13b7..863851cce 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -44,10 +44,11 @@ from langgraph.prebuilt import ( ToolNode, ValidationNode, - create_chain_executor, + create_chain, create_react_agent, tools_condition, ) +from langgraph.prebuilt.chain import _get_step_name from langgraph.prebuilt.tool_node import ( TOOL_CALL_ERROR_TEMPLATE, InjectedState, @@ -1352,7 +1353,39 @@ def foo(a: str, b: int) -> float: assert _get_state_args(foo) == {"a": None, "b": "bar"} -def test_chain_executor(): +def test__get_step_name() -> None: + # default runnable name + assert _get_step_name(RunnableLambda(func=lambda x: x)) == "RunnableLambda" + # custom runnable name + assert ( + _get_step_name(RunnableLambda(name="my_runnable", func=lambda x: x)) + == "my_runnable" + ) + + # lambda + assert _get_step_name(lambda x: x) == "" + + # regular function + def func(state): + return + + assert _get_step_name(func) == "func" + + class MyClass: + def __call__(self, state): + return + + def class_method(self, state): + return + + # callable class + assert _get_step_name(MyClass()) == "MyClass" + + # class method + assert _get_step_name(MyClass().class_method) == "class_method" + + +def test_chain(): class State(TypedDict): foo: Annotated[list[str], operator.add] bar: str @@ -1365,17 +1398,18 @@ def step2(state: State): # test raising if less than 1 steps with pytest.raises(ValueError): - create_chain_executor(state_schema=State) + create_chain(state_schema=State) # test raising if duplicate step names with pytest.raises(ValueError): - create_chain_executor(step1, step1, state_schema=State) + create_chain(step1, step1, state_schema=State) with pytest.raises(ValueError): - create_chain_executor(("foo", step1), ("foo", step1), state_schema=State) + create_chain(("foo", step1), ("foo", step1), state_schema=State) # test unnamed steps - executor = create_chain_executor(step1, step2, state_schema=State) + builder = create_chain(step1, step2, state_schema=State) + executor = builder.compile() result = executor.invoke({"foo": []}) assert result == {"foo": ["step1", "step2"], "bar": "baz"} stream_chunks = list(executor.stream({"foo": []})) @@ -1385,9 +1419,10 @@ def step2(state: State): ] # test named steps - executor_named_steps = create_chain_executor( + builder_named_steps = create_chain( ("meow1", step1), ("meow2", step2), state_schema=State ) + executor_named_steps = builder_named_steps.compile() result = executor_named_steps.invoke({"foo": []}) stream_chunks = list(executor_named_steps.stream({"foo": []})) assert result == {"foo": ["step1", "step2"], "bar": "baz"} @@ -1395,3 +1430,27 @@ def step2(state: State): {"meow1": {"foo": ["step1"], "bar": "baz"}}, {"meow2": {"foo": ["step2"]}}, ] + + # test input/output schema & functions w/ duplicate names + class Input(TypedDict): + foo: Annotated[list[str], operator.add] + + class Output(TypedDict): + bar: str + + builder_named_steps = create_chain( + ("meow1", lambda state: {"foo": ["foo"]}), + ("meow2", lambda state: {"bar": state["foo"][0] + "bar"}), + state_schema=State, + input_schema=Input, + output_schema=Output, + ) + executor_named_steps = builder_named_steps.compile() + result = executor_named_steps.invoke({"foo": []}) + stream_chunks = list(executor_named_steps.stream({"foo": []})) + # filtered by output schema + assert result == {"bar": "foobar"} + assert stream_chunks == [ + {"meow1": {"foo": ["foo"]}}, + {"meow2": {"bar": "foobar"}}, + ] From 467dc68738fa8ed785cbee568e1eb5095dc971f1 Mon Sep 17 00:00:00 2001 From: vbarda Date: Mon, 4 Nov 2024 14:50:11 -0500 Subject: [PATCH 7/7] pipeline --- libs/langgraph/langgraph/prebuilt/__init__.py | 4 +-- .../prebuilt/{chain.py => pipeline.py} | 13 +++++----- libs/langgraph/tests/test_prebuilt.py | 26 ++++++++++--------- 3 files changed, 23 insertions(+), 20 deletions(-) rename libs/langgraph/langgraph/prebuilt/{chain.py => pipeline.py} (82%) diff --git a/libs/langgraph/langgraph/prebuilt/__init__.py b/libs/langgraph/langgraph/prebuilt/__init__.py index 05dfa0823..50387838a 100644 --- a/libs/langgraph/langgraph/prebuilt/__init__.py +++ b/libs/langgraph/langgraph/prebuilt/__init__.py @@ -1,7 +1,7 @@ """langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools.""" -from langgraph.prebuilt.chain import create_chain from langgraph.prebuilt.chat_agent_executor import create_react_agent +from langgraph.prebuilt.pipeline import create_pipeline from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation from langgraph.prebuilt.tool_node import ( InjectedState, @@ -13,7 +13,7 @@ __all__ = [ "create_react_agent", - "create_chain", + "create_pipeline", "ToolExecutor", "ToolInvocation", "ToolNode", diff --git a/libs/langgraph/langgraph/prebuilt/chain.py b/libs/langgraph/langgraph/prebuilt/pipeline.py similarity index 82% rename from libs/langgraph/langgraph/prebuilt/chain.py rename to libs/langgraph/langgraph/prebuilt/pipeline.py index d4dd813d7..f5a28dec3 100644 --- a/libs/langgraph/langgraph/prebuilt/chain.py +++ b/libs/langgraph/langgraph/prebuilt/pipeline.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Type, Union, cast +from typing import Any, Optional, Sequence, Type, Union, cast from langchain_core.runnables.base import Runnable, RunnableLike @@ -14,16 +14,17 @@ def _get_step_name(step: RunnableLike) -> str: raise TypeError(f"Unsupported step type: {step}") -def create_chain( - *steps: Union[RunnableLike, tuple[str, RunnableLike]], +def create_pipeline( + steps: Sequence[Union[RunnableLike, tuple[str, RunnableLike]]], + *, state_schema: Type[Any], input_schema: Optional[Type[Any]] = None, output_schema: Optional[Type[Any]] = None, ) -> StateGraph: - """Creates a chain graph that runs a series of provided steps in order. + """Create a pipeline graph that runs a series of provided steps in order. Args: - *steps: A sequence of RunnableLike objects (e.g. a LangChain Runnable or a callable) or (name, RunnableLike) tuples. + steps: A sequence of RunnableLike objects (e.g. a LangChain Runnable or a callable) or (name, RunnableLike) tuples. If no names are provided, the name will be inferred from the step object (e.g. a runnable or a callable name). Each step will be executed in the order provided. state_schema: The state schema for the graph. @@ -34,7 +35,7 @@ def create_chain( A StateGraph object. """ if len(steps) < 1: - raise ValueError("Chain requires at least one step.") + raise ValueError("Pipeline requires at least one step.") node_names = set() builder = StateGraph(state_schema, input=input_schema, output=output_schema) diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index 15179092f..4ef9fe895 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -44,12 +44,12 @@ from langgraph.prebuilt import ( ToolNode, ValidationNode, - create_chain, + create_pipeline, create_react_agent, tools_condition, ) -from langgraph.prebuilt.chain import _get_step_name from langgraph.prebuilt.chat_agent_executor import _validate_chat_history +from langgraph.prebuilt.pipeline import _get_step_name from langgraph.prebuilt.tool_node import ( TOOL_CALL_ERROR_TEMPLATE, InjectedState, @@ -1453,7 +1453,7 @@ def class_method(self, state): assert _get_step_name(MyClass().class_method) == "class_method" -def test_chain(): +def test_pipeline(): class State(TypedDict): foo: Annotated[list[str], operator.add] bar: str @@ -1466,17 +1466,17 @@ def step2(state: State): # test raising if less than 1 steps with pytest.raises(ValueError): - create_chain(state_schema=State) + create_pipeline([], state_schema=State) # test raising if duplicate step names with pytest.raises(ValueError): - create_chain(step1, step1, state_schema=State) + create_pipeline([step1, step1], state_schema=State) with pytest.raises(ValueError): - create_chain(("foo", step1), ("foo", step1), state_schema=State) + create_pipeline([("foo", step1), ("foo", step1)], state_schema=State) # test unnamed steps - builder = create_chain(step1, step2, state_schema=State) + builder = create_pipeline([step1, step2], state_schema=State) executor = builder.compile() result = executor.invoke({"foo": []}) assert result == {"foo": ["step1", "step2"], "bar": "baz"} @@ -1487,8 +1487,8 @@ def step2(state: State): ] # test named steps - builder_named_steps = create_chain( - ("meow1", step1), ("meow2", step2), state_schema=State + builder_named_steps = create_pipeline( + [("meow1", step1), ("meow2", step2)], state_schema=State ) executor_named_steps = builder_named_steps.compile() result = executor_named_steps.invoke({"foo": []}) @@ -1506,9 +1506,11 @@ class Input(TypedDict): class Output(TypedDict): bar: str - builder_named_steps = create_chain( - ("meow1", lambda state: {"foo": ["foo"]}), - ("meow2", lambda state: {"bar": state["foo"][0] + "bar"}), + builder_named_steps = create_pipeline( + [ + ("meow1", lambda state: {"foo": ["foo"]}), + ("meow2", lambda state: {"bar": state["foo"][0] + "bar"}), + ], state_schema=State, input_schema=Input, output_schema=Output,