Skip to content

Commit

Permalink
remove base64 images from model api logging (#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjallaire authored Sep 3, 2024
1 parent 7100c42 commit 8630ca9
Show file tree
Hide file tree
Showing 13 changed files with 131 additions and 41 deletions.
3 changes: 2 additions & 1 deletion src/inspect_ai/model/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
ChatMessageUser,
)
from ._generate_config import GenerateConfig
from ._model_output import ModelCall, ModelOutput, ModelUsage
from ._model_call import ModelCall
from ._model_output import ModelOutput, ModelUsage

logger = logging.getLogger(__name__)

Expand Down
58 changes: 58 additions & 0 deletions src/inspect_ai/model/_model_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any, Callable

from pydantic import BaseModel, JsonValue

from inspect_ai._util.json import jsonable_python

ModelCallFilter = Callable[[JsonValue | None, JsonValue], JsonValue]
"""Filter for transforming or removing some values (e.g. images).
The first parmaeter is the key if the value is a dictionary item.
The second parameter is the value. Return a modified value if appropriate.
"""


class ModelCall(BaseModel):
"""Model call (raw request/response data)."""

request: dict[str, JsonValue]
"""Raw data posted to model."""

response: dict[str, JsonValue]
"""Raw response data from model."""

@staticmethod
def create(
request: Any, response: Any, filter: ModelCallFilter | None = None
) -> "ModelCall":
"""Create a ModelCall object.
Create a ModelCall from arbitrary request and response objects (they might
be dataclasses, Pydandic objects, dicts, etc.). Converts all values to
JSON serialiable (exluding those that can't be)
Args:
request (Any): Request object (dict, dataclass, BaseModel, etc.)
response (Any): Response object (dict, dataclass, BaseModel, etc.)
filter (ModelCallFilter): Function for filtering model call data.
"""
request_dict = jsonable_python(request)
if filter:
request_dict = _walk_json_value(None, request_dict, filter)
response_dict = jsonable_python(response)
if filter:
response_dict = _walk_json_value(None, response_dict, filter)
return ModelCall(request=request_dict, response=response_dict)


def _walk_json_value(
key: JsonValue | None, value: JsonValue, filter: ModelCallFilter
) -> JsonValue:
value = filter(key, value)
if isinstance(value, list):
return [_walk_json_value(None, v, filter) for v in value]
elif isinstance(value, dict):
return {k: _walk_json_value(k, v, filter) for k, v in value.items()}
else:
return value
29 changes: 1 addition & 28 deletions src/inspect_ai/model/_model_output.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import uuid
from typing import Any, Literal

from pydantic import BaseModel, Field, JsonValue
from pydantic import BaseModel, Field

from inspect_ai._util.json import jsonable_python
from inspect_ai.tool._tool_call import ToolCall

from ._chat_message import ChatMessageAssistant
Expand Down Expand Up @@ -173,29 +172,3 @@ def for_tool_call(
)
],
)


class ModelCall(BaseModel):
"""Model call (raw request/response data)."""

request: dict[str, JsonValue]
"""Raw data posted to model."""

response: dict[str, JsonValue]
"""Raw response data from model."""

@staticmethod
def create(request: Any, response: Any) -> "ModelCall":
"""Create a ModelCall object.
Create a ModelCall from arbitrary request and response objects (they might
be dataclasses, Pydandic objects, dicts, etc.). Converts all values to
JSON serialiable (exluding those that can't be)
Args:
request (Any): Request object (dict, dataclass, BaseModel, etc.)
response (Any): Response object (dict, dataclass, BaseModel, etc.)
"""
return ModelCall(
request=jsonable_python(request), response=jsonable_python(response)
)
22 changes: 20 additions & 2 deletions src/inspect_ai/model/_providers/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import os
from copy import copy
from logging import getLogger
from typing import Any, Literal, Tuple, cast

Expand All @@ -23,6 +24,7 @@
ToolUseBlockParam,
message_create_params,
)
from pydantic import JsonValue
from typing_extensions import override

from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
Expand All @@ -40,9 +42,9 @@
)
from .._generate_config import GenerateConfig
from .._model import ModelAPI
from .._model_call import ModelCall
from .._model_output import (
ChatCompletionChoice,
ModelCall,
ModelOutput,
ModelUsage,
StopReason,
Expand Down Expand Up @@ -150,7 +152,11 @@ async def generate(
output = model_output_from_message(message, tools)

# return output and call
call = ModelCall.create(request=request, response=message.model_dump())
call = ModelCall.create(
request=request,
response=message.model_dump(),
filter=model_call_filter,
)

return output, call

Expand Down Expand Up @@ -561,3 +567,15 @@ async def message_param_content(
type="image",
source=dict(type="base64", media_type=cast(Any, media_type), data=image),
)


def model_call_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
# remove base64 encoded images
if (
key == "source"
and isinstance(value, dict)
and value.get("type", None) == "base64"
):
value = copy(value)
value.update(data="")
return value
2 changes: 1 addition & 1 deletion src/inspect_ai/model/_providers/azureai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from .._chat_message import ChatMessage
from .._generate_config import GenerateConfig
from .._model import ModelAPI
from .._model_call import ModelCall
from .._model_output import (
ChatCompletionChoice,
ModelCall,
ModelOutput,
ModelUsage,
StopReason,
Expand Down
3 changes: 2 additions & 1 deletion src/inspect_ai/model/_providers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
)
from .._generate_config import GenerateConfig
from .._model import ModelAPI
from .._model_output import ChatCompletionChoice, ModelCall, ModelOutput, ModelUsage
from .._model_call import ModelCall
from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
from .util import (
ChatAPIHandler,
ChatAPIMessage,
Expand Down
3 changes: 2 additions & 1 deletion src/inspect_ai/model/_providers/cloudflare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from inspect_ai.tool import ToolChoice, ToolInfo

from ...model import ChatMessage, GenerateConfig, ModelAPI, ModelOutput
from .._model_output import ChatCompletionChoice, ModelCall
from .._model_call import ModelCall
from .._model_output import ChatCompletionChoice
from .util import (
ChatAPIHandler,
Llama31Handler,
Expand Down
12 changes: 11 additions & 1 deletion src/inspect_ai/model/_providers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from google.protobuf.json_format import MessageToDict, ParseDict
from google.protobuf.struct_pb2 import Struct
from pydantic import JsonValue
from typing_extensions import override

from inspect_ai._util.content import Content, ContentImage, ContentText
Expand All @@ -49,9 +50,9 @@
)
from .._generate_config import GenerateConfig
from .._model import ModelAPI
from .._model_call import ModelCall
from .._model_output import (
ChatCompletionChoice,
ModelCall,
ModelOutput,
ModelUsage,
StopReason,
Expand Down Expand Up @@ -196,9 +197,18 @@ def model_call(
else None,
),
response=response.to_dict(),
filter=model_call_filter,
)


def model_call_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
# remove images from raw api call
if key == "inline_data" and isinstance(value, dict) and "data" in value:
value = copy(value)
value.update(data="")
return value


def model_call_content(content: ContentDict) -> ContentDict:
return ContentDict(
role=content["role"], parts=[model_call_part(part) for part in content["parts"]]
Expand Down
3 changes: 2 additions & 1 deletion src/inspect_ai/model/_providers/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
)
from .._generate_config import GenerateConfig
from .._model import ModelAPI
from .._model_output import ChatCompletionChoice, ModelCall, ModelOutput, ModelUsage
from .._model_call import ModelCall
from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
from .util import as_stop_reason, model_base_url, parse_tool_call

GROQ_API_KEY = "GROQ_API_KEY"
Expand Down
2 changes: 1 addition & 1 deletion src/inspect_ai/model/_providers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
from .._chat_message import ChatMessage, ChatMessageAssistant
from .._generate_config import GenerateConfig
from .._model import ModelAPI
from .._model_call import ModelCall
from .._model_output import (
ChatCompletionChoice,
ModelCall,
ModelOutput,
ModelUsage,
StopReason,
Expand Down
20 changes: 18 additions & 2 deletions src/inspect_ai/model/_providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from copy import copy
from typing import Any, cast

from openai import (
Expand Down Expand Up @@ -29,6 +30,7 @@
ChatCompletionUserMessageParam,
)
from openai.types.shared_params.function_definition import FunctionDefinition
from pydantic import JsonValue
from typing_extensions import override

from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
Expand All @@ -40,10 +42,10 @@
from .._chat_message import ChatMessage, ChatMessageAssistant
from .._generate_config import GenerateConfig
from .._model import ModelAPI
from .._model_call import ModelCall
from .._model_output import (
ChatCompletionChoice,
Logprobs,
ModelCall,
ModelOutput,
ModelUsage,
)
Expand Down Expand Up @@ -175,7 +177,11 @@ async def generate(
if response.usage
else None
),
), ModelCall.create(request=request, response=response.model_dump())
), ModelCall.create(
request=request,
response=response.model_dump(),
filter=model_call_filter,
)
except APIStatusError as e:
completion, error = handle_content_filter_error(e)
return ModelOutput.from_content(
Expand Down Expand Up @@ -413,3 +419,13 @@ def handle_content_filter_error(e: APIStatusError) -> tuple[str, object | None]:
return CANT_ASSIST, e.body
else:
raise e


def model_call_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
# remove images from raw api call
if key == "image_url" and isinstance(value, dict) and "url" in value:
url = str(value.get("url"))
if url.startswith("data:"):
value = copy(value)
value.update(url="")
return value
12 changes: 11 additions & 1 deletion src/inspect_ai/model/_providers/vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import vertexai # type: ignore
from google.api_core.exceptions import TooManyRequests
from google.protobuf.json_format import MessageToDict
from pydantic import JsonValue
from typing_extensions import override
from vertexai.generative_models import ( # type: ignore
Candidate,
Expand Down Expand Up @@ -34,9 +35,9 @@
)
from .._generate_config import GenerateConfig
from .._model import ModelAPI
from .._model_call import ModelCall
from .._model_output import (
ChatCompletionChoice,
ModelCall,
ModelOutput,
ModelUsage,
StopReason,
Expand Down Expand Up @@ -179,9 +180,18 @@ def model_call(
tools=[tool.to_dict() for tool in tools] if tools is not None else None,
),
response=response.to_dict(),
filter=model_call_filter,
)


def model_call_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
# remove images from raw api call
if key == "inline_data" and isinstance(value, dict) and "data" in value:
value = copy(value)
value.update(data="")
return value


def model_call_content(content: VertexContent) -> dict[str, Any]:
return cast(dict[str, Any], content.to_dict())

Expand Down
3 changes: 2 additions & 1 deletion src/inspect_ai/solver/_subtask/transcript.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from inspect_ai.log._message import LoggingMessage
from inspect_ai.model._chat_message import ChatMessage
from inspect_ai.model._generate_config import GenerateConfig
from inspect_ai.model._model_output import ModelCall, ModelOutput
from inspect_ai.model._model_call import ModelCall
from inspect_ai.model._model_output import ModelOutput
from inspect_ai.scorer._metric import Score
from inspect_ai.tool._tool import ToolResult
from inspect_ai.tool._tool_call import ToolCallError
Expand Down

0 comments on commit 8630ca9

Please sign in to comment.