Skip to content

Commit

Permalink
Python: Updates to doc gen demo after introducing KernelArgs to Chat …
Browse files Browse the repository at this point in the history
…Completion Agent. Typing updates. (#10540)

### Motivation and Context

The latest changes to ChatCompletionAgent, where one now provides
settings via KernelArgs, didn't make it to this demo. Additionally,
there are some typing updates we can include to make Pylance satisfied.

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description

Updates the demo to correctly pass settings via KernelArguments with the
ChatCompletionAgent. Updates typing and some handling of code to satisfy
Pylance errors.

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
moonbox3 authored Feb 14, 2025
1 parent cd84e87 commit d350b83
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sys
from collections.abc import AsyncIterable
from typing import TYPE_CHECKING, Any

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand All @@ -13,6 +14,10 @@
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.functions.kernel_arguments import KernelArguments

if TYPE_CHECKING:
from semantic_kernel.kernel import Kernel

INSTRUCTION = """
You are a code validation agent in a collaborative document creation chat.
Expand All @@ -39,19 +44,26 @@ def __init__(self):
settings.function_choice_behavior = FunctionChoiceBehavior.Auto(maximum_auto_invoke_attempts=1)

super().__init__(
service_id=CustomAgentBase.SERVICE_ID,
kernel=kernel,
execution_settings=settings,
arguments=KernelArguments(settings=settings),
name="CodeValidationAgent",
instructions=INSTRUCTION.strip(),
description=DESCRIPTION.strip(),
)

@override
async def invoke(self, history: ChatHistory) -> AsyncIterable[ChatMessageContent]:
async def invoke(
self,
history: ChatHistory,
arguments: KernelArguments | None = None,
kernel: "Kernel | None" = None,
**kwargs: Any,
) -> AsyncIterable[ChatMessageContent]:
cloned_history = history.model_copy(deep=True)
cloned_history.add_user_message(
"Now validate the Python code in the latest document draft and summarize any errors."
)

async for response_message in super().invoke(cloned_history):
async for response_message in super().invoke(cloned_history, arguments=arguments, kernel=kernel, **kwargs):
yield response_message
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sys
from collections.abc import AsyncIterable
from typing import TYPE_CHECKING, Any

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand All @@ -13,6 +14,10 @@
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.functions.kernel_arguments import KernelArguments

if TYPE_CHECKING:
from semantic_kernel.kernel import Kernel

INSTRUCTION = """
You are part of a chat with multiple agents focused on creating technical content.
Expand All @@ -36,17 +41,24 @@ def __init__(self):
settings.function_choice_behavior = FunctionChoiceBehavior.Auto()

super().__init__(
service_id=CustomAgentBase.SERVICE_ID,
kernel=kernel,
execution_settings=settings,
arguments=KernelArguments(settings=settings),
name="ContentCreationAgent",
instructions=INSTRUCTION.strip(),
description=DESCRIPTION.strip(),
)

@override
async def invoke(self, history: ChatHistory) -> AsyncIterable[ChatMessageContent]:
async def invoke(
self,
history: ChatHistory,
arguments: KernelArguments | None = None,
kernel: "Kernel | None" = None,
**kwargs: Any,
) -> AsyncIterable[ChatMessageContent]:
cloned_history = history.model_copy(deep=True)
cloned_history.add_user_message("Now generate new content or revise existing content to incorporate feedback.")

async for response_message in super().invoke(cloned_history):
async for response_message in super().invoke(cloned_history, arguments=arguments, kernel=kernel, **kwargs):
yield response_message
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from abc import ABC
from collections.abc import AsyncIterable
from typing import ClassVar
from typing import Any, ClassVar

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand All @@ -14,6 +14,7 @@
from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import OpenAIChatCompletion
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.kernel import Kernel


Expand All @@ -27,7 +28,13 @@ def _create_kernel(self) -> Kernel:
return kernel

@override
async def invoke(self, history: ChatHistory) -> AsyncIterable[ChatMessageContent]:
async def invoke(
self,
history: ChatHistory,
arguments: KernelArguments | None = None,
kernel: "Kernel | None" = None,
**kwargs: Any,
) -> AsyncIterable[ChatMessageContent]:
# Since the history contains internal messages from other agents,
# we will do our best to filter out those. Unfortunately, there will
# be a side effect of losing the context of the conversation internal
Expand All @@ -41,5 +48,5 @@ async def invoke(self, history: ChatHistory) -> AsyncIterable[ChatMessageContent
if content:
filtered_chat_history.add_message(message)

async for response in super().invoke(filtered_chat_history):
async for response in super().invoke(filtered_chat_history, arguments=arguments, kernel=kernel, **kwargs):
yield response
18 changes: 15 additions & 3 deletions python/samples/demos/document_generator/agents/user_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sys
from collections.abc import AsyncIterable
from typing import TYPE_CHECKING, Any

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand All @@ -13,6 +14,10 @@
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.functions.kernel_arguments import KernelArguments

if TYPE_CHECKING:
from semantic_kernel.kernel import Kernel

INSTRUCTION = """
You are part of a chat with multiple agents working on a document.
Expand All @@ -37,19 +42,26 @@ def __init__(self):
settings.function_choice_behavior = FunctionChoiceBehavior.Auto(maximum_auto_invoke_attempts=1)

super().__init__(
service_id=CustomAgentBase.SERVICE_ID,
kernel=kernel,
execution_settings=settings,
arguments=KernelArguments(settings=settings),
name="UserAgent",
instructions=INSTRUCTION.strip(),
description=DESCRIPTION.strip(),
)

@override
async def invoke(self, history: ChatHistory) -> AsyncIterable[ChatMessageContent]:
async def invoke(
self,
history: ChatHistory,
arguments: KernelArguments | None = None,
kernel: "Kernel | None" = None,
**kwargs: Any,
) -> AsyncIterable[ChatMessageContent]:
cloned_history = history.model_copy(deep=True)
cloned_history.add_user_message(
"Now present the latest draft to the user for feedback and summarize their feedback."
)

async for response_message in super().invoke(cloned_history):
async for response_message in super().invoke(cloned_history, arguments=arguments, kernel=kernel, **kwargs):
yield response_message
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, ClassVar

from opentelemetry import trace
from pydantic import Field

from semantic_kernel.agents.strategies.selection.selection_strategy import SelectionStrategy
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
Expand All @@ -26,12 +27,7 @@ class CustomSelectionStrategy(SelectionStrategy):

NUM_OF_RETRIES: ClassVar[int] = 3

chat_completion_service: ChatCompletionClientBase

def __init__(self, **kwargs):
chat_completion_service = OpenAIChatCompletion()

super().__init__(chat_completion_service=chat_completion_service, **kwargs)
chat_completion_service: ChatCompletionClientBase = Field(default_factory=lambda: OpenAIChatCompletion())

async def next(self, agents: list["Agent"], history: list["ChatMessageContent"]) -> "Agent":
"""Select the next agent to interact with.
Expand Down Expand Up @@ -65,6 +61,9 @@ async def next(self, agents: list["Agent"], history: list["ChatMessageContent"])
AzureChatPromptExecutionSettings(),
)

if completion is None:
continue

try:
return agents[int(completion.content)]
except ValueError as ex:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, ClassVar

from opentelemetry import trace
from pydantic import Field

from semantic_kernel.agents.strategies.termination.termination_strategy import TerminationStrategy
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
Expand All @@ -27,11 +28,7 @@ class CustomTerminationStrategy(TerminationStrategy):
NUM_OF_RETRIES: ClassVar[int] = 3

maximum_iterations: int = 20
chat_completion_service: ChatCompletionClientBase

def __init__(self, **kwargs):
chat_completion_service = OpenAIChatCompletion()
super().__init__(chat_completion_service=chat_completion_service, **kwargs)
chat_completion_service: ChatCompletionClientBase = Field(default_factory=lambda: OpenAIChatCompletion())

async def should_agent_terminate(self, agent: "Agent", history: list["ChatMessageContent"]) -> bool:
"""Check if the agent should terminate.
Expand Down Expand Up @@ -62,6 +59,9 @@ async def should_agent_terminate(self, agent: "Agent", history: list["ChatMessag
AzureChatPromptExecutionSettings(),
)

if not completion:
continue

if TERMINATE_FALSE_KEYWORD in completion.content.lower():
return False
if TERMINATE_TRUE_KEYWORD in completion.content.lower():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ def read_file_by_name(
@kernel_function(description="List all files or subdirectories in a directory.")
def list_directory(
self, path: Annotated[str, "Path of a directory relative to the root of the repository."]
) -> Annotated[str, "Returns a list of files and subdirectories."]:
) -> Annotated[str, "Returns a list of files and subdirectories as a string."]:
path = os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", path)
try:
return os.listdir(path)
files = os.listdir(path)
# Join the list of files into a single string
return "\n".join(files)
except FileNotFoundError:
raise FileNotFoundError(f"Directory {path} not found in repository.")

0 comments on commit d350b83

Please sign in to comment.