Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 43 additions & 29 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
Tool as GapicTool,
ToolConfig as GapicToolConfig,
VideoMetadata,
Schema,
)
from langchain_google_vertexai._base import _VertexAICommon
from langchain_google_vertexai._image_utils import (
Expand Down Expand Up @@ -1570,9 +1571,11 @@ def _prepare_params(
f"`response_mime_type` is set to one of {allowed_mime_types}"
)
raise ValueError(error_message)

gapic_response_schema = _convert_schema_dict_to_gapic(response_schema)
params["response_schema"] = gapic_response_schema
if isinstance(response_schema, Schema):
params["response_schema"] = response_schema
else:
gapic_response_schema = _convert_schema_dict_to_gapic(response_schema)
params["response_schema"] = gapic_response_schema

audio_timestamp = kwargs.get("audio_timestamp", self.audio_timestamp)
if audio_timestamp is not None:
Expand Down Expand Up @@ -2088,7 +2091,7 @@ async def _astream(

def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel], Type],
schema: Union[Dict, Type[BaseModel], Type, Schema],
*,
include_raw: bool = False,
method: Optional[Literal["json_mode"]] = None,
Expand Down Expand Up @@ -2223,36 +2226,47 @@ class Explanation(BaseModel):
parser: OutputParserLike

if method == "json_mode":
if isinstance(schema, type) and is_basemodel_subclass(schema):
if issubclass(schema, BaseModelV1):
schema_json = schema.schema()
else:
schema_json = schema.model_json_schema()
parser = PydanticOutputParser(pydantic_object=schema)
if isinstance(schema, Schema):
llm = self.bind(
response_mime_type="application/json",
response_schema=schema,
ls_structured_output_format={
"kwargs": {"method": method},
"schema": schema,
},
)
parser = JsonOutputParser()
else:
if is_typeddict(schema):
schema_json = convert_to_json_schema(schema)
elif isinstance(schema, dict):
schema_json = schema
if isinstance(schema, type) and is_basemodel_subclass(schema):
if issubclass(schema, BaseModelV1):
schema_json = schema.schema()
else:
schema_json = schema.model_json_schema()
parser = PydanticOutputParser(pydantic_object=schema)
else:
raise ValueError(f"Unsupported schema type {type(schema)}")
parser = JsonOutputParser()
if is_typeddict(schema):
schema_json = convert_to_json_schema(schema)
elif isinstance(schema, dict):
schema_json = schema
else:
raise ValueError(f"Unsupported schema type {type(schema)}")
parser = JsonOutputParser()

# Resolve refs in schema because they are not supported
# by the Gemini API.
schema_json = replace_defs_in_schema(schema_json)
# Resolve refs in schema because they are not supported
# by the Gemini API.
schema_json = replace_defs_in_schema(schema_json)

# API does not support anyOf.
schema_json = _strip_nullable_anyof(schema_json)
# API does not support anyOf.
schema_json = _strip_nullable_anyof(schema_json)

llm = self.bind(
response_mime_type="application/json",
response_schema=schema_json,
ls_structured_output_format={
"kwargs": {"method": method},
"schema": schema_json,
},
)
llm = self.bind(
response_mime_type="application/json",
response_schema=schema_json,
ls_structured_output_format={
"kwargs": {"method": method},
"schema": schema_json,
},
)
else:
tool_name = _get_tool_name(schema)
if isinstance(schema, type) and is_basemodel_subclass(schema):
Expand Down
5 changes: 3 additions & 2 deletions libs/vertexai/langchain_google_vertexai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import logging
from difflib import get_close_matches
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union

from google.cloud.aiplatform_v1beta1.types import Schema
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand Down Expand Up @@ -38,7 +39,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
The model also needs to be prompted to output the appropriate response
type, otherwise the behavior is undefined. This is a preview feature.
"""
response_schema: Optional[Dict[str, Any]] = None
response_schema: Optional[Union[Dict[str, Any], Schema]] = None
""" Optional. Enforce an schema to the output.
The format of the dictionary should follow Open API schema.
"""
Expand Down
39 changes: 39 additions & 0 deletions libs/vertexai/tests/integration_tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json

import pytest
from google.cloud.aiplatform_v1beta1.types import Schema, Type
from langchain_core.outputs import LLMResult
from langchain_core.rate_limiters import InMemoryRateLimiter

Expand Down Expand Up @@ -119,3 +120,41 @@ def test_structured_output_schema_json():
assert isinstance(parsed_response, list)
assert len(parsed_response) > 0
assert "recipe_name" in parsed_response[0]


@pytest.mark.extended
def test_structured_output_schema_json_with_openapi_schema_object():
model = VertexAI(
rate_limiter=rate_limiter,
model_name="gemini-2.0-flash-001",
response_mime_type="application/json",
response_schema=Schema(
type_=Type.ARRAY,
items=Schema(
type_=Type.OBJECT,
properties={
"recipe_name": Schema(type_=Type.STRING),
"level": Schema(type_=Type.ENUM, values=["easy", "medium", "hard"]),
},
required=["recipe_name", "level"],
property_ordering=["level", "recipe_name"],
),
min_items=3,
max_items=4,
),
)

response = model.invoke("List a few popular cookie recipes")

assert isinstance(response, str)
parsed_response = json.loads(response)
assert isinstance(parsed_response, list)
assert len(parsed_response) >= 3 and len(parsed_response) <= 4
for recipe in parsed_response:
assert isinstance(recipe, dict)
assert "recipe_name" in recipe
assert "level" in recipe
assert isinstance(recipe["recipe_name"], str)
assert recipe["level"] in ["easy", "medium", "hard"]
keys = list(recipe.keys())
assert keys.index("level") < keys.index("recipe_name")