diff --git a/backend/chainlit/data/dynamodb.py b/backend/chainlit/data/dynamodb.py index b3524b0734..f43ee3178e 100644 --- a/backend/chainlit/data/dynamodb.py +++ b/backend/chainlit/data/dynamodb.py @@ -3,6 +3,7 @@ import logging import os import random +from decimal import Decimal from dataclasses import asdict from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -60,17 +61,49 @@ def __init__( def _get_current_timestamp(self) -> str: return datetime.now().isoformat() + "Z" + def _convert_floats_to_decimal(self, obj): + if isinstance(obj, list): + return [self._convert_floats_to_decimal(i) for i in obj] + + for key, value in obj.items(): + if isinstance(value, float): + obj[key] = Decimal(str(value)) + elif isinstance(value, dict): + self._convert_floats_to_decimal(value) + elif isinstance(value, list): + obj[key] = [self._convert_floats_to_decimal(i) for i in value] + + return obj + + def _convert_decimal_to_floats(self, obj): + if isinstance(obj, list): + return [self._convert_decimal_to_floats(i) for i in obj] + + for key, value in obj.items(): + if isinstance(value, Decimal): + obj[key] = float(value) + elif isinstance(value, dict): + self._convert_decimal_to_floats(value) + elif isinstance(value, list): + obj[key] = [self._convert_decimal_to_floats(i) for i in value] + + return obj + def _serialize_item(self, item: Dict[str, Any]) -> Dict[str, Any]: + item = self._convert_floats_to_decimal(item) return { - key: self._type_serializer.serialize(value) for key, value in item.items() + key: self._type_serializer.serialize(value) + for key, value in item.items() } def _deserialize_item(self, item: Dict[str, Any]) -> Dict[str, Any]: + item = self._convert_decimal_to_floats(item) return { key: self._type_deserializer.deserialize(value) for key, value in item.items() } + def _update_item(self, key: Dict[str, Any], updates: Dict[str, Any]): update_expr: List[str] = [] expression_attribute_names = {}