Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
from ._anthropic_client import (
AnthropicBedrockChatCompletionClient,
AnthropicChatCompletionClient,
AnthropicVertexChatCompletionClient,
BaseAnthropicChatCompletionClient,
)
from .config import (
AnthropicBedrockClientConfiguration,
AnthropicBedrockClientConfigurationConfigModel,
AnthropicClientConfiguration,
AnthropicClientConfigurationConfigModel,
AnthropicVertexClientConfiguration,
AnthropicVertexClientConfigurationConfigModel,
BedrockInfo,
CreateArgumentsConfigModel,
VertexInfo,
)

__all__ = [
"AnthropicChatCompletionClient",
"AnthropicBedrockChatCompletionClient",
"AnthropicVertexChatCompletionClient",
"BaseAnthropicChatCompletionClient",
"AnthropicClientConfiguration",
"AnthropicBedrockClientConfiguration",
"AnthropicVertexClientConfiguration",
"AnthropicClientConfigurationConfigModel",
"AnthropicBedrockClientConfigurationConfigModel",
"AnthropicVertexClientConfigurationConfigModel",
"CreateArgumentsConfigModel",
"BedrockInfo",
"VertexInfo",
]
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)

import tiktoken
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock, AsyncStream
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock, AsyncAnthropicVertex, AsyncStream
from anthropic.types import (
Base64ImageSourceParam,
ContentBlock,
Expand Down Expand Up @@ -71,7 +71,10 @@
AnthropicBedrockClientConfigurationConfigModel,
AnthropicClientConfiguration,
AnthropicClientConfigurationConfigModel,
AnthropicVertexClientConfiguration,
AnthropicVertexClientConfigurationConfigModel,
BedrockInfo,
VertexInfo,
)

logger = logging.getLogger(EVENT_LOGGER_NAME)
Expand Down Expand Up @@ -1414,3 +1417,130 @@ def _from_config(cls, config: AnthropicBedrockClientConfigurationConfigModel) ->
}

return cls(**copied_config)


class AnthropicVertexChatCompletionClient(
BaseAnthropicChatCompletionClient, Component[AnthropicVertexClientConfigurationConfigModel]
):
"""
Chat completion client for Anthropic's Claude models via Google Vertex AI.

Args:
model (str): The Claude model to use (e.g., "claude-3-sonnet-20240229", "claude-3-opus-20240229")
vertex_info (VertexInfo): Configuration for Vertex AI including project_id and region
max_tokens (int, optional): Maximum tokens in the response. Default is 4096.
temperature (float, optional): Controls randomness. Lower is more deterministic. Default is 1.0.
top_p (float, optional): Controls diversity via nucleus sampling. Default is 1.0.
top_k (int, optional): Controls diversity via top-k sampling. Default is -1 (disabled).
model_info (ModelInfo, optional): The capabilities of the model. Required if using a custom model.

To use this client, you must install the Anthropic extension:

.. code-block:: bash

pip install "autogen-ext[anthropic]"

Example:

.. code-block:: python

import asyncio
from autogen_ext.models.anthropic import AnthropicVertexChatCompletionClient, VertexInfo
from autogen_core.models import UserMessage


async def main():
vertex_client = AnthropicVertexChatCompletionClient(
model="claude-3-sonnet-20240229",
vertex_info=VertexInfo(project_id="your-gcp-project-id", region="us-east5"),
)

result = await vertex_client.create([UserMessage(content="What is the capital of France?", source="user")]) # type: ignore
print(result)


if __name__ == "__main__":
asyncio.run(main())

To load the client from a configuration:

.. code-block:: python

from autogen_core.models import ChatCompletionClient

config = {
"provider": "AnthropicVertexChatCompletionClient",
"config": {
"model": "claude-3-sonnet-20240229",
"vertex_info": {"project_id": "your-gcp-project-id", "region": "us-east5"},
},
}

client = ChatCompletionClient.load_component(config)
"""

component_type = "model"
component_config_schema = AnthropicVertexClientConfigurationConfigModel
component_provider_override = "autogen_ext.models.anthropic.AnthropicVertexChatCompletionClient"

def __init__(self, **kwargs: Unpack[AnthropicVertexClientConfiguration]):
if "model" not in kwargs:
raise ValueError("model is required for AnthropicVertexChatCompletionClient")

self._raw_config: Dict[str, Any] = dict(kwargs).copy()
copied_args = dict(kwargs).copy()

model_info: Optional[ModelInfo] = None
if "model_info" in kwargs:
model_info = kwargs["model_info"]
del copied_args["model_info"]

vertex_info: Optional[VertexInfo] = None
if "vertex_info" in kwargs:
vertex_info = kwargs["vertex_info"]

if vertex_info is None:
raise ValueError("vertex_info is required for AnthropicVertexChatCompletionClient")

# Handle vertex_info
project_id = vertex_info["project_id"]
region = vertex_info["region"]

client = AsyncAnthropicVertex(
project_id=project_id,
region=region,
)
create_args = _create_args_from_config(copied_args)

super().__init__(
client=client,
create_args=create_args,
model_info=model_info,
)

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state["_client"] = None
return state

def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
# Recreate the client from raw config
vertex_info = state["_raw_config"]["vertex_info"]
self._client = AsyncAnthropicVertex(
project_id=vertex_info["project_id"],
region=vertex_info["region"],
)

def _to_config(self) -> AnthropicVertexClientConfigurationConfigModel:
copied_config = self._raw_config.copy()
return AnthropicVertexClientConfigurationConfigModel(**copied_config)

@classmethod
def _from_config(cls, config: AnthropicVertexClientConfigurationConfigModel) -> Self:
copied_config = config.model_copy().model_dump(exclude_none=True)

# Handle vertex_info properly - no secret values to extract like bedrock
# vertex_info contains project_id and region which are not secret

return cls(**copied_config)
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ class BedrockInfo(TypedDict):
"""aws region for the aws account to gain bedrock model access"""


class VertexInfo(TypedDict):
"""VertexInfo is a dictionary that contains information about Vertex AI configuration.
It is expected to be used in the vertex_info property of a model client.

"""

project_id: Required[str]
"""GCP project ID for Vertex AI access"""
region: Required[str]
"""GCP region for Vertex AI access"""


class BaseAnthropicClientConfiguration(CreateArguments, total=False):
api_key: str
base_url: Optional[str]
Expand All @@ -64,6 +76,10 @@ class AnthropicBedrockClientConfiguration(AnthropicClientConfiguration, total=Fa
bedrock_info: BedrockInfo


class AnthropicVertexClientConfiguration(AnthropicClientConfiguration, total=False):
vertex_info: VertexInfo


# Pydantic equivalents of the above TypedDicts
class ThinkingConfigModel(BaseModel):
"""Configuration for thinking mode."""
Expand Down Expand Up @@ -111,3 +127,14 @@ class BedrockInfoConfigModel(TypedDict):

class AnthropicBedrockClientConfigurationConfigModel(AnthropicClientConfigurationConfigModel):
bedrock_info: BedrockInfoConfigModel | None = None


class VertexInfoConfigModel(TypedDict):
project_id: Required[str]
"""GCP project ID for Vertex AI access"""
region: Required[str]
"""GCP region for Vertex AI access"""


class AnthropicVertexClientConfigurationConfigModel(AnthropicClientConfigurationConfigModel):
vertex_info: VertexInfoConfigModel | None = None
Loading
Loading