diff --git a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py index 145cab313..c409a9633 100644 --- a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py +++ b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py @@ -1,7 +1,31 @@ +import inspect +import json + from langchain_core.callbacks import BaseCallbackHandler from langchain_core.outputs import LLMResult +def requires_no_arguments(func): + sig = inspect.signature(func) + for param in sig.parameters.values(): + if param.default is param.empty and param.kind in ( + param.POSITIONAL_ONLY, + param.POSITIONAL_OR_KEYWORD, + param.KEYWORD_ONLY, + ): + return False + return True + + +def convert_to_serializable(obj): + """Convert an object to a JSON serializable format""" + if hasattr(obj, "dict") and callable(obj.dict) and requires_no_arguments(obj.dict): + return obj.dict() + if hasattr(obj, "__dict__"): + return obj.__dict__ + return str(obj) + + class MetadataCallbackHandler(BaseCallbackHandler): """ When passed as a callback handler, this stores the LLMResult's @@ -23,4 +47,9 @@ def on_llm_end(self, response: LLMResult, **kwargs) -> None: if not (len(response.generations) and len(response.generations[0])): return - self.jai_metadata = response.generations[0][0].generation_info or {} + metadata = response.generations[0][0].generation_info or {} + + # Convert any non-serializable objects in metadata + self.jai_metadata = json.loads( + json.dumps(metadata, default=convert_to_serializable) + ) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 48dbe6193..6bd7d4e06 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -1,3 +1,4 @@ +import json from typing import Any, Dict, List, Literal, Optional, Union from jupyter_ai_magics import Persona @@ -128,6 +129,15 @@ class AgentStreamChunkMessage(BaseModel): on `BaseAgentMessage.metadata` for information. """ + @validator("metadata") + def validate_metadata(cls, v): + """Ensure metadata values are JSON serializable""" + try: + json.dumps(v) + return v + except TypeError as e: + raise ValueError(f"Metadata must be JSON serializable: {str(e)}") + class HumanChatMessage(BaseModel): type: Literal["human"] = "human"