Skip to content

Commit 012ef17

Browse files
committed
Feature: Allow for nested templating in instructions
1 parent 905c20d commit 012ef17

File tree

2 files changed

+124
-29
lines changed

2 files changed

+124
-29
lines changed

src/google/adk/flows/llm_flows/instructions.py

Lines changed: 75 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import re
2020
from typing import AsyncGenerator
21-
from typing import Generator
2221
from typing import TYPE_CHECKING
2322

2423
from typing_extensions import override
@@ -77,7 +76,52 @@ async def _populate_values(
7776
instruction_template: str,
7877
context: InvocationContext,
7978
) -> str:
80-
"""Populates values in the instruction template, e.g. state, artifact, etc."""
79+
"""Populates values in the instruction template, e.g. state, artifact, etc.
80+
81+
Supports nested dot-separated references like:
82+
- state.user.name
83+
- artifact.config.settings
84+
- user.profile.email
85+
- user?.profile?.name? (optional markers at any level)
86+
"""
87+
88+
def _get_nested_value(
89+
obj, paths: list[str], is_optional: bool = False
90+
) -> str:
91+
"""Gets a nested value from an object using a list of path segments.
92+
93+
Args:
94+
obj: The object to get the value from
95+
paths: List of path segments to traverse
96+
is_optional: Whether the entire path is optional
97+
98+
Returns:
99+
The value as a string
100+
101+
Raises:
102+
KeyError: If the path doesn't exist and the reference is not optional
103+
"""
104+
if not paths:
105+
return str(obj)
106+
107+
# Get current part and remaining paths
108+
current_part = paths[0]
109+
110+
# Handle optional markers
111+
is_current_optional = current_part.endswith('?') or is_optional
112+
clean_part = current_part.removesuffix('?')
113+
114+
# Get value for current part
115+
if isinstance(obj, dict) and clean_part in obj:
116+
return _get_nested_value(obj[clean_part], paths[1:], is_current_optional)
117+
elif hasattr(obj, clean_part):
118+
return _get_nested_value(
119+
getattr(obj, clean_part), paths[1:], is_current_optional
120+
)
121+
elif is_current_optional:
122+
return ''
123+
else:
124+
raise KeyError(f'Key not found: {clean_part}')
81125

82126
async def _async_sub(pattern, repl_async_fn, string) -> str:
83127
result = []
@@ -96,29 +140,37 @@ async def _replace_match(match) -> str:
96140
if var_name.endswith('?'):
97141
optional = True
98142
var_name = var_name.removesuffix('?')
99-
if var_name.startswith('artifact.'):
100-
var_name = var_name.removeprefix('artifact.')
101-
if context.artifact_service is None:
102-
raise ValueError('Artifact service is not initialized.')
103-
artifact = await context.artifact_service.load_artifact(
104-
app_name=context.session.app_name,
105-
user_id=context.session.user_id,
106-
session_id=context.session.id,
107-
filename=var_name,
108-
)
109-
if not var_name:
110-
raise KeyError(f'Artifact {var_name} not found.')
111-
return str(artifact)
112-
else:
113-
if not _is_valid_state_name(var_name):
114-
return match.group()
115-
if var_name in context.session.state:
116-
return str(context.session.state[var_name])
143+
144+
try:
145+
if var_name.startswith('artifact.'):
146+
var_name = var_name.removeprefix('artifact.')
147+
if context.artifact_service is None:
148+
raise ValueError('Artifact service is not initialized.')
149+
artifact = await context.artifact_service.load_artifact(
150+
app_name=context.session.app_name,
151+
user_id=context.session.user_id,
152+
session_id=context.session.id,
153+
filename=var_name,
154+
)
155+
if not var_name:
156+
raise KeyError(f'Artifact {var_name} not found.')
157+
return str(artifact)
117158
else:
118-
if optional:
119-
return ''
120-
else:
159+
if not _is_valid_state_name(var_name.split('.')[0].removesuffix('?')):
160+
return match.group()
161+
# Try to resolve nested path
162+
try:
163+
return _get_nested_value(
164+
context.session.state, var_name.split('.'), optional
165+
)
166+
except KeyError:
167+
if not _is_valid_state_name(var_name):
168+
return match.group()
121169
raise KeyError(f'Context variable not found: `{var_name}`.')
170+
except Exception as e:
171+
if optional:
172+
return ''
173+
raise e
122174

123175
return await _async_sub(r'{+[^{}]*}+', _replace_match, instruction_template)
124176

tests/unittests/flows/llm_flows/test_instructions.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from google.genai import types
16+
import pytest
17+
1518
from google.adk.agents import Agent
1619
from google.adk.agents.readonly_context import ReadonlyContext
1720
from google.adk.flows.llm_flows import instructions
1821
from google.adk.models import LlmRequest
1922
from google.adk.sessions import Session
20-
from google.genai import types
21-
import pytest
2223

2324
from ... import utils
2425

@@ -33,15 +34,21 @@ async def test_build_system_instruction():
3334
model="gemini-1.5-flash",
3435
name="agent",
3536
instruction=("""Use the echo_info tool to echo { customerId }, \
36-
{{customer_int }, { non-identifier-float}}, \
37-
{'key1': 'value1'} and {{'key2': 'value2'}}."""),
37+
{{customer_int }, {customer.profile.name}, {customer?.preferences.alias}, \
38+
{ non-identifier-float}}, {'key1': 'value1'} and {{'key2': 'value2'}}."""),
3839
)
3940
invocation_context = utils.create_invocation_context(agent=agent)
4041
invocation_context.session = Session(
4142
app_name="test_app",
4243
user_id="test_user",
4344
id="test_id",
44-
state={"customerId": "1234567890", "customer_int": 30},
45+
state={
46+
"customerId": "1234567890",
47+
"customer_int": 30,
48+
"customer": {
49+
"profile": {"name": "Test User", "email": "[email protected]"}
50+
},
51+
},
4552
)
4653

4754
async for _ in instructions.request_processor.run_async(
@@ -52,7 +59,7 @@ async def test_build_system_instruction():
5259

5360
assert request.config.system_instruction == (
5461
"""Use the echo_info tool to echo 1234567890, 30, \
55-
{ non-identifier-float}}, {'key1': 'value1'} and {{'key2': 'value2'}}."""
62+
Test User, , { non-identifier-float}}, {'key1': 'value1'} and {{'key2': 'value2'}}."""
5663
)
5764

5865

@@ -162,3 +169,39 @@ async def test_build_system_instruction_with_namespace():
162169
assert request.config.system_instruction == (
163170
"""Use the echo_info tool to echo 1234567890, app_value, user_value, {a:key}."""
164171
)
172+
173+
174+
@pytest.mark.asyncio
175+
async def test_nested_templating():
176+
request = LlmRequest(
177+
model="gemini-1.5-flash",
178+
config=types.GenerateContentConfig(system_instruction=""),
179+
)
180+
agent = Agent(
181+
model="gemini-1.5-flash",
182+
name="agent",
183+
instruction=(
184+
"""Echo the following: {user.profile.name}, {user.profile.email}, {user.settings?.preferences.theme}, {user.preferences.value}"""
185+
),
186+
)
187+
invocation_context = utils.create_invocation_context(agent=agent)
188+
invocation_context.session = Session(
189+
app_name="test_app",
190+
user_id="test_user",
191+
id="test_id",
192+
state={
193+
"user": {
194+
"profile": {"name": "Test User", "email": "[email protected]"}
195+
}
196+
},
197+
)
198+
199+
async for _ in instructions.request_processor.run_async(
200+
invocation_context,
201+
request,
202+
):
203+
pass
204+
205+
assert request.config.system_instruction == (
206+
"""Echo the following: Test User, [email protected], , {user.preferences.value}"""
207+
)

0 commit comments

Comments
 (0)