diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index 8b147c3c0bcaa..9dba8fd4cf5f1 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -4,8 +4,9 @@ import warnings from abc import ABC +from collections.abc import Callable, Sequence from string import Formatter -from typing import Any, Callable, Literal +from typing import Any, Literal from pydantic import BaseModel, create_model @@ -148,9 +149,7 @@ def mustache_template_vars( Defs = dict[str, "Defs"] -def mustache_schema( - template: str, -) -> type[BaseModel]: +def mustache_schema(template: str) -> type[BaseModel]: """Get the variables from a mustache template. Args: @@ -174,6 +173,11 @@ def mustache_schema( fields[prefix] = False elif type_ in {"variable", "no escape"}: fields[prefix + tuple(key.split("."))] = True + + for fkey, fval in fields.items(): + fields[fkey] = fval and not any( + is_subsequence(fkey, k) for k in fields if k != fkey + ) defs: Defs = {} # None means leaf node while fields: field, is_leaf = fields.popitem() @@ -326,3 +330,12 @@ def pretty_repr( def pretty_print(self) -> None: """Print a pretty representation of the prompt.""" print(self.pretty_repr(html=is_interactive_env())) # noqa: T201 + + +def is_subsequence(child: Sequence, parent: Sequence) -> bool: + """Return True if child is subsequence of parent.""" + if len(child) == 0 or len(parent) == 0: + return False + if len(parent) < len(child): + return False + return all(child[i] == parent[i] for i in range(len(child))) diff --git a/libs/core/tests/unit_tests/prompts/test_string.py b/libs/core/tests/unit_tests/prompts/test_string.py new file mode 100644 index 0000000000000..96c573c72f222 --- /dev/null +++ b/libs/core/tests/unit_tests/prompts/test_string.py @@ -0,0 +1,32 @@ +import pytest +from packaging import version + +from langchain_core.prompts.string import mustache_schema +from langchain_core.utils.pydantic import PYDANTIC_VERSION + +PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION + + +@pytest.mark.skipif( + not PYDANTIC_VERSION_AT_LEAST_29, + reason=( + "Only test with most recent version of pydantic. " + "Pydantic introduced small fixes to generated JSONSchema on minor versions." + ), +) +def test_mustache_schema_parent_child() -> None: + template = "{{x.y}} {{x}}" + expected = { + "$defs": { + "x": { + "properties": {"y": {"default": None, "title": "Y", "type": "string"}}, + "title": "x", + "type": "object", + } + }, + "properties": {"x": {"$ref": "#/$defs/x", "default": None}}, + "title": "PromptInput", + "type": "object", + } + actual = mustache_schema(template).model_json_schema() + assert expected == actual