diff --git a/.gitignore b/.gitignore index 20720fea6..f0cf3902e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ dist/ venv/ build/ +dbs/ docs/source/apiref docs/source/_misc docs/source/release_notes.rst diff --git a/chatsky/__rebuild_pydantic_models__.py b/chatsky/__rebuild_pydantic_models__.py index 1da4126a9..86661ad68 100644 --- a/chatsky/__rebuild_pydantic_models__.py +++ b/chatsky/__rebuild_pydantic_models__.py @@ -5,9 +5,13 @@ from chatsky.core.script import Node from chatsky.core.pipeline import Pipeline from chatsky.slots.slots import SlotManager +from chatsky.context_storages import DBContextStorage, ContextInfo +from chatsky.core.ctx_dict import ContextDict from chatsky.core.context import FrameworkData, ServiceState from chatsky.core.service import PipelineComponent +ContextInfo.model_rebuild() +ContextDict.model_rebuild() PipelineComponent.model_rebuild() Pipeline.model_rebuild() Script.model_rebuild() diff --git a/chatsky/conditions/standard.py b/chatsky/conditions/standard.py index 1acf4de40..5d782bbf1 100644 --- a/chatsky/conditions/standard.py +++ b/chatsky/conditions/standard.py @@ -202,7 +202,7 @@ def __init__( super().__init__(flow_labels=flow_labels, labels=labels, last_n_indices=last_n_indices) async def call(self, ctx: Context) -> bool: - labels = list(ctx.labels.values())[-self.last_n_indices :] # noqa: E203 + labels = await ctx.labels[-self.last_n_indices :] # noqa: E203 for label in labels: if label.flow_name in self.flow_labels or label in self.labels: return True diff --git a/chatsky/context_storages/__init__.py b/chatsky/context_storages/__init__.py index e41618440..18d95afaa 100644 --- a/chatsky/context_storages/__init__.py +++ b/chatsky/context_storages/__init__.py @@ -1,11 +1,10 @@ # -*- coding: utf-8 -*- -from .database import DBContextStorage, threadsafe_method, context_storage_factory -from .json import JSONContextStorage, json_available -from .pickle import PickleContextStorage, pickle_available +from .database import DBContextStorage, ContextInfo, context_storage_factory +from .file import JSONContextStorage, PickleContextStorage, ShelveContextStorage, json_available, pickle_available from .sql import SQLContextStorage, postgres_available, mysql_available, sqlite_available, sqlalchemy_available from .ydb import YDBContextStorage, ydb_available from .redis import RedisContextStorage, redis_available +from .memory import MemoryContextStorage from .mongo import MongoContextStorage, mongo_available -from .shelve import ShelveContextStorage from .protocol import PROTOCOLS, get_protocol_install_suggestion diff --git a/chatsky/context_storages/database.py b/chatsky/context_storages/database.py index faa70caf4..813371aa8 100644 --- a/chatsky/context_storages/database.py +++ b/chatsky/context_storages/database.py @@ -8,181 +8,383 @@ This class implements the basic functionality and can be extended to add additional features as needed. """ -import asyncio -import importlib -import threading -from functools import wraps +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Callable, Hashable, Optional +from asyncio import Lock +from json import loads +from functools import wraps +from importlib import import_module +from logging import getLogger +from pathlib import Path +from time import time_ns +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Union + +from pydantic import BaseModel, Field, PrivateAttr, TypeAdapter, field_serializer, field_validator + +from chatsky.utils.logging import collapse_num_list from .protocol import PROTOCOLS -from chatsky.core import Context + +if TYPE_CHECKING: + from chatsky.core.context import FrameworkData + +_SUBSCRIPT_TYPE = Union[Literal["__all__"], int] +_SUBSCRIPT_DICT = Dict[Literal["labels", "requests", "responses"], Union[_SUBSCRIPT_TYPE]] + +logger = getLogger(__name__) + + +class NameConfig: + """ + Configuration of names of different database parts, + including table names, column names, field names, etc. + """ + + _main_table: Literal["main"] = "main" + _turns_table: Literal["turns"] = "turns" + _key_column: Literal["key"] = "key" + _id_column: Literal["id"] = "id" + _current_turn_id_column: Literal["current_turn_id"] = "current_turn_id" + _created_at_column: Literal["created_at"] = "created_at" + _updated_at_column: Literal["updated_at"] = "updated_at" + _misc_column: Literal["misc"] = "misc" + _framework_data_column: Literal["framework_data"] = "framework_data" + _labels_field: Literal["labels"] = "labels" + _requests_field: Literal["requests"] = "requests" + _responses_field: Literal["responses"] = "responses" + + +class ContextInfo(BaseModel): + """ + Main context fields, that are stored in `MAIN` table. + For most of the database backends, it will be serialized to json. + For SQL database backends, it will be written to different table columns. + For memory context storage, it won't be serialized at all. + """ + + turn_id: int + created_at: int = Field(default_factory=time_ns) + updated_at: int = Field(default_factory=time_ns) + misc: Dict[str, Any] = Field(default_factory=dict) + framework_data: FrameworkData = Field(default_factory=dict, validate_default=True) + + _misc_adaptor: TypeAdapter[Dict[str, Any]] = PrivateAttr(default=TypeAdapter(Dict[str, Any])) + + @field_validator("framework_data", "misc", mode="before") + @classmethod + def _validate_framework_data(cls, value: Any) -> Dict: + if isinstance(value, bytes) or isinstance(value, str): + value = loads(value) + return value + + @field_serializer("misc", when_used="always") + def _serialize_misc(self, misc: Dict[str, Any]) -> bytes: + return self._misc_adaptor.dump_json(misc) + + @field_serializer("framework_data", when_used="always") + def _serialize_framework_data(self, framework_data: FrameworkData) -> bytes: + return framework_data.model_dump_json().encode() + + def __eq__(self, other: Any) -> bool: + if isinstance(other, BaseModel): + return self.model_dump() == other.model_dump() + return super().__eq__(other) + + +def _lock(function: Callable[..., Awaitable[Any]]): + @wraps(function) + async def wrapped(self: DBContextStorage, *args, **kwargs): + if not self.is_concurrent or not self.connected: + async with self._sync_lock: + return await function(self, *args, **kwargs) + else: + return await function(self, *args, **kwargs) + + return wrapped class DBContextStorage(ABC): - r""" - An abstract interface for `chatsky` DB context storages. - It includes the most essential methods of the python `dict` class. - Can not be instantiated. - - :param path: Parameter `path` should be set with the URI of the database. - It includes a prefix and the required connection credentials. - Example: postgresql+asyncpg://user:password@host:port/database - In the case of classes that save data to hard drive instead of external databases - you need to specify the location of the file, like you do in sqlite. - Keep in mind that in Windows you will have to use double backslashes '\\' - instead of forward slashes '/' when defining the file path. + """ + Base context storage class. + Includes a set of methods for storing and reading different context parts. + :param path: Path to the storage instance. + :param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not. + :param partial_read_config: Dictionary of subscripts for all possible turn items. """ - def __init__(self, path: str): + _default_subscript_value: int = 3 + + def __init__( + self, + path: str, + rewrite_existing: bool = False, + partial_read_config: Optional[_SUBSCRIPT_DICT] = None, + ): _, _, file_path = path.partition("://") - self.full_path = path - """Full path to access the context storage, as it was provided by user.""" - self.path = file_path - """`full_path` without a prefix defining db used""" - self._lock = threading.Lock() - """Threading for methods that require single thread access.""" + configuration = partial_read_config if partial_read_config is not None else dict() - def __getitem__(self, key: Hashable) -> Context: + self.full_path = path + """ + Full path to access the context storage, as it was provided by user. """ - Synchronous method for accessing stored Context. - :param key: Hashable key used to store Context instance. - :return: The stored context, associated with the given key. + self.path = Path(file_path) + """ + `full_path` without a prefix defining db used. """ - return asyncio.run(self.get_item_async(key)) - @abstractmethod - async def get_item_async(self, key: Hashable) -> Context: + self.rewrite_existing = rewrite_existing + """ + Whether to rewrite existing data in the storage. """ - Asynchronous method for accessing stored Context. - :param key: Hashable key used to store Context instance. - :return: The stored context, associated with the given key. + self._subscripts = dict() + """ + Subscripts control how many elements will be loaded from the database. + Can be an integer, meaning the number of *last* elements to load. + A special value for loading all the elements at once: "__all__". + Can also be a set of keys that should be loaded. """ - raise NotImplementedError - def __setitem__(self, key: Hashable, value: Context): + self._sync_lock = Lock() + """ + Synchronization lock for the databases that don't support + asynchronous atomic reads and writes. """ - Synchronous method for storing Context. - :param key: Hashable key used to store Context instance. - :param value: Context to store. + self.connected = False """ - return asyncio.run(self.set_item_async(key, value)) + Flag that marks if the storage is connected to the backend. + Should be set in `pipeline.run` or later (lazily). + """ + + for field in (NameConfig._labels_field, NameConfig._requests_field, NameConfig._responses_field): + value = configuration.get(field, self._default_subscript_value) + if (not isinstance(value, int)) or value >= 1: + self._subscripts[field] = value + else: + raise ValueError(f"Invalid subscript value ({value}) for field {field}") + @property @abstractmethod - async def set_item_async(self, key: Hashable, value: Context): + def is_concurrent(self) -> bool: """ - Asynchronous method for storing Context. - - :param key: Hashable key used to store Context instance. - :param value: Context to store. + If the database backend support asynchronous IO. """ + raise NotImplementedError - def __delitem__(self, key: Hashable): - """ - Synchronous method for removing stored Context. + @classmethod + def _validate_field_name(cls, field_name: str) -> str: + if field_name not in (NameConfig._labels_field, NameConfig._requests_field, NameConfig._responses_field): + raise ValueError(f"Invalid value '{field_name}' for argument 'field_name'!") + else: + return field_name - :param key: Hashable key used to identify Context instance for deletion. + @abstractmethod + async def _connect(self) -> None: + raise NotImplementedError + + async def connect(self) -> None: + """ + Connect to the backend context storage. """ - return asyncio.run(self.del_item_async(key)) + + logger.info(f"Connecting to context storage {type(self).__name__} ...") + await self._connect() + self.connected = True @abstractmethod - async def del_item_async(self, key: Hashable): + async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]: + raise NotImplementedError + + @_lock + async def load_main_info(self, ctx_id: str) -> Optional[ContextInfo]: """ - Asynchronous method for removing stored Context. + Load main information about the context. - :param key: Hashable key used to identify Context instance for deletion. + :param ctx_id: Context identifier. + :return: Context main information (from `MAIN` table). """ + + if not self.connected: + await self.connect() + logger.debug(f"Loading main info for {ctx_id}...") + result = await self._load_main_info(ctx_id) + logger.debug(f"Main info loaded for {ctx_id}") + return result + + @abstractmethod + async def _update_main_info(self, ctx_id: str, ctx_info: ContextInfo) -> None: raise NotImplementedError - def __contains__(self, key: Hashable) -> bool: + @_lock + async def update_main_info(self, ctx_id: str, ctx_info: ContextInfo) -> None: """ - Synchronous method for finding whether any Context is stored with given key. + Update main information about the context. - :param key: Hashable key used to check if Context instance is stored. - :return: True if there is Context accessible by given key, False otherwise. + :param ctx_id: Context identifier. + :param ctx_info: New context information (will be written to `MAIN` table). """ - return asyncio.run(self.contains_async(key)) + + if not self.connected: + await self.connect() + logger.debug(f"Updating main info for {ctx_id}...") + await self._update_main_info(ctx_id, ctx_info) + logger.debug(f"Main info updated for {ctx_id}") @abstractmethod - async def contains_async(self, key: Hashable) -> bool: + async def _delete_context(self, ctx_id: str) -> None: + raise NotImplementedError + + @_lock + async def delete_context(self, ctx_id: str) -> None: """ - Asynchronous method for finding whether any Context is stored with given key. + Delete context from context storage. - :param key: Hashable key used to check if Context instance is stored. - :return: True if there is Context accessible by given key, False otherwise. + :param ctx_id: Context identifier. """ + + if not self.connected: + await self.connect() + logger.debug(f"Deleting context {ctx_id}...") + await self._delete_context(ctx_id) + logger.debug(f"Context {ctx_id} deleted") + + @abstractmethod + async def _load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]: raise NotImplementedError - def __len__(self) -> int: + @_lock + async def load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]: """ - Synchronous method for retrieving number of stored Contexts. + Load the latest field data (specified by `subscript` value). - :return: The number of stored Contexts. + :param ctx_id: Context identifier. + :param field_name: Field name to load from `TURNS` table. + :return: List of tuples (step number, serialized value). """ - return asyncio.run(self.len_async()) + + if not self.connected: + await self.connect() + logger.debug(f"Loading latest items for {ctx_id}, {field_name}...") + result = await self._load_field_latest(ctx_id, self._validate_field_name(field_name)) + logger.debug(f"Latest field loaded for {ctx_id}, {field_name}: {collapse_num_list(list(k for k, _ in result))}") + return result @abstractmethod - async def len_async(self) -> int: + async def _load_field_keys(self, ctx_id: str, field_name: str) -> List[int]: + raise NotImplementedError + + @_lock + async def load_field_keys(self, ctx_id: str, field_name: str) -> List[int]: """ - Asynchronous method for retrieving number of stored Contexts. + Load all field keys. - :return: The number of stored Contexts. + :param ctx_id: Context identifier. + :param field_name: Field name to load from `TURNS` table. + :return: List of all the step numbers. """ + + if not self.connected: + await self.connect() + logger.debug(f"Loading field keys for {ctx_id}, {field_name}...") + result = await self._load_field_keys(ctx_id, self._validate_field_name(field_name)) + logger.debug(f"Field keys loaded for {ctx_id}, {field_name}: {collapse_num_list(result)}") + return result + + @abstractmethod + async def _load_field_items(self, ctx_id: str, field_name: str, keys: List[int]) -> List[Tuple[int, bytes]]: raise NotImplementedError - def get(self, key: Hashable, default: Optional[Context] = None) -> Context: + @_lock + async def load_field_items(self, ctx_id: str, field_name: str, keys: List[int]) -> List[Tuple[int, bytes]]: """ - Synchronous method for accessing stored Context, returning default if no Context is stored with the given key. + Load field items (specified by key list). + The items that are equal to `None` will be ignored. - :param key: Hashable key used to store Context instance. - :param default: Optional default value to be returned if no Context is found. - :return: The stored context, associated with the given key or default value. + :param ctx_id: Context identifier. + :param field_name: Field name to load from `TURNS` table. + :param keys: List of keys to load. + :return: List of tuples (step number, serialized value). """ - return asyncio.run(self.get_async(key, default)) - async def get_async(self, key: Hashable, default: Optional[Context] = None) -> Context: + if not self.connected: + await self.connect() + logger.debug(f"Loading field items for {ctx_id}, {field_name} ({collapse_num_list(keys)})...") + result = await self._load_field_items(ctx_id, self._validate_field_name(field_name), keys) + logger.debug(f"Field items loaded for {ctx_id}, {field_name}: {collapse_num_list([k for k, _ in result])}") + return result + + @abstractmethod + async def _update_field_items(self, ctx_id: str, field_name: str, items: List[Tuple[int, Optional[bytes]]]) -> None: + raise NotImplementedError + + @_lock + async def update_field_items(self, ctx_id: str, field_name: str, items: List[Tuple[int, Optional[bytes]]]) -> None: """ - Asynchronous method for accessing stored Context, returning default if no Context is stored with the given key. + Update field items. - :param key: Hashable key used to store Context instance. - :param default: Optional default value to be returned if no Context is found. - :return: The stored context, associated with the given key or default value. + :param ctx_id: Context identifier. + :param field_name: Field name to load from `TURNS` table. + :param items: List of tuples that will be written (step number, serialized value or `None`). """ - try: - return await self.get_item_async(str(key)) - except KeyError: - return default - def clear(self): + if len(items) == 0: + logger.debug(f"No fields to update in {ctx_id}, {field_name}!") + return + elif not self.connected: + await self.connect() + logger.debug(f"Updating fields for {ctx_id}, {field_name}: {collapse_num_list(list(k for k, _ in items))}...") + await self._update_field_items(ctx_id, self._validate_field_name(field_name), items) + logger.debug(f"Fields updated for {ctx_id}, {field_name}") + + async def _delete_field_keys(self, ctx_id: str, field_name: str, keys: List[int]) -> None: + await self._update_field_items(ctx_id, field_name, [(k, None) for k in keys]) + + @_lock + async def delete_field_keys(self, ctx_id: str, field_name: str, keys: List[int]) -> None: """ - Synchronous method for clearing context storage, removing all the stored Contexts. + Delete field keys. + + :param ctx_id: Context identifier. + :param field_name: Field name to load from `TURNS` table. + :param keys: List of keys to delete (will be just overwritten with `None`). """ - return asyncio.run(self.clear_async()) + + if len(keys) == 0: + logger.debug(f"No fields to delete in {ctx_id}, {field_name}!") + return + elif not self.connected: + await self.connect() + logger.debug(f"Deleting fields for {ctx_id}, {field_name}: {collapse_num_list(keys)}...") + await self._delete_field_keys(ctx_id, self._validate_field_name(field_name), keys) + logger.debug(f"Fields deleted for {ctx_id}, {field_name}") @abstractmethod - async def clear_async(self): - """ - Asynchronous method for clearing context storage, removing all the stored Contexts. - """ + async def _clear_all(self) -> None: raise NotImplementedError + @_lock + async def clear_all(self) -> None: + """ + Clear all the chatsky tables and records. + """ -def threadsafe_method(func: Callable): - """ - A decorator that makes sure methods of an object instance are threadsafe. - """ - - @wraps(func) - def _synchronized(self, *args, **kwargs): - with self._lock: - return func(self, *args, **kwargs) + if not self.connected: + await self.connect() + logger.debug("Clearing all") + await self._clear_all() - return _synchronized + def __eq__(self, other: Any) -> bool: + if not isinstance(other, DBContextStorage): + return False + return ( + self.full_path == other.full_path + and self.path == other.path + and self.rewrite_existing == other.rewrite_existing + ) def context_storage_factory(path: str, **kwargs) -> DBContextStorage: @@ -209,20 +411,29 @@ def context_storage_factory(path: str, **kwargs) -> DBContextStorage: json://file.json When using sqlite backend your prefix should contain three slashes if you use Windows, or four in other cases: sqlite:////file.db + + For MemoryContextStorage pass an empty string as ``path``. + If you want to use additional parameters in class constructors, you can pass them to this function as kwargs. :param path: Path to the file. """ - prefix, _, _ = path.partition("://") - if "sql" in prefix: - prefix = prefix.split("+")[0] # this takes care of alternative sql drivers - assert ( - prefix in PROTOCOLS - ), f""" - URI path should be prefixed with one of the following:\n - {", ".join(PROTOCOLS.keys())}.\n - For more information, see the function doc:\n{context_storage_factory.__doc__} - """ - _class, module = PROTOCOLS[prefix]["class"], PROTOCOLS[prefix]["module"] - target_class = getattr(importlib.import_module(f".{module}", package="chatsky.context_storages"), _class) + + if path == "": + module = "memory" + _class = "MemoryContextStorage" + else: + prefix, _, _ = path.partition("://") + if any(prefix.startswith(sql_prefix) for sql_prefix in ("sqlite", "mysql", "postgresql")): + prefix = prefix.split("+")[0] # this takes care of alternative sql drivers + if prefix not in PROTOCOLS: + raise ValueError( + f""" + URI path should be prefixed with one of the following:\n + {", ".join(PROTOCOLS.keys())}.\n + For more information, see the function doc:\n{context_storage_factory.__doc__} + """ + ) + _class, module = PROTOCOLS[prefix]["class"], PROTOCOLS[prefix]["module"] + target_class = getattr(import_module(f".{module}", package="chatsky.context_storages"), _class) return target_class(path, **kwargs) diff --git a/chatsky/context_storages/file.py b/chatsky/context_storages/file.py new file mode 100644 index 000000000..3463686ef --- /dev/null +++ b/chatsky/context_storages/file.py @@ -0,0 +1,203 @@ +""" +JSON +---- +The JSON module provides a json-based version of the :py:class:`.DBContextStorage` class. +This class is used to store and retrieve context data in a JSON. It allows Chatsky to easily +store and retrieve context data. +""" + +from abc import ABC, abstractmethod +from pickle import loads, dumps +from shelve import DbfilenameShelf +from typing import List, Set, Tuple, Dict, Optional + +from pydantic import BaseModel, Field + +from .database import ContextInfo, DBContextStorage, _SUBSCRIPT_DICT + +try: + from aiofiles import open + from aiofiles.os import stat, makedirs + from aiofiles.ospath import isfile + + json_available = True + pickle_available = True +except ImportError: + json_available = False + pickle_available = False + + +class SerializableStorage(BaseModel): + """ + A special serializable database implementation. + One element of this class will be used to store all the contexts, read and written to file on every turn. + """ + + main: Dict[str, ContextInfo] = Field(default_factory=dict) + turns: List[Tuple[str, str, int, Optional[bytes]]] = Field(default_factory=list) + + +class FileContextStorage(DBContextStorage, ABC): + """ + Implements :py:class:`.DBContextStorage` with any file-based storage format. + + :param path: Target file URI. + :param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not. + :param partial_read_config: Dictionary of subscripts for all possible turn items. + """ + + is_concurrent: bool = False + + def __init__( + self, + path: str = "", + rewrite_existing: bool = False, + partial_read_config: Optional[_SUBSCRIPT_DICT] = None, + ): + DBContextStorage.__init__(self, path, rewrite_existing, partial_read_config) + + @abstractmethod + async def _save(self, data: SerializableStorage) -> None: + raise NotImplementedError + + @abstractmethod + async def _load(self) -> SerializableStorage: + raise NotImplementedError + + async def _connect(self): + await self._load() + + async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]: + return (await self._load()).main.get(ctx_id, None) + + async def _update_main_info(self, ctx_id: str, ctx_info: ContextInfo) -> None: + storage = await self._load() + storage.main[ctx_id] = ctx_info + await self._save(storage) + + async def _delete_context(self, ctx_id: str) -> None: + storage = await self._load() + storage.main.pop(ctx_id, None) + storage.turns = [(c, f, k, v) for c, f, k, v in storage.turns if c != ctx_id] + await self._save(storage) + + async def _load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]: + storage = await self._load() + select = sorted( + [(k, v) for c, f, k, v in storage.turns if c == ctx_id and f == field_name and v is not None], + key=lambda e: e[0], + reverse=True, + ) + if isinstance(self._subscripts[field_name], int): + select = select[: self._subscripts[field_name]] + elif isinstance(self._subscripts[field_name], Set): + select = [(k, v) for k, v in select if k in self._subscripts[field_name]] + return select + + async def _load_field_keys(self, ctx_id: str, field_name: str) -> List[int]: + return [k for c, f, k, v in (await self._load()).turns if c == ctx_id and f == field_name and v is not None] + + async def _load_field_items(self, ctx_id: str, field_name: str, keys: List[int]) -> List[Tuple[int, bytes]]: + return [ + (k, v) + for c, f, k, v in (await self._load()).turns + if c == ctx_id and f == field_name and k in keys and v is not None + ] + + async def _update_field_items(self, ctx_id: str, field_name: str, items: List[Tuple[int, Optional[bytes]]]) -> None: + storage = await self._load() + for k, v in items: + upd = (ctx_id, field_name, k, v) + for i in range(len(storage.turns)): + if storage.turns[i][:-1] == upd[:-1]: + storage.turns[i] = upd + break + else: + storage.turns += [upd] + await self._save(storage) + + async def _clear_all(self) -> None: + await self._save(SerializableStorage()) + + +class JSONContextStorage(FileContextStorage): + """ + Implements :py:class:`.DBContextStorage` with `json` as the storage format. + + :param path: Target file URI. Example: `json://file.json`. + :param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not. + :param partial_read_config: Dictionary of subscripts for all possible turn items. + """ + + async def _save(self, data: SerializableStorage) -> None: + if not await isfile(self.path) or (await stat(self.path)).st_size == 0: + await makedirs(self.path.parent, exist_ok=True) + async with open(self.path, "w", encoding="utf-8") as file_stream: + await file_stream.write(data.model_dump_json()) + + async def _load(self) -> SerializableStorage: + if not await isfile(self.path) or (await stat(self.path)).st_size == 0: + storage = SerializableStorage() + await self._save(storage) + else: + async with open(self.path, "r", encoding="utf-8") as file_stream: + storage = SerializableStorage.model_validate_json(await file_stream.read()) + return storage + + +class PickleContextStorage(FileContextStorage): + """ + Implements :py:class:`.DBContextStorage` with `pickle` as the storage format. + + :param path: Target file URI. Example: `pickle://file.pkl`. + :param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not. + :param partial_read_config: Dictionary of subscripts for all possible turn items. + """ + + async def _save(self, data: SerializableStorage) -> None: + if not await isfile(self.path) or (await stat(self.path)).st_size == 0: + await makedirs(self.path.parent, exist_ok=True) + async with open(self.path, "wb") as file_stream: + await file_stream.write(dumps(data.model_dump())) + + async def _load(self) -> SerializableStorage: + if not await isfile(self.path) or (await stat(self.path)).st_size == 0: + storage = SerializableStorage() + await self._save(storage) + else: + async with open(self.path, "rb") as file_stream: + storage = SerializableStorage.model_validate(loads(await file_stream.read())) + return storage + + +class ShelveContextStorage(FileContextStorage): + """ + Implements :py:class:`.DBContextStorage` with `shelve` as the storage format. + + :param path: Target file URI. Example: `shelve://file.shlv`. + :param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not. + :param partial_read_config: Dictionary of subscripts for all possible turn items. + """ + + _SHELVE_ROOT = "root" + + def __init__( + self, + path: str = "", + rewrite_existing: bool = False, + partial_read_config: Optional[_SUBSCRIPT_DICT] = None, + ): + self._storage = None + FileContextStorage.__init__(self, path, rewrite_existing, partial_read_config) + + async def _save(self, data: SerializableStorage) -> None: + self._storage[self._SHELVE_ROOT] = data.model_dump() + + async def _load(self) -> SerializableStorage: + if self._storage is None: + content = SerializableStorage() + self._storage = DbfilenameShelf(str(self.path.absolute()), writeback=True) + await self._save(content) + else: + content = SerializableStorage.model_validate(self._storage[self._SHELVE_ROOT]) + return content diff --git a/chatsky/context_storages/json.py b/chatsky/context_storages/json.py deleted file mode 100644 index 21b84e36f..000000000 --- a/chatsky/context_storages/json.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -JSON ----- -The JSON module provides a json-based version of the :py:class:`.DBContextStorage` class. -This class is used to store and retrieve context data in a JSON. It allows Chatsky to easily -store and retrieve context data. -""" - -import asyncio -from typing import Hashable, Dict - -try: - import aiofiles - import aiofiles.os - - json_available = True -except ImportError: - json_available = False - -from pydantic import BaseModel - -from .database import DBContextStorage, threadsafe_method -from chatsky.core import Context - - -class SerializableStorage(BaseModel, extra="allow"): - __pydantic_extra__: Dict[str, Context] - - -class JSONContextStorage(DBContextStorage): - """ - Implements :py:class:`.DBContextStorage` with `json` as the storage format. - - :param path: Target file URI. Example: `json://file.json`. - """ - - def __init__(self, path: str): - DBContextStorage.__init__(self, path) - asyncio.run(self._load()) - - @threadsafe_method - async def len_async(self) -> int: - return len(self.storage.model_extra) - - @threadsafe_method - async def set_item_async(self, key: Hashable, value: Context): - self.storage.model_extra.__setitem__(str(key), value) - await self._save() - - @threadsafe_method - async def get_item_async(self, key: Hashable) -> Context: - await self._load() - return Context.model_validate(self.storage.model_extra.__getitem__(str(key))) - - @threadsafe_method - async def del_item_async(self, key: Hashable): - self.storage.model_extra.__delitem__(str(key)) - await self._save() - - @threadsafe_method - async def contains_async(self, key: Hashable) -> bool: - await self._load() - return self.storage.model_extra.__contains__(str(key)) - - @threadsafe_method - async def clear_async(self): - self.storage.model_extra.clear() - await self._save() - - async def _save(self): - async with aiofiles.open(self.path, "w+", encoding="utf-8") as file_stream: - await file_stream.write(self.storage.model_dump_json()) - - async def _load(self): - if not await aiofiles.os.path.isfile(self.path) or (await aiofiles.os.stat(self.path)).st_size == 0: - self.storage = SerializableStorage() - await self._save() - else: - async with aiofiles.open(self.path, "r", encoding="utf-8") as file_stream: - self.storage = SerializableStorage.model_validate_json(await file_stream.read()) diff --git a/chatsky/context_storages/memory.py b/chatsky/context_storages/memory.py new file mode 100644 index 000000000..e0e16c7ab --- /dev/null +++ b/chatsky/context_storages/memory.py @@ -0,0 +1,75 @@ +from typing import List, Optional, Set, Tuple + +from .database import ContextInfo, DBContextStorage, _SUBSCRIPT_DICT, NameConfig + + +class MemoryContextStorage(DBContextStorage): + """ + Implements :py:class:`.DBContextStorage` storing contexts in memory, wthout file backend. + Does not serialize any data. By default it sets path to an empty string. + + Keeps data in a dictionary and two dictionaries: + + - `main`: {context_id: context_info} + - `turns`: {context_id: {labels, requests, responses}} + + :param path: Any string, won't be used. + :param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not. + :param partial_read_config: Dictionary of subscripts for all possible turn items. + """ + + is_concurrent: bool = True + + def __init__( + self, + path: str = "", + rewrite_existing: bool = False, + partial_read_config: Optional[_SUBSCRIPT_DICT] = None, + ): + DBContextStorage.__init__(self, path, rewrite_existing, partial_read_config) + self._main_storage = dict() + self._aux_storage = { + NameConfig._labels_field: dict(), + NameConfig._requests_field: dict(), + NameConfig._responses_field: dict(), + } + + async def _connect(self): + pass + + async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]: + return self._main_storage.get(ctx_id, None) + + async def _update_main_info(self, ctx_id: str, ctx_info: ContextInfo) -> None: + self._main_storage[ctx_id] = ctx_info + + async def _delete_context(self, ctx_id: str) -> None: + self._main_storage.pop(ctx_id, None) + for storage in self._aux_storage.values(): + storage.pop(ctx_id, None) + + async def _load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]: + select = sorted( + [k for k, v in self._aux_storage[field_name].get(ctx_id, dict()).items() if v is not None], reverse=True + ) + if isinstance(self._subscripts[field_name], int): + select = select[: self._subscripts[field_name]] + elif isinstance(self._subscripts[field_name], Set): + select = [k for k in select if k in self._subscripts[field_name]] + return [(k, self._aux_storage[field_name][ctx_id][k]) for k in select] + + async def _load_field_keys(self, ctx_id: str, field_name: str) -> List[int]: + return [k for k, v in self._aux_storage[field_name].get(ctx_id, dict()).items() if v is not None] + + async def _load_field_items(self, ctx_id: str, field_name: str, keys: List[int]) -> List[Tuple[int, bytes]]: + return [ + (k, v) for k, v in self._aux_storage[field_name].get(ctx_id, dict()).items() if k in keys and v is not None + ] + + async def _update_field_items(self, ctx_id: str, field_name: str, items: List[Tuple[int, Optional[bytes]]]) -> None: + self._aux_storage[field_name].setdefault(ctx_id, dict()).update(items) + + async def _clear_all(self) -> None: + self._main_storage = dict() + for key in self._aux_storage.keys(): + self._aux_storage[key] = dict() diff --git a/chatsky/context_storages/mongo.py b/chatsky/context_storages/mongo.py index 3bc6d1956..0de5a83d1 100644 --- a/chatsky/context_storages/mongo.py +++ b/chatsky/context_storages/mongo.py @@ -12,23 +12,18 @@ and high levels of read and write traffic. """ -from typing import Hashable, Dict, Any +from asyncio import gather +from typing import Set, Tuple, Optional, List try: + from pymongo import UpdateOne from motor.motor_asyncio import AsyncIOMotorClient - from bson.objectid import ObjectId mongo_available = True except ImportError: mongo_available = False - AsyncIOMotorClient = None - ObjectId = Any -import json - -from chatsky.core import Context - -from .database import DBContextStorage, threadsafe_method +from .database import ContextInfo, DBContextStorage, _SUBSCRIPT_DICT, NameConfig from .protocol import get_protocol_install_suggestion @@ -36,60 +31,142 @@ class MongoContextStorage(DBContextStorage): """ Implements :py:class:`.DBContextStorage` with `mongodb` as the database backend. + CONTEXTS table is stored as `COLLECTION_PREFIX_contexts` collection. + LOGS table is stored as `COLLECTION_PREFIX_logs` collection. + :param path: Database URI. Example: `mongodb://user:password@host:port/dbname`. - :param collection: Name of the collection to store the data in. + :param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not. + :param partial_read_config: Dictionary of subscripts for all possible turn items. + :param collection_prefix: "namespace" prefix for the two collections created for context storing. """ - def __init__(self, path: str, collection: str = "context_collection"): - DBContextStorage.__init__(self, path) + _UNIQUE_KEYS = "unique_keys" + _ID_FIELD = "_id" + + is_concurrent: bool = True + + def __init__( + self, + path: str, + rewrite_existing: bool = False, + partial_read_config: Optional[_SUBSCRIPT_DICT] = None, + collection_prefix: str = "chatsky_collection", + ): + DBContextStorage.__init__(self, path, rewrite_existing, partial_read_config) + if not mongo_available: install_suggestion = get_protocol_install_suggestion("mongodb") raise ImportError("`mongodb` package is missing.\n" + install_suggestion) - self._mongo = AsyncIOMotorClient(self.full_path) + self._mongo = AsyncIOMotorClient(self.full_path, uuidRepresentation="standard") db = self._mongo.get_default_database() - self.collection = db[collection] - - @staticmethod - def _adjust_key(key: Hashable) -> Dict[str, ObjectId]: - """Convert a n-digit context id to a 24-digit mongo id""" - new_key = hex(int.from_bytes(str.encode(str(key)), "big", signed=False))[3:] - new_key = (new_key * (24 // len(new_key) + 1))[:24] - assert len(new_key) == 24 - return {"_id": ObjectId(new_key)} - - @threadsafe_method - async def set_item_async(self, key: Hashable, value: Context): - new_key = self._adjust_key(key) - value = Context.model_validate(value) - document = json.loads(value.model_dump_json()) - - document.update(new_key) - await self.collection.replace_one(new_key, document, upsert=True) - - @threadsafe_method - async def get_item_async(self, key: Hashable) -> Context: - adjust_key = self._adjust_key(key) - document = await self.collection.find_one(adjust_key) - if document: - document.pop("_id") - ctx = Context.model_validate(document) - return ctx - raise KeyError - - @threadsafe_method - async def del_item_async(self, key: Hashable): - adjust_key = self._adjust_key(key) - await self.collection.delete_one(adjust_key) - - @threadsafe_method - async def contains_async(self, key: Hashable) -> bool: - adjust_key = self._adjust_key(key) - return bool(await self.collection.find_one(adjust_key)) - - @threadsafe_method - async def len_async(self) -> int: - return await self.collection.estimated_document_count() - - @threadsafe_method - async def clear_async(self): - await self.collection.delete_many(dict()) + + self.main_table = db[f"{collection_prefix}_{NameConfig._main_table}"] + self.turns_table = db[f"{collection_prefix}_{NameConfig._turns_table}"] + + async def _connect(self): + await gather( + self.main_table.create_index(NameConfig._id_column, background=True, unique=True), + self.turns_table.create_index( + [NameConfig._id_column, NameConfig._key_column], background=True, unique=True + ), + ) + + async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]: + result = await self.main_table.find_one( + {NameConfig._id_column: ctx_id}, + [ + NameConfig._current_turn_id_column, + NameConfig._created_at_column, + NameConfig._updated_at_column, + NameConfig._misc_column, + NameConfig._framework_data_column, + ], + ) + return ( + ContextInfo.model_validate( + { + "turn_id": result[NameConfig._current_turn_id_column], + "created_at": result[NameConfig._created_at_column], + "updated_at": result[NameConfig._updated_at_column], + "misc": result[NameConfig._misc_column], + "framework_data": result[NameConfig._framework_data_column], + } + ) + if result is not None + else None + ) + + async def _update_main_info(self, ctx_id: str, ctx_info: ContextInfo) -> None: + ctx_info_dump = ctx_info.model_dump(mode="python") + await self.main_table.update_one( + {NameConfig._id_column: ctx_id}, + { + "$set": { + NameConfig._id_column: ctx_id, + NameConfig._current_turn_id_column: ctx_info_dump["turn_id"], + NameConfig._created_at_column: ctx_info_dump["created_at"], + NameConfig._updated_at_column: ctx_info_dump["updated_at"], + NameConfig._misc_column: ctx_info_dump["misc"], + NameConfig._framework_data_column: ctx_info_dump["framework_data"], + } + }, + upsert=True, + ) + + async def _delete_context(self, ctx_id: str) -> None: + await gather( + self.main_table.delete_one({NameConfig._id_column: ctx_id}), + self.turns_table.delete_one({NameConfig._id_column: ctx_id}), + ) + + async def _load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]: + limit, key = 0, dict() + if isinstance(self._subscripts[field_name], int): + limit = self._subscripts[field_name] + elif isinstance(self._subscripts[field_name], Set): + key = {NameConfig._key_column: {"$in": list(self._subscripts[field_name])}} + result = ( + await self.turns_table.find( + {NameConfig._id_column: ctx_id, field_name: {"$exists": True, "$ne": None}, **key}, + [NameConfig._key_column, field_name], + sort=[(NameConfig._key_column, -1)], + ) + .limit(limit) + .to_list(None) + ) + return [(item[NameConfig._key_column], item[field_name]) for item in result] + + async def _load_field_keys(self, ctx_id: str, field_name: str) -> List[int]: + result = await self.turns_table.aggregate( + [ + {"$match": {NameConfig._id_column: ctx_id, field_name: {"$ne": None}}}, + {"$group": {"_id": None, self._UNIQUE_KEYS: {"$addToSet": f"${NameConfig._key_column}"}}}, + ] + ).to_list(None) + return result[0][self._UNIQUE_KEYS] if len(result) == 1 else list() + + async def _load_field_items(self, ctx_id: str, field_name: str, keys: Set[int]) -> List[Tuple[int, bytes]]: + result = await self.turns_table.find( + { + NameConfig._id_column: ctx_id, + NameConfig._key_column: {"$in": list(keys)}, + field_name: {"$exists": True, "$ne": None}, + }, + [NameConfig._key_column, field_name], + ).to_list(None) + return [(item[NameConfig._key_column], item[field_name]) for item in result] + + async def _update_field_items(self, ctx_id: str, field_name: str, items: List[Tuple[int, Optional[bytes]]]) -> None: + await self.turns_table.bulk_write( + [ + UpdateOne( + {NameConfig._id_column: ctx_id, NameConfig._key_column: k}, + {"$set": {field_name: v}}, + upsert=True, + ) + for k, v in items + ] + ) + + async def _clear_all(self) -> None: + await gather(self.main_table.delete_many({}), self.turns_table.delete_many({})) diff --git a/chatsky/context_storages/pickle.py b/chatsky/context_storages/pickle.py deleted file mode 100644 index eb1ddeb0c..000000000 --- a/chatsky/context_storages/pickle.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -Pickle ------- -The Pickle module provides a pickle-based version of the :py:class:`.DBContextStorage` class. -This class is used to store and retrieve context data in a pickle format. -It allows Chatsky to easily store and retrieve context data in a format that is efficient -for serialization and deserialization and can be easily used in python. - -Pickle is a python library that allows to serialize and deserialize python objects. -It is efficient and fast, but it is not recommended to use it to transfer data across -different languages or platforms because it's not cross-language compatible. -""" - -import asyncio -import pickle -from typing import Hashable - -try: - import aiofiles - import aiofiles.os - - pickle_available = True -except ImportError: - pickle_available = False - -from .database import DBContextStorage, threadsafe_method -from chatsky.core import Context - - -class PickleContextStorage(DBContextStorage): - """ - Implements :py:class:`.DBContextStorage` with `pickle` as driver. - - :param path: Target file URI. Example: 'pickle://file.pkl'. - """ - - def __init__(self, path: str): - DBContextStorage.__init__(self, path) - asyncio.run(self._load()) - - @threadsafe_method - async def len_async(self) -> int: - return len(self.dict) - - @threadsafe_method - async def set_item_async(self, key: Hashable, value: Context): - self.dict.__setitem__(str(key), value) - await self._save() - - @threadsafe_method - async def get_item_async(self, key: Hashable) -> Context: - await self._load() - return Context.model_validate(self.dict.__getitem__(str(key))) - - @threadsafe_method - async def del_item_async(self, key: Hashable): - self.dict.__delitem__(str(key)) - await self._save() - - @threadsafe_method - async def contains_async(self, key: Hashable) -> bool: - await self._load() - return self.dict.__contains__(str(key)) - - @threadsafe_method - async def clear_async(self): - self.dict.clear() - await self._save() - - async def _save(self): - async with aiofiles.open(self.path, "wb+") as file: - await file.write(pickle.dumps(self.dict)) - - async def _load(self): - if not await aiofiles.os.path.isfile(self.path) or (await aiofiles.os.stat(self.path)).st_size == 0: - self.dict = dict() - await self._save() - else: - async with aiofiles.open(self.path, "rb") as file: - self.dict = pickle.loads(await file.read()) diff --git a/chatsky/context_storages/protocols.json b/chatsky/context_storages/protocols.json index ce1f808f3..3ac4220a2 100644 --- a/chatsky/context_storages/protocols.json +++ b/chatsky/context_storages/protocols.json @@ -1,18 +1,18 @@ { "shelve": { - "module": "shelve", + "module": "file", "class": "ShelveContextStorage", "slug": "shelve", "uri_example": "shelve://path_to_the_file/file_name" }, "json": { - "module": "json", + "module": "file", "class": "JSONContextStorage", "slug": "json", "uri_example": "json://path_to_the_file/file_name" }, "pickle": { - "module": "pickle", + "module": "file", "class": "PickleContextStorage", "slug": "pickle", "uri_example": "pickle://path_to_the_file/file_name" diff --git a/chatsky/context_storages/redis.py b/chatsky/context_storages/redis.py index e3165cacd..4c7d27d0d 100644 --- a/chatsky/context_storages/redis.py +++ b/chatsky/context_storages/redis.py @@ -13,8 +13,8 @@ and powerful choice for data storage and management. """ -import json -from typing import Hashable +from asyncio import gather +from typing import List, Set, Tuple, Optional try: from redis.asyncio import Redis @@ -23,9 +23,7 @@ except ImportError: redis_available = False -from chatsky.core import Context - -from .database import DBContextStorage, threadsafe_method +from .database import ContextInfo, DBContextStorage, _SUBSCRIPT_DICT, NameConfig from .protocol import get_protocol_install_suggestion @@ -33,41 +31,120 @@ class RedisContextStorage(DBContextStorage): """ Implements :py:class:`.DBContextStorage` with `redis` as the database backend. + The main context info is stored in redis hashes, one for each context. + The `TURNS` table values are stored in redis hashes, one for each field. + + That's how MAIN table fields are stored: + `"KEY_PREFIX:main:ctx_id": "DATA"` + That's how TURNS table fields are stored: + `"KEY_PREFIX:turns:ctx_id:FIELD_NAME": "DATA"` + :param path: Database URI string. Example: `redis://user:password@host:port`. + :param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not. + :param partial_read_config: Dictionary of subscripts for all possible turn items. + :param key_prefix: "namespace" prefix for all keys, should be set for efficient clearing of all data. """ - def __init__(self, path: str): - DBContextStorage.__init__(self, path) + is_concurrent: bool = True + + def __init__( + self, + path: str, + rewrite_existing: bool = False, + partial_read_config: Optional[_SUBSCRIPT_DICT] = None, + key_prefix: str = "chatsky_keys", + ): + DBContextStorage.__init__(self, path, rewrite_existing, partial_read_config) + if not redis_available: install_suggestion = get_protocol_install_suggestion("redis") raise ImportError("`redis` package is missing.\n" + install_suggestion) - self._redis = Redis.from_url(self.full_path) - - @threadsafe_method - async def contains_async(self, key: Hashable) -> bool: - return bool(await self._redis.exists(str(key))) - - @threadsafe_method - async def set_item_async(self, key: Hashable, value: Context): - value = Context.model_validate(value) - await self._redis.set(str(key), value.model_dump_json()) - - @threadsafe_method - async def get_item_async(self, key: Hashable) -> Context: - result = await self._redis.get(str(key)) - if result: - result_dict = json.loads(result.decode("utf-8")) - return Context.model_validate(result_dict) - raise KeyError(f"No entry for key {key}.") - - @threadsafe_method - async def del_item_async(self, key: Hashable): - await self._redis.delete(str(key)) - - @threadsafe_method - async def len_async(self) -> int: - return await self._redis.dbsize() - - @threadsafe_method - async def clear_async(self): - await self._redis.flushdb() + if not bool(key_prefix): + raise ValueError("`key_prefix` parameter shouldn't be empty") + self.database = Redis.from_url(self.full_path) + + self._prefix = key_prefix + self._main_key = f"{key_prefix}:{NameConfig._main_table}" + self._turns_key = f"{key_prefix}:{NameConfig._turns_table}" + + async def _connect(self): + pass + + @staticmethod + def _keys_to_bytes(keys: List[int]) -> List[bytes]: + return [str(f).encode("utf-8") for f in keys] + + @staticmethod + def _bytes_to_keys(keys: List[bytes]) -> List[int]: + return [int(f.decode("utf-8")) for f in keys] + + async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]: + if await self.database.exists(f"{self._main_key}:{ctx_id}"): + cti, ca, ua, msc, fd = await gather( + self.database.hget(f"{self._main_key}:{ctx_id}", NameConfig._current_turn_id_column), + self.database.hget(f"{self._main_key}:{ctx_id}", NameConfig._created_at_column), + self.database.hget(f"{self._main_key}:{ctx_id}", NameConfig._updated_at_column), + self.database.hget(f"{self._main_key}:{ctx_id}", NameConfig._misc_column), + self.database.hget(f"{self._main_key}:{ctx_id}", NameConfig._framework_data_column), + ) + return ContextInfo.model_validate( + {"turn_id": cti, "created_at": ca, "updated_at": ua, "misc": msc, "framework_data": fd} + ) + else: + return None + + async def _update_main_info(self, ctx_id: str, ctx_info: ContextInfo) -> None: + ctx_info_dump = ctx_info.model_dump(mode="python") + await gather( + self.database.hset( + f"{self._main_key}:{ctx_id}", NameConfig._current_turn_id_column, str(ctx_info_dump["turn_id"]) + ), + self.database.hset( + f"{self._main_key}:{ctx_id}", NameConfig._created_at_column, str(ctx_info_dump["created_at"]) + ), + self.database.hset( + f"{self._main_key}:{ctx_id}", NameConfig._updated_at_column, str(ctx_info_dump["updated_at"]) + ), + self.database.hset(f"{self._main_key}:{ctx_id}", NameConfig._misc_column, ctx_info_dump["misc"]), + self.database.hset( + f"{self._main_key}:{ctx_id}", NameConfig._framework_data_column, ctx_info_dump["framework_data"] + ), + ) + + async def _delete_context(self, ctx_id: str) -> None: + keys = await self.database.keys(f"{self._prefix}:*:{ctx_id}*") + if len(keys) > 0: + await self.database.delete(*keys) + + async def _load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]: + field_key = f"{self._turns_key}:{ctx_id}:{field_name}" + keys = sorted(await self.database.hkeys(field_key), key=lambda k: int(k), reverse=True) + if isinstance(self._subscripts[field_name], int): + keys = keys[: self._subscripts[field_name]] + elif isinstance(self._subscripts[field_name], Set): + keys = [k for k in keys if k in self._keys_to_bytes(self._subscripts[field_name])] + values = await gather(*[self.database.hget(field_key, k) for k in keys]) + return [(k, v) for k, v in zip(self._bytes_to_keys(keys), values)] + + async def _load_field_keys(self, ctx_id: str, field_name: str) -> List[int]: + return self._bytes_to_keys(await self.database.hkeys(f"{self._turns_key}:{ctx_id}:{field_name}")) + + async def _load_field_items(self, ctx_id: str, field_name: str, keys: List[int]) -> List[Tuple[int, bytes]]: + field_key = f"{self._turns_key}:{ctx_id}:{field_name}" + load = [k for k in await self.database.hkeys(field_key) if k in self._keys_to_bytes(keys)] + values = await gather(*[self.database.hget(field_key, k) for k in load]) + return [(k, v) for k, v in zip(self._bytes_to_keys(load), values)] + + async def _update_field_items(self, ctx_id: str, field_name: str, items: List[Tuple[int, Optional[bytes]]]) -> None: + await gather(*[self.database.hset(f"{self._turns_key}:{ctx_id}:{field_name}", str(k), v) for k, v in items]) + + async def _delete_field_keys(self, ctx_id: str, field_name: str, keys: List[int]) -> None: + field_key = f"{self._turns_key}:{ctx_id}:{field_name}" + match = [k for k in await self.database.hkeys(field_key) if k in self._keys_to_bytes(keys)] + if len(match) > 0: + await self.database.hdel(field_key, *match) + + async def _clear_all(self) -> None: + keys = await self.database.keys(f"{self._prefix}:*") + if len(keys) > 0: + await self.database.delete(*keys) diff --git a/chatsky/context_storages/shelve.py b/chatsky/context_storages/shelve.py deleted file mode 100644 index 82fc5ca87..000000000 --- a/chatsky/context_storages/shelve.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -Shelve ------- -The Shelve module provides a shelve-based version of the :py:class:`.DBContextStorage` class. -This class is used to store and retrieve context data in a shelve format. -It allows Chatsky to easily store and retrieve context data in a format that is efficient -for serialization and deserialization and can be easily used in python. - -Shelve is a python library that allows to store and retrieve python objects. -It is efficient and fast, but it is not recommended to use it to transfer data across different languages -or platforms because it's not cross-language compatible. -It stores data in a dbm-style format in the file system, which is not as fast as the other serialization -libraries like pickle or JSON. -""" - -import pickle -from shelve import DbfilenameShelf -from typing import Hashable - -from chatsky.core import Context - -from .database import DBContextStorage - - -class ShelveContextStorage(DBContextStorage): - """ - Implements :py:class:`.DBContextStorage` with `shelve` as the driver. - - :param path: Target file URI. Example: `shelve://file.db`. - """ - - def __init__(self, path: str): - DBContextStorage.__init__(self, path) - self.shelve_db = DbfilenameShelf(filename=self.path, protocol=pickle.HIGHEST_PROTOCOL) - - async def get_item_async(self, key: Hashable) -> Context: - return self.shelve_db[str(key)] - - async def set_item_async(self, key: Hashable, value: Context): - self.shelve_db.__setitem__(str(key), value) - - async def del_item_async(self, key: Hashable): - self.shelve_db.__delitem__(str(key)) - - async def contains_async(self, key: Hashable) -> bool: - return self.shelve_db.__contains__(str(key)) - - async def len_async(self) -> int: - return self.shelve_db.__len__() - - async def clear_async(self): - self.shelve_db.clear() diff --git a/chatsky/context_storages/sql.py b/chatsky/context_storages/sql.py index 4fafa9dc5..35e12999d 100644 --- a/chatsky/context_storages/sql.py +++ b/chatsky/context_storages/sql.py @@ -13,18 +13,32 @@ public-domain, SQL database engine. """ +from __future__ import annotations import asyncio -import importlib -import json -from typing import Hashable +from importlib import import_module +from typing import Callable, Collection, List, Optional, Set, Tuple +import logging -from chatsky.core import Context - -from .database import DBContextStorage, threadsafe_method +from .database import ContextInfo, DBContextStorage, _SUBSCRIPT_DICT, NameConfig from .protocol import get_protocol_install_suggestion try: - from sqlalchemy import Table, MetaData, Column, JSON, String, inspect, select, delete, func + from sqlalchemy import ( + Table, + MetaData, + Column, + LargeBinary, + String, + BigInteger, + ForeignKey, + Integer, + Index, + Insert, + inspect, + select, + delete, + event, + ) from sqlalchemy.ext.asyncio import create_async_engine sqlalchemy_available = True @@ -64,129 +78,233 @@ postgres_available = sqlite_available = mysql_available = False -def import_insert_for_dialect(dialect: str): - """ - Imports the insert function into global scope depending on the chosen sqlalchemy dialect. +logger = logging.getLogger(__name__) - :param dialect: Chosen sqlalchemy dialect. - """ - global insert - insert = getattr( - importlib.import_module(f"sqlalchemy.dialects.{dialect}"), - "insert", - ) + +def _sqlite_enable_foreign_key(dbapi_con, con_record): + dbapi_con.execute("pragma foreign_keys=ON") + + +def _import_insert_for_dialect(dialect: str) -> Callable[[Table], "Insert"]: + return getattr(import_module(f"sqlalchemy.dialects.{dialect}"), "insert") + + +def _get_upsert_stmt(dialect: str, insert_stmt, columns: Collection[str], unique: Collection[str]): + if dialect == "postgresql" or dialect == "sqlite": + if len(columns) > 0: + update_stmt = insert_stmt.on_conflict_do_update( + index_elements=unique, set_={column: insert_stmt.excluded[column] for column in columns} + ) + else: + update_stmt = insert_stmt.on_conflict_do_nothing() + elif dialect == "mysql": + if len(columns) > 0: + update_stmt = insert_stmt.on_duplicate_key_update( + **{column: insert_stmt.inserted[column] for column in columns} + ) + else: + update_stmt = insert_stmt.prefix_with("IGNORE") + else: + update_stmt = insert_stmt + return update_stmt class SQLContextStorage(DBContextStorage): """ - | SQL-based version of the :py:class:`.DBContextStorage`. - | Compatible with MySQL, Postgresql, Sqlite. + SQL-based version of the :py:class:`.DBContextStorage`. + Compatible with MySQL, Postgresql, Sqlite. + When using Sqlite on a Windows system, keep in mind that you have to use double backslashes '\\' + instead of forward slashes '/' in the file path. + + `MAIN` table is represented by `main` table. + Columns of the table are: `id`, `current_turn_id`, `created_at` `updated_at`, `misc` and `framework_data`. + + `TURNS` table is represented by `turns` table. + Columns of the table are: `id`, `key`, `label`, `request` and `response`. :param path: Standard sqlalchemy URI string. - When using sqlite backend in Windows, keep in mind that you have to use double backslashes '\\' - instead of forward slashes '/' in the file path. - :param table_name: The name of the table to use. - :param custom_driver: If you intend to use some other database driver instead of the recommended ones, - set this parameter to `True` to bypass the import checks. + Examples: `sqlite+aiosqlite://path_to_the_file/file_name`, + `mysql+asyncmy://root:pass@localhost:3306/test`, + `postgresql+asyncpg://postgres:pass@localhost:5430/test`. + :param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not. + :param partial_read_config: Dictionary of subscripts for all possible turn items. + :param table_name_prefix: "namespace" prefix for the two tables created for context storing. + :param database_id_length: Length of context ID column in the database. """ - def __init__(self, path: str, table_name: str = "contexts", custom_driver: bool = False): - DBContextStorage.__init__(self, path) - - self._check_availability(custom_driver) + def __init__( + self, + path: str, + rewrite_existing: bool = False, + partial_read_config: Optional[_SUBSCRIPT_DICT] = None, + table_name_prefix: str = "chatsky_table", + database_id_length: int = 255, + ): + DBContextStorage.__init__(self, path, rewrite_existing, partial_read_config) + + self._check_availability() self.engine = create_async_engine(self.full_path, pool_pre_ping=True) self.dialect: str = self.engine.dialect.name + self._INSERT_CALLABLE = _import_insert_for_dialect(self.dialect) - id_column_args = {"primary_key": True} if self.dialect == "sqlite": - id_column_args["sqlite_on_conflict_primary_key"] = "REPLACE" - - self.metadata = MetaData() - self.table = Table( - table_name, - self.metadata, - Column("id", String(36), **id_column_args), - Column("context", JSON), # column for storing serialized contexts + event.listen(self.engine.sync_engine, "connect", _sqlite_enable_foreign_key) + + metadata = MetaData() + self.main_table = Table( + f"{table_name_prefix}_{NameConfig._main_table}", + metadata, + Column(NameConfig._id_column, String(database_id_length), index=True, unique=True, nullable=False), + Column(NameConfig._current_turn_id_column, BigInteger(), nullable=False), + Column(NameConfig._created_at_column, BigInteger(), nullable=False), + Column(NameConfig._updated_at_column, BigInteger(), nullable=False), + Column(NameConfig._misc_column, LargeBinary(), nullable=False), + Column(NameConfig._framework_data_column, LargeBinary(), nullable=False), + ) + self.turns_table = Table( + f"{table_name_prefix}_{NameConfig._turns_table}", + metadata, + Column( + NameConfig._id_column, + String(database_id_length), + ForeignKey(self.main_table.name, NameConfig._id_column), + nullable=False, + ), + Column(NameConfig._key_column, Integer(), nullable=False), + Column(NameConfig._labels_field, LargeBinary(), nullable=True), + Column(NameConfig._requests_field, LargeBinary(), nullable=True), + Column(NameConfig._responses_field, LargeBinary(), nullable=True), + Index(f"{NameConfig._turns_table}_index", NameConfig._id_column, NameConfig._key_column, unique=True), ) - asyncio.run(self._create_self_table()) + @property + def is_concurrent(self) -> bool: + return self.dialect != "sqlite" + + async def _connect(self): + async with self.engine.begin() as conn: + for table in [self.main_table, self.turns_table]: + if not await conn.run_sync(lambda sync_conn: inspect(sync_conn).has_table(table.name)): + logger.debug(f"SQL table created: {table.name}") + await conn.run_sync(table.create, self.engine) + else: + logger.debug(f"SQL table already exists: {table.name}") + + def _check_availability(self): + """ + Chech availability of the specified backend, raise error if not available. + + :param custom_driver: custom driver is requested - no checks will be performed. + """ + if self.full_path.startswith("postgresql") and not postgres_available: + install_suggestion = get_protocol_install_suggestion("postgresql") + raise ImportError("Packages `sqlalchemy` and/or `asyncpg` are missing.\n" + install_suggestion) + elif self.full_path.startswith("mysql") and not mysql_available: + install_suggestion = get_protocol_install_suggestion("mysql") + raise ImportError("Packages `sqlalchemy` and/or `asyncmy` are missing.\n" + install_suggestion) + elif self.full_path.startswith("sqlite") and not sqlite_available: + install_suggestion = get_protocol_install_suggestion("sqlite") + raise ImportError("Package `sqlalchemy` and/or `aiosqlite` is missing.\n" + install_suggestion) + + async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]: + stmt = select(self.main_table).where(self.main_table.c[NameConfig._id_column] == ctx_id) + async with self.engine.begin() as conn: + result = (await conn.execute(stmt)).fetchone() + return ( + None + if result is None + else ContextInfo.model_validate( + { + "turn_id": result[1], + "created_at": result[2], + "updated_at": result[3], + "misc": result[4], + "framework_data": result[5], + } + ) + ) + + async def _update_main_info(self, ctx_id: str, ctx_info: ContextInfo) -> None: + ctx_info_dump = ctx_info.model_dump(mode="python") + insert_stmt = self._INSERT_CALLABLE(self.main_table).values( + { + NameConfig._id_column: ctx_id, + NameConfig._current_turn_id_column: ctx_info_dump["turn_id"], + NameConfig._created_at_column: ctx_info_dump["created_at"], + NameConfig._updated_at_column: ctx_info_dump["updated_at"], + NameConfig._misc_column: ctx_info_dump["misc"], + NameConfig._framework_data_column: ctx_info_dump["framework_data"], + } + ) + update_stmt = _get_upsert_stmt( + self.dialect, + insert_stmt, + [ + NameConfig._updated_at_column, + NameConfig._current_turn_id_column, + NameConfig._misc_column, + NameConfig._framework_data_column, + ], + [NameConfig._id_column], + ) + async with self.engine.begin() as conn: + await conn.execute(update_stmt) - import_insert_for_dialect(self.dialect) + # TODO: use foreign keys instead maybe? + async def _delete_context(self, ctx_id: str) -> None: + async with self.engine.begin() as conn: + await asyncio.gather( + conn.execute(delete(self.main_table).where(self.main_table.c[NameConfig._id_column] == ctx_id)), + conn.execute(delete(self.turns_table).where(self.turns_table.c[NameConfig._id_column] == ctx_id)), + ) - @threadsafe_method - async def set_item_async(self, key: Hashable, value: Context): - value = Context.model_validate(value) - value = json.loads(value.model_dump_json()) + async def _load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]: + stmt = select(self.turns_table.c[NameConfig._key_column], self.turns_table.c[field_name]) + stmt = stmt.where(self.turns_table.c[NameConfig._id_column] == ctx_id) + stmt = stmt.where(self.turns_table.c[field_name] != None) # noqa: E711 + stmt = stmt.order_by(self.turns_table.c[NameConfig._key_column].desc()) + if isinstance(self._subscripts[field_name], int): + stmt = stmt.limit(self._subscripts[field_name]) + elif isinstance(self._subscripts[field_name], Set): + stmt = stmt.where(self.turns_table.c[NameConfig._key_column].in_(self._subscripts[field_name])) + async with self.engine.begin() as conn: + return list((await conn.execute(stmt)).fetchall()) - insert_stmt = insert(self.table).values(id=str(key), context=value) - update_stmt = await self._get_update_stmt(insert_stmt) + async def _load_field_keys(self, ctx_id: str, field_name: str) -> List[int]: + stmt = select(self.turns_table.c[NameConfig._key_column]) + stmt = stmt.where(self.turns_table.c[NameConfig._id_column] == ctx_id) + stmt = stmt.where(self.turns_table.c[field_name] != None) # noqa: E711 + async with self.engine.begin() as conn: + return [k[0] for k in (await conn.execute(stmt)).fetchall()] - async with self.engine.connect() as conn: - await conn.execute(update_stmt) - await conn.commit() - - @threadsafe_method - async def get_item_async(self, key: Hashable) -> Context: - stmt = select(self.table.c.context).where(self.table.c.id == str(key)) - async with self.engine.connect() as conn: - result = await conn.execute(stmt) - row = result.fetchone() - if row: - return Context.model_validate(row[0]) - raise KeyError - - @threadsafe_method - async def del_item_async(self, key: Hashable): - stmt = delete(self.table).where(self.table.c.id == str(key)) - async with self.engine.connect() as conn: - await conn.execute(stmt) - await conn.commit() - - @threadsafe_method - async def contains_async(self, key: Hashable) -> bool: - stmt = select(self.table.c.context).where(self.table.c.id == str(key)) - async with self.engine.connect() as conn: - result = await conn.execute(stmt) - return bool(result.fetchone()) - - @threadsafe_method - async def len_async(self) -> int: - stmt = select(func.count()).select_from(self.table) - async with self.engine.connect() as conn: - result = await conn.execute(stmt) - return result.fetchone()[0] - - @threadsafe_method - async def clear_async(self): - stmt = delete(self.table) - async with self.engine.connect() as conn: - await conn.execute(stmt) - await conn.commit() - - async def _create_self_table(self): + async def _load_field_items(self, ctx_id: str, field_name: str, keys: List[int]) -> List[Tuple[int, bytes]]: + stmt = select(self.turns_table.c[NameConfig._key_column], self.turns_table.c[field_name]) + stmt = stmt.where(self.turns_table.c[NameConfig._id_column] == ctx_id) + stmt = stmt.where(self.turns_table.c[NameConfig._key_column].in_(tuple(keys))) + stmt = stmt.where(self.turns_table.c[field_name] != None) # noqa: E711 async with self.engine.begin() as conn: - if not await conn.run_sync(lambda sync_conn: inspect(sync_conn).has_table(self.table.name)): - await conn.run_sync(self.table.create, self.engine) + return list((await conn.execute(stmt)).fetchall()) + + async def _update_field_items(self, ctx_id: str, field_name: str, items: List[Tuple[int, Optional[bytes]]]) -> None: + insert_stmt = self._INSERT_CALLABLE(self.turns_table).values( + [ + { + NameConfig._id_column: ctx_id, + NameConfig._key_column: k, + field_name: v, + } + for k, v in items + ] + ) + update_stmt = _get_upsert_stmt( + self.dialect, + insert_stmt, + [field_name], + [NameConfig._id_column, NameConfig._key_column], + ) + async with self.engine.begin() as conn: + await conn.execute(update_stmt) - async def _get_update_stmt(self, insert_stmt): - if self.dialect == "sqlite": - return insert_stmt - elif self.dialect == "mysql": - update_stmt = insert_stmt.on_duplicate_key_update(context=insert_stmt.inserted.context) - else: - update_stmt = insert_stmt.on_conflict_do_update( - index_elements=["id"], set_=dict(context=insert_stmt.excluded.context) - ) - return update_stmt - - def _check_availability(self, custom_driver: bool): - if not custom_driver: - if self.full_path.startswith("postgresql") and not postgres_available: - install_suggestion = get_protocol_install_suggestion("postgresql") - raise ImportError("Packages `sqlalchemy` and/or `asyncpg` are missing.\n" + install_suggestion) - elif self.full_path.startswith("mysql") and not mysql_available: - install_suggestion = get_protocol_install_suggestion("mysql") - raise ImportError("Packages `sqlalchemy` and/or `asyncmy` are missing.\n" + install_suggestion) - elif self.full_path.startswith("sqlite") and not sqlite_available: - install_suggestion = get_protocol_install_suggestion("sqlite") - raise ImportError("Package `sqlalchemy` and/or `aiosqlite` is missing.\n" + install_suggestion) + async def _clear_all(self) -> None: + async with self.engine.begin() as conn: + await asyncio.gather(conn.execute(delete(self.main_table)), conn.execute(delete(self.turns_table))) diff --git a/chatsky/context_storages/ydb.py b/chatsky/context_storages/ydb.py index 82a4465a6..68690b164 100644 --- a/chatsky/context_storages/ydb.py +++ b/chatsky/context_storages/ydb.py @@ -10,20 +10,25 @@ take advantage of the scalability and high-availability features provided by the service. """ -import asyncio -import os -from typing import Hashable +from asyncio import gather +from os.path import join +from typing import Awaitable, Callable, Set, Tuple, List, Optional from urllib.parse import urlsplit - -from chatsky.core import Context - -from .database import DBContextStorage +from .database import ContextInfo, DBContextStorage, _SUBSCRIPT_DICT, NameConfig from .protocol import get_protocol_install_suggestion try: - import ydb - import ydb.aio + from ydb import ( + SerializableReadWrite, + SchemeError, + TableDescription, + Column, + OptionalType, + PrimitiveType, + ) + from ydb.aio import Driver, SessionPool + from ydb.table import Session ydb_available = True except ImportError: @@ -34,207 +39,324 @@ class YDBContextStorage(DBContextStorage): """ Version of the :py:class:`.DBContextStorage` for YDB. - :param path: Standard sqlalchemy URI string. - When using sqlite backend in Windows, keep in mind that you have to use double backslashes '\\' - instead of forward slashes '/' in the file path. - :param table_name: The name of the table to use. + `CONTEXT` table is represented by `contexts` table. + Columns of the table are: `id`, `current_turn_id`, `created_at` `updated_at`, `misc` and `framework_data`. + + `TURNS` table is represented by `turns` table. + olumns of the table are: `id`, `key`, `label`, `request` and `response`. + + :param path: Standard sqlalchemy URI string. One of `grpc` or `grpcs` can be chosen as a protocol. + Example: `grpc://localhost:2134/local`. + NB! Do not forget to provide credentials in environmental variables + or set `YDB_ANONYMOUS_CREDENTIALS` variable to `1`! + :param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not. + :param partial_read_config: Dictionary of subscripts for all possible turn items. + :param table_name_prefix: "namespace" prefix for the two tables created for context storing. + :param timeout: Waiting timeout for the database driver. """ - def __init__(self, path: str, table_name: str = "contexts", timeout=5): - DBContextStorage.__init__(self, path) + _LIMIT_VAR = "limit" + _KEY_VAR = "key" + + is_concurrent: bool = True + + def __init__( + self, + path: str, + rewrite_existing: bool = False, + partial_read_config: Optional[_SUBSCRIPT_DICT] = None, + table_name_prefix: str = "chatsky_table", + timeout: int = 5, + ): + DBContextStorage.__init__(self, path, rewrite_existing, partial_read_config) + protocol, netloc, self.database, _, _ = urlsplit(path) - self.endpoint = "{}://{}".format(protocol, netloc) - self.table_name = table_name if not ydb_available: install_suggestion = get_protocol_install_suggestion("grpc") raise ImportError("`ydb` package is missing.\n" + install_suggestion) - self.driver, self.pool = asyncio.run(_init_drive(timeout, self.endpoint, self.database, self.table_name)) - - async def set_item_async(self, key: Hashable, value: Context): - value = Context.model_validate(value) - - async def callee(session): - query = """ - PRAGMA TablePathPrefix("{}"); - DECLARE $queryId AS Utf8; - DECLARE $queryContext AS Json; - UPSERT INTO {} - ( - id, - context - ) - VALUES - ( - $queryId, - $queryContext - ); - """.format( - self.database, self.table_name - ) - prepared_query = await session.prepare(query) - await session.transaction(ydb.SerializableReadWrite()).execute( - prepared_query, - {"$queryId": str(key), "$queryContext": value.model_dump_json()}, - commit_tx=True, + self.table_prefix = table_name_prefix + self._timeout = timeout + self._endpoint = f"{protocol}://{netloc}" + + async def _connect(self) -> None: + self._driver = Driver(endpoint=self._endpoint, database=self.database) + client_settings = self._driver.table_client._table_client_settings.with_allow_truncated_result(True) + self._driver.table_client._table_client_settings = client_settings + await self._driver.wait(fail_fast=True, timeout=self._timeout) + + self.pool = SessionPool(self._driver, size=10) + + self.main_table = f"{self.table_prefix}_{NameConfig._main_table}" + if not await self._does_table_exist(self.main_table): + await self._create_main_table(self.main_table) + + self.turns_table = f"{self.table_prefix}_{NameConfig._turns_table}" + if not await self._does_table_exist(self.turns_table): + await self._create_turns_table(self.turns_table) + + async def _does_table_exist(self, table_name: str) -> bool: + async def callee(session: Session) -> None: + await session.describe_table(join(self.database, table_name)) + + try: + await self.pool.retry_operation(callee) + return True + except SchemeError: + return False + + async def _create_main_table(self, table_name: str) -> None: + async def callee(session: Session) -> None: + await session.create_table( + "/".join([self.database, table_name]), + TableDescription() + .with_column(Column(NameConfig._id_column, PrimitiveType.Utf8)) + .with_column(Column(NameConfig._current_turn_id_column, PrimitiveType.Uint64)) + .with_column(Column(NameConfig._created_at_column, PrimitiveType.Uint64)) + .with_column(Column(NameConfig._updated_at_column, PrimitiveType.Uint64)) + .with_column(Column(NameConfig._misc_column, PrimitiveType.String)) + .with_column(Column(NameConfig._framework_data_column, PrimitiveType.String)) + .with_primary_key(NameConfig._id_column), ) - return await self.pool.retry_operation(callee) - - async def get_item_async(self, key: Hashable) -> Context: - async def callee(session): - query = """ - PRAGMA TablePathPrefix("{}"); - DECLARE $queryId AS Utf8; - SELECT - id, - context - FROM {} - WHERE id = $queryId; - """.format( - self.database, self.table_name + await self.pool.retry_operation(callee) + + async def _create_turns_table(self, table_name: str) -> None: + async def callee(session: Session) -> None: + await session.create_table( + "/".join([self.database, table_name]), + TableDescription() + .with_column(Column(NameConfig._id_column, PrimitiveType.Utf8)) + .with_column(Column(NameConfig._key_column, PrimitiveType.Uint32)) + .with_column(Column(NameConfig._labels_field, OptionalType(PrimitiveType.String))) + .with_column(Column(NameConfig._requests_field, OptionalType(PrimitiveType.String))) + .with_column(Column(NameConfig._responses_field, OptionalType(PrimitiveType.String))) + .with_primary_keys(NameConfig._id_column, NameConfig._key_column), ) - prepared_query = await session.prepare(query) - result_sets = await session.transaction(ydb.SerializableReadWrite()).execute( - prepared_query, + await self.pool.retry_operation(callee) + + async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]: + async def callee(session: Session) -> Optional[ContextInfo]: + query = f""" + PRAGMA TablePathPrefix("{self.database}"); + DECLARE ${NameConfig._id_column} AS Utf8; + SELECT {NameConfig._current_turn_id_column}, {NameConfig._created_at_column}, {NameConfig._updated_at_column}, {NameConfig._misc_column}, {NameConfig._framework_data_column} + FROM {self.main_table} + WHERE {NameConfig._id_column} = ${NameConfig._id_column}; + """ # noqa: E501 + result_sets = await session.transaction(SerializableReadWrite()).execute( + await session.prepare(query), { - "$queryId": str(key), + f"${NameConfig._id_column}": ctx_id, }, commit_tx=True, ) - if result_sets[0].rows: - return Context.model_validate_json(result_sets[0].rows[0].context) - else: - raise KeyError + return ( + ContextInfo.model_validate( + { + "turn_id": result_sets[0].rows[0][NameConfig._current_turn_id_column], + "created_at": result_sets[0].rows[0][NameConfig._created_at_column], + "updated_at": result_sets[0].rows[0][NameConfig._updated_at_column], + "misc": result_sets[0].rows[0][NameConfig._misc_column], + "framework_data": result_sets[0].rows[0][NameConfig._framework_data_column], + } + ) + if len(result_sets[0].rows) > 0 + else None + ) return await self.pool.retry_operation(callee) - async def del_item_async(self, key: Hashable): - async def callee(session): - query = """ - PRAGMA TablePathPrefix("{}"); - DECLARE $queryId AS Utf8; - DELETE - FROM {} - WHERE - id = $queryId - ; - """.format( - self.database, self.table_name - ) - prepared_query = await session.prepare(query) - - await session.transaction(ydb.SerializableReadWrite()).execute( - prepared_query, - {"$queryId": str(key)}, + async def _update_main_info(self, ctx_id: str, ctx_info: ContextInfo) -> None: + async def callee(session: Session) -> None: + ctx_info_dump = ctx_info.model_dump(mode="python") + query = f""" + PRAGMA TablePathPrefix("{self.database}"); + DECLARE ${NameConfig._id_column} AS Utf8; + DECLARE ${NameConfig._current_turn_id_column} AS Uint64; + DECLARE ${NameConfig._created_at_column} AS Uint64; + DECLARE ${NameConfig._updated_at_column} AS Uint64; + DECLARE ${NameConfig._misc_column} AS String; + DECLARE ${NameConfig._framework_data_column} AS String; + UPSERT INTO {self.main_table} ({NameConfig._id_column}, {NameConfig._current_turn_id_column}, {NameConfig._created_at_column}, {NameConfig._updated_at_column}, {NameConfig._misc_column}, {NameConfig._framework_data_column}) + VALUES (${NameConfig._id_column}, ${NameConfig._current_turn_id_column}, ${NameConfig._created_at_column}, ${NameConfig._updated_at_column}, ${NameConfig._misc_column}, ${NameConfig._framework_data_column}); + """ # noqa: E501 + await session.transaction(SerializableReadWrite()).execute( + await session.prepare(query), + { + f"${NameConfig._id_column}": ctx_id, + f"${NameConfig._current_turn_id_column}": ctx_info_dump["turn_id"], + f"${NameConfig._created_at_column}": ctx_info_dump["created_at"], + f"${NameConfig._updated_at_column}": ctx_info_dump["updated_at"], + f"${NameConfig._misc_column}": ctx_info_dump["misc"], + f"${NameConfig._framework_data_column}": ctx_info_dump["framework_data"], + }, commit_tx=True, ) - return await self.pool.retry_operation(callee) + await self.pool.retry_operation(callee) + + async def _delete_context(self, ctx_id: str) -> None: + def construct_callee(table_name: str) -> Callable[[Session], Awaitable[None]]: + async def callee(session: Session) -> None: + query = f""" + PRAGMA TablePathPrefix("{self.database}"); + DECLARE ${NameConfig._id_column} AS Utf8; + DELETE FROM {table_name} + WHERE {NameConfig._id_column} = ${NameConfig._id_column}; + """ # noqa: E501 + await session.transaction(SerializableReadWrite()).execute( + await session.prepare(query), + { + f"${NameConfig._id_column}": ctx_id, + }, + commit_tx=True, + ) - async def contains_async(self, key: Hashable) -> bool: - async def callee(session): - # new transaction in serializable read write mode - # if query successfully completed you will get result sets. - # otherwise exception will be raised - query = """ - PRAGMA TablePathPrefix("{}"); - DECLARE $queryId AS Utf8; - SELECT - id, - context - FROM {} - WHERE id = $queryId; - """.format( - self.database, self.table_name - ) - prepared_query = await session.prepare(query) + return callee + + await gather( + self.pool.retry_operation(construct_callee(self.main_table)), + self.pool.retry_operation(construct_callee(self.turns_table)), + ) - result_sets = await session.transaction(ydb.SerializableReadWrite()).execute( - prepared_query, + async def _load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]: + async def callee(session: Session) -> List[Tuple[int, bytes]]: + declare, prepare, limit, key = list(), dict(), "", "" + if isinstance(self._subscripts[field_name], int): + declare += [f"DECLARE ${self._LIMIT_VAR} AS Uint64;"] + prepare.update({f"${self._LIMIT_VAR}": self._subscripts[field_name]}) + limit = f"LIMIT ${self._LIMIT_VAR}" + elif isinstance(self._subscripts[field_name], Set): + values = list() + for i, k in enumerate(self._subscripts[field_name]): + declare += [f"DECLARE ${self._KEY_VAR}_{i} AS Utf8;"] + prepare.update({f"${self._KEY_VAR}_{i}": k}) + values += [f"${self._KEY_VAR}_{i}"] + key = f"AND {self._KEY_VAR} IN ({', '.join(values)})" + query = f""" + PRAGMA TablePathPrefix("{self.database}"); + DECLARE ${NameConfig._id_column} AS Utf8; + {" ".join(declare)} + SELECT {NameConfig._key_column}, {field_name} + FROM {self.turns_table} + WHERE {NameConfig._id_column} = ${NameConfig._id_column} AND {field_name} IS NOT NULL {key} + ORDER BY {NameConfig._key_column} DESC {limit}; + """ # noqa: E501 + result_sets = await session.transaction(SerializableReadWrite()).execute( + await session.prepare(query), { - "$queryId": str(key), + f"${NameConfig._id_column}": ctx_id, + **prepare, }, commit_tx=True, ) - return len(result_sets[0].rows) > 0 + return ( + [(e[NameConfig._key_column], e[field_name]) for e in result_sets[0].rows] + if len(result_sets[0].rows) > 0 + else list() + ) return await self.pool.retry_operation(callee) - async def len_async(self) -> int: - async def callee(session): - query = """ - PRAGMA TablePathPrefix("{}"); - SELECT - COUNT(*) as cnt - FROM {} - """.format( - self.database, self.table_name - ) - prepared_query = await session.prepare(query) - - result_sets = await session.transaction(ydb.SerializableReadWrite()).execute( - prepared_query, + async def _load_field_keys(self, ctx_id: str, field_name: str) -> List[int]: + async def callee(session: Session) -> List[int]: + query = f""" + PRAGMA TablePathPrefix("{self.database}"); + DECLARE ${NameConfig._id_column} AS Utf8; + SELECT {NameConfig._key_column} + FROM {self.turns_table} + WHERE {NameConfig._id_column} = ${NameConfig._id_column} AND {field_name} IS NOT NULL; + """ # noqa: E501 + result_sets = await session.transaction(SerializableReadWrite()).execute( + await session.prepare(query), + { + f"${NameConfig._id_column}": ctx_id, + }, commit_tx=True, ) - return result_sets[0].rows[0].cnt + return [e[NameConfig._key_column] for e in result_sets[0].rows] if len(result_sets[0].rows) > 0 else list() return await self.pool.retry_operation(callee) - async def clear_async(self): - async def callee(session): - query = """ - PRAGMA TablePathPrefix("{}"); - DECLARE $queryId AS Utf8; - DELETE - FROM {} - ; - """.format( - self.database, self.table_name - ) - prepared_query = await session.prepare(query) - - await session.transaction(ydb.SerializableReadWrite()).execute( - prepared_query, - {}, + async def _load_field_items(self, ctx_id: str, field_name: str, keys: List[int]) -> List[Tuple[int, bytes]]: + async def callee(session: Session) -> List[Tuple[int, bytes]]: + declare, prepare = list(), dict() + for i, k in enumerate(keys): + declare += [f"DECLARE ${self._KEY_VAR}_{i} AS Uint32;"] + prepare.update({f"${self._KEY_VAR}_{i}": k}) + query = f""" + PRAGMA TablePathPrefix("{self.database}"); + DECLARE ${NameConfig._id_column} AS Utf8; + {" ".join(declare)} + SELECT {NameConfig._key_column}, {field_name} + FROM {self.turns_table} + WHERE {NameConfig._id_column} = ${NameConfig._id_column} AND {field_name} IS NOT NULL + AND {NameConfig._key_column} IN ({", ".join(prepare.keys())}); + """ # noqa: E501 + result_sets = await session.transaction(SerializableReadWrite()).execute( + await session.prepare(query), + { + f"${NameConfig._id_column}": ctx_id, + **prepare, + }, commit_tx=True, ) + return ( + [(e[NameConfig._key_column], e[field_name]) for e in result_sets[0].rows] + if len(result_sets[0].rows) > 0 + else list() + ) return await self.pool.retry_operation(callee) + async def _update_field_items(self, ctx_id: str, field_name: str, items: List[Tuple[int, Optional[bytes]]]) -> None: + async def callee(session: Session) -> None: + declare, prepare, values = list(), dict(), list() + for i, (k, v) in enumerate(items): + declare += [f"DECLARE ${self._KEY_VAR}_{i} AS Uint32;"] + prepare.update({f"${self._KEY_VAR}_{i}": k}) + if v is not None: + declare += [f"DECLARE ${field_name}_{i} AS String;"] + prepare.update({f"${field_name}_{i}": v}) + value_param = f"${field_name}_{i}" + else: + value_param = "NULL" + values += [f"(${NameConfig._id_column}, ${self._KEY_VAR}_{i}, {value_param})"] + query = f""" + PRAGMA TablePathPrefix("{self.database}"); + DECLARE ${NameConfig._id_column} AS Utf8; + {" ".join(declare)} + UPSERT INTO {self.turns_table} ({NameConfig._id_column}, {NameConfig._key_column}, {field_name}) + VALUES {", ".join(values)}; + """ # noqa: E501 + + await session.transaction(SerializableReadWrite()).execute( + await session.prepare(query), + { + f"${NameConfig._id_column}": ctx_id, + **prepare, + }, + commit_tx=True, + ) -async def _init_drive(timeout: int, endpoint: str, database: str, table_name: str): - driver = ydb.aio.Driver(endpoint=endpoint, database=database) - await driver.wait(fail_fast=True, timeout=timeout) - - pool = ydb.aio.SessionPool(driver, size=10) - - if not await _is_table_exists(pool, database, table_name): # create table if it does not exist - await _create_table(pool, database, table_name) - return driver, pool - - -async def _is_table_exists(pool, path, table_name) -> bool: - try: - - async def callee(session): - await session.describe_table(os.path.join(path, table_name)) - - await pool.retry_operation(callee) - return True - except ydb.SchemeError: - return False + await self.pool.retry_operation(callee) + + async def _clear_all(self) -> None: + def construct_callee(table_name: str) -> Callable[[Session], Awaitable[None]]: + async def callee(session: Session) -> None: + query = f""" + PRAGMA TablePathPrefix("{self.database}"); + DELETE FROM {table_name}; + """ # noqa: E501 + await session.transaction(SerializableReadWrite()).execute( + await session.prepare(query), dict(), commit_tx=True + ) + return callee -async def _create_table(pool, path, table_name): - async def callee(session): - await session.create_table( - "/".join([path, table_name]), - ydb.TableDescription() - .with_column(ydb.Column("id", ydb.OptionalType(ydb.PrimitiveType.Utf8))) - .with_column(ydb.Column("context", ydb.OptionalType(ydb.PrimitiveType.Json))) - .with_primary_key("id"), + await gather( + self.pool.retry_operation(construct_callee(self.main_table)), + self.pool.retry_operation(construct_callee(self.turns_table)), ) - - return await pool.retry_operation(callee) diff --git a/chatsky/core/context.py b/chatsky/core/context.py index 59a9899a2..8adbee610 100644 --- a/chatsky/core/context.py +++ b/chatsky/core/context.py @@ -17,16 +17,19 @@ """ from __future__ import annotations +from asyncio import Event, gather +from uuid import uuid4 +from time import time_ns +from typing import Any, Callable, Iterable, Optional, Dict, TYPE_CHECKING, Tuple, Union import logging -import asyncio -from uuid import UUID, uuid4 -from typing import Any, Optional, Union, Dict, TYPE_CHECKING -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr, TypeAdapter, model_validator -from chatsky.core.message import Message, MessageInitTypes +from chatsky.context_storages.database import DBContextStorage, ContextInfo, NameConfig +from chatsky.core.message import Message from chatsky.slots.slots import SlotManager -from chatsky.core.node_label import AbsoluteNodeLabel, AbsoluteNodeLabelInitTypes +from chatsky.core.node_label import AbsoluteNodeLabel +from chatsky.core.ctx_dict import LabelContextDict, MessageContextDict if TYPE_CHECKING: from chatsky.core.service import ComponentExecutionState @@ -36,20 +39,6 @@ logger = logging.getLogger(__name__) -def get_last_index(dictionary: dict) -> int: - """ - Obtain the last index from the `dictionary`. - - :param dictionary: Dictionary with unsorted keys. - :return: Last index from the `dictionary`. - :raises ValueError: If the dictionary is empty. - """ - if len(dictionary) == 0: - raise ValueError("Dictionary is empty.") - indices = list(dictionary) - return max(indices) - - class ContextError(Exception): """Raised when context methods are not used correctly.""" @@ -60,7 +49,7 @@ class ServiceState(BaseModel, arbitrary_types_allowed=True): :py:class:`.ComponentExecutionState` of this pipeline service. Cleared at the end of every turn. """ - finished_event: asyncio.Event = Field(default_factory=asyncio.Event) + finished_event: Event = Field(default_factory=Event) """ Asyncio `Event` which can be awaited until this service finishes. Cleared at the end of every turn. @@ -98,44 +87,55 @@ class Context(BaseModel): A structure that is used to store data about the context of a dialog. """ - id: Union[UUID, int, str] = Field(default_factory=uuid4) + id: str = Field(default_factory=lambda: str(uuid4()), exclude=True, frozen=True) """ - ``id`` is the unique context identifier. By default, randomly generated using ``uuid4``. - ``id`` can be used to trace the user behavior, e.g while collecting the statistical data. + `id` is the unique context identifier. By default, randomly generated using `uuid4` is used. """ - labels: Dict[int, AbsoluteNodeLabel] = Field(default_factory=dict) + _created_at: int = PrivateAttr(default_factory=time_ns) """ - ``labels`` stores the history of labels for all passed nodes. - - - key - ``id`` of the turn. - - value - ``label`` of this turn. - - Start label is stored at key ``0``. - IDs go up by ``1`` after that. + Timestamp when the context was **first time saved to database**. + It is set (and managed) by :py:class:`~chatsky.context_storages.DBContextStorage`. """ - requests: Dict[int, Message] = Field(default_factory=dict) + _updated_at: int = PrivateAttr(default_factory=time_ns) """ - ``requests`` stores the history of all requests received by the pipeline. - - - key - ``id`` of the turn. - - value - ``request`` of this turn. + Timestamp when the context was **last time saved to database**. + It is set (and managed) by :py:class:`~chatsky.context_storages.DBContextStorage`. + """ + current_turn_id: int = Field(default=0) + """ + Current turn number, specifies the last turn number, + that is also the last turn available in `labels`, `requests`, and `responses`. + """ + labels: LabelContextDict = Field(default_factory=LabelContextDict) + """ + `labels` stores dialog labels. + A new label is stored in the dictionary on every turn, the keys are consecutive integers. + The first ever (initial) has key `0`. - First request is stored at key ``1``. - IDs go up by ``1`` after that. + - key - Label identification numbers. + - value - Label data: `AbsoluteNodeLabel`. """ - responses: Dict[int, Message] = Field(default_factory=dict) + requests: MessageContextDict = Field(default_factory=MessageContextDict) """ - ``responses`` stores the history of all responses produced by the pipeline. + `requests` stores dialog requests. + A new request is stored in the dictionary on every turn, the keys are consecutive integers. + The first ever (initial) has key `1`. - - key - ``id`` of the turn. - - value - ``response`` of this turn. + - key - Request identification numbers. + - value - Request data: `Message`. + """ + responses: MessageContextDict = Field(default_factory=MessageContextDict) + """ + `responses` stores dialog responses. + A new response is stored in the dictionary on every turn, the keys are consecutive integers. + The first ever (initial) has key `1`. - First response is stored at key ``1``. - IDs go up by ``1`` after that. + - key - Response identification numbers. + - value - Response data: `Message`. """ misc: Dict[str, Any] = Field(default_factory=dict) """ - ``misc`` stores any custom data. The framework doesn't use this dictionary, + `misc` stores any custom data. The framework doesn't use this dictionary, so storage of any data won't reflect on the work of the internal Chatsky functions. - key - Arbitrary data name. @@ -146,6 +146,7 @@ class Context(BaseModel): This attribute is used for storing custom data required for pipeline execution. It is meant to be used by the framework only. Accessing it may result in pipeline breakage. """ + _storage: Optional[DBContextStorage] = PrivateAttr(None) origin_interface: Optional[str] = Field(default=None) """ @@ -153,92 +154,123 @@ class Context(BaseModel): """ @classmethod - def init(cls, start_label: AbsoluteNodeLabelInitTypes, id: Optional[Union[UUID, int, str]] = None): - """Initialize new context from ``start_label`` and, optionally, context ``id``.""" - init_kwargs = { - "labels": {0: AbsoluteNodeLabel.model_validate(start_label)}, - } - if id is None: - return cls(**init_kwargs) - else: - return cls(**init_kwargs, id=id) - - def add_request(self, request: MessageInitTypes): + async def connected( + cls, storage: DBContextStorage, start_label: Optional[AbsoluteNodeLabel] = None, id: Optional[str] = None + ) -> Context: """ - Add a new ``request`` to the context. + Create context **connected** to the given database storage. + If context ID is given, the corresponding context is loaded from the database. + If the context does not exist in database or ID is `None`, a new context with new ID is created. + A connected context can be later stored in the database. + + :param storage: context storage to connect to. + :param start_label: new context start label (will be set only if the context is created). + :param id: context ID. + :return: context, connected to the database. """ - request_message = Message.model_validate(request) - if len(self.requests) == 0: - self.requests[1] = request_message - if request_message.origin is not None: - self.origin_interface = request_message.origin.interface - else: - last_index = get_last_index(self.requests) - self.requests[last_index + 1] = request_message - def add_response(self, response: MessageInitTypes): - """ - Add a new ``response`` to the context. - """ - response_message = Message.model_validate(response) - if len(self.responses) == 0: - self.responses[1] = response_message + if id is None: + uid = str(uuid4()) + logger.debug(f"Disconnected context created with uid: {uid}") + instance = cls(id=uid) + instance.requests = await MessageContextDict.new(storage, uid, NameConfig._requests_field) + instance.responses = await MessageContextDict.new(storage, uid, NameConfig._responses_field) + instance.labels = await LabelContextDict.new(storage, uid, NameConfig._labels_field) + await instance.labels.update({0: start_label}) + instance._storage = storage + return instance else: - last_index = get_last_index(self.responses) - self.responses[last_index + 1] = response_message - - def add_label(self, label: AbsoluteNodeLabelInitTypes): + if not isinstance(id, str): + logger.warning(f"Id is not a string: {id}. Converting to string.") + id = str(id) + logger.debug(f"Connected context created with uid: {id}") + main, labels, requests, responses = await gather( + storage.load_main_info(id), + LabelContextDict.connected(storage, id, NameConfig._labels_field), + MessageContextDict.connected(storage, id, NameConfig._requests_field), + MessageContextDict.connected(storage, id, NameConfig._responses_field), + ) + if main is None: + crt_at = upd_at = time_ns() + turn_id = 0 + misc = dict() + fw_data = FrameworkData() + labels[0] = start_label + else: + turn_id = main.turn_id + crt_at = main.created_at + upd_at = main.updated_at + misc = main.misc + fw_data = main.framework_data + logger.debug(f"Context loaded with turns number: {len(labels)}") + instance = cls( + id=id, + current_turn_id=turn_id, + labels=labels, + requests=requests, + responses=responses, + misc=misc, + framework_data=fw_data, + ) + instance._created_at, instance._updated_at, instance._storage = crt_at, upd_at, storage + return instance + + async def delete(self) -> None: """ - Add a new :py:class:`~.AbsoluteNodeLabel` to the context. - - :raises ContextError: If :py:attr:`labels` is empty. + Delete connected context from the context storage and disconnect it. + Throw an error if the context is not connected. + No local context fields will be affected. + If the context is not connected, throw a runtime error. """ - label = AbsoluteNodeLabel.model_validate(label) - if len(self.labels) == 0: - raise ContextError("Labels are empty. Use `Context.init` to initialize context with labels.") - last_index = get_last_index(self.labels) - self.labels[last_index + 1] = label + + if self._storage is not None: + await self._storage.delete_context(self.id) + self._storage = None + else: + raise RuntimeError(f"{type(self).__name__} is not attached to any context storage!") @property def last_label(self) -> AbsoluteNodeLabel: """ - Return the last :py:class:`~.AbsoluteNodeLabel` of - the :py:class:`~.Context`. - - :raises ContextError: If :py:attr:`labels` is empty. + Receive last turn label. + Throw an error if no labels are present or the last label is absent. + :return: The last turn label. """ + if len(self.labels) == 0: - raise ContextError("Labels are empty. Use `Context.init` to initialize context with labels.") - last_index = get_last_index(self.labels) - return self.labels[last_index] + raise ContextError("Labels are empty.") + return self.labels._items[self.labels.keys()[-1]] @property - def last_response(self) -> Optional[Message]: + def last_response(self) -> Message: """ - Return the last response of the current :py:class:`~.Context`. - Return ``None`` if no responses have been added yet. + Receive last turn response. + Throw an error if no responses are present or the last response is absent. + :return: The last turn response. """ + if len(self.responses) == 0: - return None - last_index = get_last_index(self.responses) - response = self.responses[last_index] - return response + raise ContextError("Responses are empty.") + return self.responses._items[self.responses.keys()[-1]] @property def last_request(self) -> Message: """ - Return the last request of the current :py:class:`~.Context`. - - :raises ContextError: If :py:attr:`responses` is empty. + Receive last turn request. + Throw an error if no requests are present or the last request is absent. + :return: The last turn request. """ + if len(self.requests) == 0: - raise ContextError("No requests have been added.") - last_index = get_last_index(self.requests) - return self.requests[last_index] + raise ContextError("Requests are empty.") + return self.requests._items[self.requests.keys()[-1]] @property def pipeline(self) -> Pipeline: - """Return :py:attr:`.FrameworkData.pipeline`.""" + """ + Return :py:attr:`.FrameworkData.pipeline`. + """ + pipeline = self.framework_data.pipeline if pipeline is None: raise ContextError("Pipeline is not set.") @@ -246,8 +278,95 @@ def pipeline(self) -> Pipeline: @property def current_node(self) -> Node: - """Return :py:attr:`.FrameworkData.current_node`.""" + """ + Return :py:attr:`.FrameworkData.current_node`. + """ + node = self.framework_data.current_node if node is None: raise ContextError("Current node is not set.") return node + + async def turns(self, key: Union[int, slice]) -> Iterable[Tuple[AbsoluteNodeLabel, Message, Message]]: + """ + Get one or more nodes, requests and responses sharing common keys simultaneously. + Acts just like context dict `get` method, but queries all three dicts at the same time asinchronously. + :param key: Context dict key that will be queried from `labels`, `requests` and `responses`. + :return: Tuples of (`label`, `request`, `response`), sharing a common key. + """ + + turn_ids = range(self.current_turn_id + 1)[key] + turn_ids = turn_ids if isinstance(key, slice) else [turn_ids] + context_dicts = (self.labels, self.requests, self.responses) + turns_lists = await gather(*[gather(*[ctd.get(ti, None) for ti in turn_ids]) for ctd in context_dicts]) + return zip(*turns_lists) + + def __eq__(self, value: object) -> bool: + if isinstance(value, Context): + return ( + self.id == value.id + and self.current_turn_id == value.current_turn_id + and self.labels == value.labels + and self.requests == value.requests + and self.responses == value.responses + and self.misc == value.misc + and self.framework_data == value.framework_data + and self._storage == value._storage + ) + else: + return False + + @model_validator(mode="wrap") + def _validate_model(value: Any, handler: Callable[[Any], "Context"], _) -> "Context": + if isinstance(value, Context): + return value + elif isinstance(value, Dict): + instance = handler(value) + labels_obj = value.get("labels", dict()) + if isinstance(labels_obj, Dict): + labels_obj = TypeAdapter(Dict[int, AbsoluteNodeLabel]).validate_python(labels_obj) + instance.labels = LabelContextDict.model_validate(labels_obj) + instance.labels._ctx_id = instance.id + requests_obj = value.get("requests", dict()) + if isinstance(requests_obj, Dict): + requests_obj = TypeAdapter(Dict[int, Message]).validate_python(requests_obj) + instance.requests = MessageContextDict.model_validate(requests_obj) + instance.requests._ctx_id = instance.id + responses_obj = value.get("responses", dict()) + if isinstance(responses_obj, Dict): + responses_obj = TypeAdapter(Dict[int, Message]).validate_python(responses_obj) + instance.responses = MessageContextDict.model_validate(responses_obj) + instance.responses._ctx_id = instance.id + return instance + else: + raise ValueError(f"Unknown type of Context value: {type(value).__name__}!") + + async def store(self) -> None: + """ + Store connected context in the context storage. + Depending on the context storage settings ("rewrite_existing" flag in particular), + either only write new and deleted values or also modify the changed ones. + All the context storage tables are updated asynchronously and simultaneously. + """ + + if self._storage is not None: + logger.debug(f"Storing context: {self.id}...") + self._updated_at = time_ns() + await gather( + self._storage.update_main_info( + self.id, + ContextInfo( + turn_id=self.current_turn_id, + created_at=self._created_at, + updated_at=self._updated_at, + misc=self.misc, + framework_data=self.framework_data, + ), + ), + self.labels.store(), + self.requests.store(), + self.responses.store(), + ) + logger.debug(f"Context stored: {self.id}") + else: + raise RuntimeError(f"{type(self).__name__} is not attached to any context storage!") diff --git a/chatsky/core/ctx_dict.py b/chatsky/core/ctx_dict.py new file mode 100644 index 000000000..0fd79705a --- /dev/null +++ b/chatsky/core/ctx_dict.py @@ -0,0 +1,389 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from asyncio import gather +from hashlib import sha256 +import logging +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + overload, + TYPE_CHECKING, +) + +from pydantic import BaseModel, PrivateAttr, TypeAdapter, model_serializer, model_validator + +from chatsky.core.message import Message +from chatsky.core.node_label import AbsoluteNodeLabel +from chatsky.utils.logging import collapse_num_list + +if TYPE_CHECKING: + from chatsky.context_storages.database import DBContextStorage + +K = TypeVar("K", bound=int) +V = TypeVar("V") + +logger = logging.getLogger(__name__) + + +def _get_hash(string: bytes) -> bytes: + return sha256(string).digest() + + +class ContextDict(ABC, BaseModel, Generic[K, V]): + """ + Dictionary-like structure for storing different dialog types in a context storage. + It holds all the possible keys, but may not store all the values locally. + Some of them might be loaded lazily upon querying. + """ + + _items: Dict[K, V] = PrivateAttr(default_factory=dict) + """ + Already loaded from storage items collection. + """ + + _hashes: Dict[K, int] = PrivateAttr(default_factory=dict) + """ + Hashes of the loaded items (as they were upon loading), only populated if `rewrite_existing` flag is enabled. + """ + + _keys: Set[K] = PrivateAttr(default_factory=set) + """ + All the item keys available in the storage. + """ + + _added: Set[K] = PrivateAttr(default_factory=set) + """ + Keys added localy (need to be synchronized with the storage). + """ + + _removed: Set[K] = PrivateAttr(default_factory=set) + """ + Keys removed localy (need to be synchronized with the storage). + """ + + _storage: Optional[DBContextStorage] = PrivateAttr(None) + """ + Context storage for item synchronization. + """ + + _ctx_id: str = PrivateAttr(default_factory=str) + """ + Corresponding context ID. + """ + + _field_name: str = PrivateAttr(default_factory=str) + """ + Name of the field that is represented by the given dict. + """ + + @property + @abstractmethod + def _value_type(self) -> TypeAdapter[Type[V]]: + raise NotImplementedError + + @classmethod + async def new(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict[K, V]": + """ + Create a new context dict, without connecting it to the context storage. + No keys or items will be loaded, but any newly added items will be available for synchronization. + Should be used when we are *sure* that context with given ID does not exist in the storage. + + :param storage: Context storage, where the new items will be added. + :param id: Newly created context ID. + :param field: Current dict field name. + :return: New "disconnected" context dict. + """ + + instance = cls() + logger.debug(f"Disconnected context dict created for id {id} and field name: {field}") + instance._ctx_id = id + instance._field_name = field + instance._storage = storage + return instance + + @classmethod + async def connected(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict[K, V]": + """ + Create a new context dict, connecting it to the context storage. + All the keys and some items will be loaded, all the other items will be available for synchronization. + Also hashes will be calculated for the initially loaded items for modification tracking. + + :param storage: Context storage, keeping the current context. + :param id: Newly created context ID. + :param field: Current dict field name. + :return: New "connected" context dict. + """ + + logger.debug(f"Connected context dict created for {id}, {field}") + keys, items = await gather(storage.load_field_keys(id, field), storage.load_field_latest(id, field)) + val_key_items = [(k, v) for k, v in items if v is not None] + logger.debug(f"Context dict for {id}, {field} loaded: {collapse_num_list(keys)}") + instance = cls() + instance._storage = storage + instance._ctx_id = id + instance._field_name = field + instance._keys = set(keys) + instance._items = {k: instance._value_type.validate_json(v) for k, v in val_key_items} + instance._hashes = {k: _get_hash(v) for k, v in val_key_items} if storage.rewrite_existing else dict() + return instance + + async def _load_items(self, keys: List[K]) -> None: + """ + Load items for the given keys from the connected context storage. + Update the `_items` and `_hashes` fields if necessary. + NB! If not all the requested items are available, only the successfully loaded will be updated and no error will be raised. + + :param keys: The requested key array. + """ + + logger.debug( + f"Context dict for {self._ctx_id}, {self._field_name} loading extra items: {collapse_num_list(keys)}..." + ) + items = await self._storage.load_field_items(self._ctx_id, self._field_name, keys) + logger.debug( + f"Context dict for {self._ctx_id}, {self._field_name} extra items loaded: {collapse_num_list(keys)}" + ) + for key, value in items: + self._items[key] = self._value_type.validate_json(value) + if not self._storage.rewrite_existing: + self._hashes[key] = _get_hash(value) + + @overload + async def __getitem__(self, key: K) -> V: ... # noqa: E704 + + @overload + async def __getitem__(self, key: slice) -> List[V]: ... # noqa: E704 + + async def __getitem__(self, key): + if isinstance(key, int) and key < 0: + key = self.keys()[key] + if self._storage is not None: + if isinstance(key, slice): + await self._load_items( + [self.keys()[k] for k in range(len(self.keys()))[key] if k not in self._items.keys()] + ) + elif key not in self._items.keys(): + await self._load_items([key]) + if isinstance(key, slice): + return [self._items[k] for k in self.keys()[key]] + else: + return self._items[key] + + def __setitem__(self, key: Union[K, slice], value: Union[V, Sequence[V]]) -> None: + if isinstance(key, int) and key < 0: + key = self.keys()[key] + if isinstance(key, slice): + if isinstance(value, Sequence): + key_slice = self.keys()[key] + if len(key_slice) != len(value): + raise ValueError("Slices must have the same length!") + for k, v in zip(key_slice, value): + self[k] = v + else: + raise ValueError("Slice key must have sequence value!") + else: + self._keys.add(key) + self._added.add(key) + self._removed.discard(key) + self._items[key] = self._value_type.validate_python(value) + + def __delitem__(self, key: Union[K, slice]) -> None: + if isinstance(key, int) and key < 0: + key = self.keys()[key] + if isinstance(key, slice): + for k in self.keys()[key]: + del self[k] + else: + self._removed.add(key) + self._added.discard(key) + self._keys.discard(key) + del self._items[key] + + def __iter__(self) -> Sequence[K]: + return iter(self.keys() if self._storage is not None else self._items.keys()) + + def __len__(self) -> int: + return len(self.keys() if self._storage is not None else self._items.keys()) + + @overload + async def get(self, key: K, default=None) -> V: ... # noqa: E704 + + @overload + async def get(self, key: Iterable[K], default=None) -> List[V]: ... # noqa: E704 + + async def get(self, key, default=None): + """ + Get one or many items from the dict. + Asynchronously load missing ones, if context storage is connected. + Raise an error if any requested elements are still missing after. + + :param key: Key or slice for item retrieving. + :param default: Default value. + :return: One value or value list. + """ + + try: + return await self[key] + except KeyError: + if isinstance(key, Iterable): + return [self._items.get(k, default) for k in key] + else: + return default + + def __contains__(self, key: K) -> bool: + return key in self.keys() + + def keys(self) -> List[K]: + return sorted(self._keys) + + async def values(self) -> List[V]: + return await self[:] + + async def items(self) -> List[Tuple[K, V]]: + return [(k, v) for k, v in zip(self.keys(), await self.values())] + + async def pop(self, key: K, default=None) -> V: + try: + value = await self[key] + except KeyError: + return default + else: + del self[key] + return value + + async def popitem(self) -> Tuple[K, V]: + try: + key = next(iter(self)) + except StopIteration: + raise KeyError from None + value = await self[key] + del self[key] + return key, value + + def clear(self) -> None: + del self[:] + + async def update(self, other: Any = (), /, **kwds) -> None: + if isinstance(other, ContextDict): + await self.update(zip(other.keys(), await other.values())) + elif isinstance(other, Mapping): + for key in other: + self[key] = other[key] + elif hasattr(other, "keys"): + for key in other.keys(): + self[key] = other[key] + else: + for key, value in other: + self[key] = value + for key, value in kwds.items(): + self[key] = value + + async def setdefault(self, key: K, default=None) -> V: + try: + return await self[key] + except KeyError: + self[key] = default + return default + + def __eq__(self, value: object) -> bool: + if isinstance(value, ContextDict): + return self._items == value._items + elif isinstance(value, Dict): + return self._items == value + else: + return False + + def __repr__(self) -> str: + return ( + f"ContextDict(items={self._items}, " + f"keys={list(self.keys())}, " + f"hashes={self._hashes}, " + f"added={self._added}, " + f"removed={self._removed}, " + f"storage={self._storage}, " + f"ctx_id={self._ctx_id}, " + f"field_name={self._field_name})" + ) + + @model_validator(mode="wrap") + def _validate_model(value: Any, handler: Callable[[Any], "ContextDict"], _) -> "ContextDict": + if isinstance(value, ContextDict): + return value + elif isinstance(value, Dict): + instance = handler(dict()) + instance._items = value.copy() + instance._keys = set(value.keys()) + return instance + else: + raise ValueError(f"Unknown type of ContextDict value: {type(value).__name__}!") + + @model_serializer() + def _serialize_model(self) -> Dict[K, V]: + if self._storage is None: + return self._items + elif not self._storage.rewrite_existing: + result = dict() + for k, v in self._items.items(): + value = self._value_type.dump_json(v) + if _get_hash(value) != self._hashes.get(k, None): + result[k] = value.decode() + return result + else: + return {k: self._value_type.dump_json(self._items[k]).decode() for k in self._added} + + async def store(self) -> None: + """ + Synchronize dict state with the connected storage. + Update added and removed elements, also update modified ones if `rewrite_existing` flag is enabled. + Raise an error if no storage is connected. + """ + + if self._storage is not None: + logger.debug(f"Storing context dict for {self._ctx_id}, {self._field_name}...") + stored = [(k, e.encode()) for k, e in self.model_dump().items()] + await gather( + self._storage.update_field_items(self._ctx_id, self._field_name, stored), + self._storage.delete_field_keys(self._ctx_id, self._field_name, list(self._removed - self._added)), + ) + logger.debug( + f"Context dict for {self._ctx_id}, {self._field_name} stored: " + f"{collapse_num_list([k for k, _ in stored])}" + ) + self._added, self._removed = set(), set() + if not self._storage.rewrite_existing: + for k, v in self._items.items(): + self._hashes[k] = _get_hash(self._value_type.dump_json(v)) + else: + raise RuntimeError(f"{type(self).__name__} is not attached to any context storage!") + + +class LabelContextDict(ContextDict[int, AbsoluteNodeLabel]): + """ + Context dictionary for storing `AbsoluteNodeLabel` types. + """ + + @property + def _value_type(self) -> TypeAdapter[Type[AbsoluteNodeLabel]]: + return TypeAdapter(AbsoluteNodeLabel) + + +class MessageContextDict(ContextDict[int, Message]): + """ + Context dictionary for storing `Message` types. + """ + + @property + def _value_type(self) -> TypeAdapter[Type[Message]]: + return TypeAdapter(Message) diff --git a/chatsky/core/pipeline.py b/chatsky/core/pipeline.py index 20caa74ea..9b7dcb46a 100644 --- a/chatsky/core/pipeline.py +++ b/chatsky/core/pipeline.py @@ -11,14 +11,14 @@ import asyncio import logging from functools import cached_property -from typing import Union, List, Dict, Optional, Hashable +from typing import Union, List, Optional from pydantic import BaseModel, Field, model_validator, computed_field -from chatsky.context_storages import DBContextStorage from chatsky.core.script import Script from chatsky.core.context import Context from chatsky.core.message import Message +from chatsky.context_storages import DBContextStorage, MemoryContextStorage from chatsky.messengers.console import CLIMessengerInterface from chatsky.messengers.common import MessengerInterface from chatsky.slots.slots import GroupSlot @@ -84,7 +84,7 @@ class Pipeline(BaseModel, extra="forbid", arbitrary_types_allowed=True): It handles connections to interfaces that provide user requests and accept bot responses. """ - context_storage: Union[DBContextStorage, Dict] = Field(default_factory=dict) + context_storage: DBContextStorage = Field(default_factory=MemoryContextStorage) """ A :py:class:`~.DBContextStorage` instance for this pipeline or a dict to store dialog :py:class:`~.Context`. @@ -117,7 +117,7 @@ def __init__( default_priority: float = None, slots: GroupSlot = None, messenger_interface: MessengerInterface = None, - context_storage: Union[DBContextStorage, dict] = None, + context_storage: DBContextStorage = None, pre_services: ServiceGroupInitTypes = None, post_services: ServiceGroupInitTypes = None, before_handler: ComponentExtraHandlerInitTypes = None, @@ -223,7 +223,7 @@ def validate_fallback_label(self): return self async def _run_pipeline( - self, request: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None + self, request: Message, ctx_id: Optional[str] = None, update_ctx_misc: Optional[dict] = None ) -> Context: """ Method that should be invoked on user input. @@ -243,12 +243,7 @@ async def _run_pipeline( """ logger.info(f"Running pipeline for context {ctx_id}.") logger.debug(f"Received request: {request}.") - if ctx_id is None: - ctx = Context.init(self.start_label) - elif isinstance(self.context_storage, DBContextStorage): - ctx = await self.context_storage.get_async(ctx_id, Context.init(self.start_label, id=ctx_id)) - else: - ctx = self.context_storage.get(ctx_id, Context.init(self.start_label, id=ctx_id)) + ctx = await Context.connected(self.context_storage, self.start_label, ctx_id) if update_ctx_misc is not None: ctx.misc.update(update_ctx_misc) @@ -259,16 +254,15 @@ async def _run_pipeline( ctx.framework_data.pipeline = self initialize_service_states(ctx, self.services_pipeline) - ctx.add_request(request) + ctx.current_turn_id = ctx.current_turn_id + 1 + + ctx.requests[ctx.current_turn_id] = request await self.services_pipeline(ctx) ctx.framework_data.service_states.clear() ctx.framework_data.pipeline = None - if isinstance(self.context_storage, DBContextStorage): - await self.context_storage.set_item_async(ctx_id, ctx) - else: - self.context_storage[ctx_id] = ctx + await ctx.store() return ctx @@ -282,11 +276,13 @@ def run(self): This method can be both blocking and non-blocking. It depends on current :py:attr:`messenger_interface` nature. Message interfaces that run in a loop block current thread. """ + if not self.context_storage.connected: + asyncio.run(self.context_storage.connect()) logger.info("Pipeline is accepting requests.") asyncio.run(self.messenger_interface.connect(self._run_pipeline)) def __call__( - self, request: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None + self, request: Message, ctx_id: Optional[str] = None, update_ctx_misc: Optional[dict] = None ) -> Context: """ Method that executes pipeline once. diff --git a/chatsky/core/service/actor.py b/chatsky/core/service/actor.py index 3e61c48d2..d6455504c 100644 --- a/chatsky/core/service/actor.py +++ b/chatsky/core/service/actor.py @@ -68,7 +68,7 @@ async def run_component(self, ctx: Context) -> None: logger.debug(f"Next label: {next_label}") - ctx.add_label(next_label) + ctx.labels[ctx.current_turn_id] = next_label response = Message() @@ -91,7 +91,7 @@ async def run_component(self, ctx: Context) -> None: except Exception as exc: logger.exception("Exception occurred during response processing.", exc_info=exc) - ctx.add_response(response) + ctx.responses[ctx.current_turn_id] = response @staticmethod async def _run_processing_parallel(processing: Dict[str, BaseProcessing], ctx: Context) -> None: diff --git a/chatsky/destinations/standard.py b/chatsky/destinations/standard.py index 59115a6e8..5694d0bde 100644 --- a/chatsky/destinations/standard.py +++ b/chatsky/destinations/standard.py @@ -12,7 +12,7 @@ from pydantic import Field -from chatsky.core.context import get_last_index, Context +from chatsky.core.context import Context from chatsky.core.node_label import NodeLabelInitTypes, AbsoluteNodeLabel from chatsky.core.script_function import BaseDestination @@ -33,15 +33,7 @@ class FromHistory(BaseDestination): """ async def call(self, ctx: Context) -> NodeLabelInitTypes: - index = get_last_index(ctx.labels) - shifted_index = index + self.position + 1 - result = ctx.labels.get(shifted_index) - if result is None: - raise KeyError( - f"No label with index {shifted_index!r}. " - f"Current label index: {index!r}; FromHistory.position: {self.position!r}." - ) - return result + return await ctx.labels[self.position] class Current(FromHistory): diff --git a/chatsky/stats/instrumentor.py b/chatsky/stats/instrumentor.py index 0005e8fcc..31dee897c 100644 --- a/chatsky/stats/instrumentor.py +++ b/chatsky/stats/instrumentor.py @@ -25,7 +25,6 @@ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter -from chatsky.core.context import get_last_index from chatsky.stats.utils import ( resource, get_extra_handler_name, @@ -161,7 +160,7 @@ async def __call__(self, wrapped, _, args, kwargs): pipeline_component = get_extra_handler_name(info) attributes = { "context_id": str(ctx.id), - "request_id": get_last_index(ctx.labels), + "request_id": ctx.current_turn_id, "pipeline_component": pipeline_component, } diff --git a/chatsky/utils/db_benchmark/basic_config.py b/chatsky/utils/db_benchmark/basic_config.py index bb1a10330..22afc83a5 100644 --- a/chatsky/utils/db_benchmark/basic_config.py +++ b/chatsky/utils/db_benchmark/basic_config.py @@ -15,7 +15,8 @@ from humanize import naturalsize from pympler import asizeof -from chatsky.core import Message, Context +from chatsky.core import Message, Context, AbsoluteNodeLabel +from chatsky.context_storages import MemoryContextStorage from chatsky.utils.db_benchmark.benchmark import BenchmarkConfig @@ -59,7 +60,8 @@ def get_message(message_dimensions: Tuple[int, ...]): return Message(misc=get_dict(message_dimensions)) -def get_context( +async def get_context( + db, dialog_len: int, message_dimensions: Tuple[int, ...], misc_dimensions: Tuple[int, ...], @@ -73,12 +75,16 @@ def get_context( :param misc_dimensions: A parameter used to generate misc field. See :py:func:`~.get_dict`. """ - return Context( - labels={i: (f"flow_{i}", f"node_{i}") for i in range(dialog_len)}, - requests={i: get_message(message_dimensions) for i in range(dialog_len)}, - responses={i: get_message(message_dimensions) for i in range(dialog_len)}, - misc=get_dict(misc_dimensions), - ) + ctx = await Context.connected(db, start_label=("flow", "node")) + ctx.current_turn_id = -1 + for i in range(dialog_len): + ctx.current_turn_id += 1 + ctx.labels[ctx.current_turn_id] = AbsoluteNodeLabel(flow_name=f"flow_{i}", node_name=f"node_{i}") + ctx.requests[ctx.current_turn_id] = get_message(message_dimensions) + ctx.responses[ctx.current_turn_id] = get_message(message_dimensions) + ctx.misc.update(get_dict(misc_dimensions)) + + return ctx class BasicBenchmarkConfig(BenchmarkConfig, frozen=True): @@ -121,15 +127,15 @@ class BasicBenchmarkConfig(BenchmarkConfig, frozen=True): See :py:func:`~.get_dict`. """ - def get_context(self) -> Context: + async def get_context(self, db) -> Context: """ Return context with `from_dialog_len`, `message_dimensions`, `misc_dimensions`. Wraps :py:func:`~.get_context`. """ - return get_context(self.from_dialog_len, self.message_dimensions, self.misc_dimensions) + return await get_context(db, self.from_dialog_len, self.message_dimensions, self.misc_dimensions) - def info(self): + async def info(self): """ Return fields of this instance and sizes of objects defined by this config. @@ -144,20 +150,34 @@ def info(self): - "misc_size" -- size of a misc field of a context. - "message_size" -- size of a misc field of a message. """ + + def remove_db_from_context(ctx: Context): + ctx._storage = None + ctx.requests._storage = None + ctx.responses._storage = None + ctx.labels._storage = None + + starting_context = await get_context( + MemoryContextStorage(), self.from_dialog_len, self.message_dimensions, self.misc_dimensions + ) + final_contex = await get_context( + MemoryContextStorage(), self.to_dialog_len, self.message_dimensions, self.misc_dimensions + ) + remove_db_from_context(starting_context) + remove_db_from_context(final_contex) return { "params": self.model_dump(), "sizes": { - "starting_context_size": naturalsize(asizeof.asizeof(self.get_context()), gnu=True), - "final_context_size": naturalsize( - asizeof.asizeof(get_context(self.to_dialog_len, self.message_dimensions, self.misc_dimensions)), - gnu=True, + "starting_context_size": naturalsize( + asizeof.asizeof(starting_context.model_dump(mode="python")), gnu=True ), + "final_context_size": naturalsize(asizeof.asizeof(final_contex.model_dump(mode="python")), gnu=True), "misc_size": naturalsize(asizeof.asizeof(get_dict(self.misc_dimensions)), gnu=True), "message_size": naturalsize(asizeof.asizeof(get_message(self.message_dimensions)), gnu=True), }, } - def context_updater(self, context: Context) -> Optional[Context]: + async def context_updater(self, context: Context) -> Optional[Context]: """ Update context to have `step_dialog_len` more labels, requests and responses, unless such dialog len would be equal to `to_dialog_len` or exceed than it, @@ -166,9 +186,12 @@ def context_updater(self, context: Context) -> Optional[Context]: start_len = len(context.labels) if start_len + self.step_dialog_len < self.to_dialog_len: for i in range(start_len, start_len + self.step_dialog_len): - context.add_label((f"flow_{i}", f"node_{i}")) - context.add_request(get_message(self.message_dimensions)) - context.add_response(get_message(self.message_dimensions)) + context.current_turn_id += 1 + context.labels[context.current_turn_id] = AbsoluteNodeLabel( + flow_name=f"flow_{i}", node_name=f"node_{i}" + ) + context.requests[context.current_turn_id] = get_message(self.message_dimensions) + context.responses[context.current_turn_id] = get_message(self.message_dimensions) return context else: return None diff --git a/chatsky/utils/db_benchmark/benchmark.py b/chatsky/utils/db_benchmark/benchmark.py index ccf60e6cc..2a70be291 100644 --- a/chatsky/utils/db_benchmark/benchmark.py +++ b/chatsky/utils/db_benchmark/benchmark.py @@ -22,12 +22,13 @@ from uuid import uuid4 from pathlib import Path from time import perf_counter -from typing import Tuple, List, Dict, Union, Optional, Callable, Any +from typing import Tuple, List, Dict, Union, Optional, Callable, Any, Awaitable import json import importlib from statistics import mean import abc from traceback import extract_tb, StackSummary +import asyncio from pydantic import BaseModel, Field from tqdm.auto import tqdm @@ -36,11 +37,11 @@ from chatsky.core import Context -def time_context_read_write( +async def time_context_read_write( context_storage: DBContextStorage, - context_factory: Callable[[], Context], + context_factory: Callable[[DBContextStorage], Awaitable[Context]], context_num: int, - context_updater: Optional[Callable[[Context], Optional[Context]]] = None, + context_updater: Optional[Callable[[Context], Awaitable[Optional[Context]]]] = None, ) -> Tuple[List[float], List[Dict[int, float]], List[Dict[int, float]]]: """ Benchmark `context_storage` by writing and reading `context`\\s generated by `context_factory` @@ -78,20 +79,18 @@ def time_context_read_write( dialog_len of the context returned by `context_factory`. So if `context_updater` is None, all dictionaries will be empty. """ - context_storage.clear() + await context_storage.clear_all() write_times: List[float] = [] read_times: List[Dict[int, float]] = [] update_times: List[Dict[int, float]] = [] for _ in tqdm(range(context_num), desc="Iteration", leave=False): - context = context_factory() - - ctx_id = uuid4() + context = await context_factory(context_storage) # write operation benchmark write_start = perf_counter() - context_storage[ctx_id] = context + await context.store() write_times.append(perf_counter() - write_start) read_times.append({}) @@ -99,27 +98,27 @@ def time_context_read_write( # read operation benchmark read_start = perf_counter() - context = context_storage[ctx_id] + context = await Context.connected(context_storage, start_label=("flow", "node"), id=context.id) read_time = perf_counter() - read_start read_times[-1][len(context.labels)] = read_time if context_updater is not None: - context = context_updater(context) + context = await context_updater(context) while context is not None: update_start = perf_counter() - context_storage[ctx_id] = context + await context.store() update_time = perf_counter() - update_start update_times[-1][len(context.labels)] = update_time read_start = perf_counter() - context = context_storage[ctx_id] + context = await Context.connected(context_storage, start_label=("flow", "node"), id=context.id) read_time = perf_counter() - read_start read_times[-1][len(context.labels)] = read_time - context = context_updater(context) + context = await context_updater(context) - context_storage.clear() + await context_storage.clear_all() return write_times, read_times, update_times @@ -164,7 +163,7 @@ class BenchmarkConfig(BaseModel, abc.ABC, frozen=True): """ @abc.abstractmethod - def get_context(self) -> Context: + async def get_context(self, db: DBContextStorage) -> Context: """ Return context to benchmark read and write operations with. @@ -173,14 +172,14 @@ def get_context(self) -> Context: ... @abc.abstractmethod - def info(self) -> Dict[str, Any]: + async def info(self) -> Dict[str, Any]: """ Return a dictionary with information about this configuration. """ ... @abc.abstractmethod - def context_updater(self, context: Context) -> Optional[Context]: + async def context_updater(self, context: Context) -> Optional[Context]: """ Update context with new dialog turns or return `None` to stop updates. @@ -284,11 +283,13 @@ def get_complex_stats(results): def _run(self): try: - write_times, read_times, update_times = time_context_read_write( - self.db_factory.db(), - self.benchmark_config.get_context, - self.benchmark_config.context_num, - self.benchmark_config.context_updater, + write_times, read_times, update_times = asyncio.run( + time_context_read_write( + self.db_factory.db(), + self.benchmark_config.get_context, + self.benchmark_config.context_num, + self.benchmark_config.context_updater, + ) ) return { "success": True, @@ -366,7 +367,7 @@ def save_results_to_file( result["benchmarks"].append( { **case.model_dump(exclude={"benchmark_config"}), - "benchmark_config": case.benchmark_config.info(), + "benchmark_config": asyncio.run(case.benchmark_config.info()), **case.run(), } ) diff --git a/chatsky/utils/logging.py b/chatsky/utils/logging.py new file mode 100644 index 000000000..fd736117d --- /dev/null +++ b/chatsky/utils/logging.py @@ -0,0 +1,14 @@ +from typing import Union + + +def collapse_num_list(num_list: Union[list[int], list[float]]) -> str: + """ + Produce representation for a list of numbers while collapsing large lists. + + For lists with 10 or fewer items return the representation of the list. + Otherwise, return a string with the minimum and maximum items as well as the number of items. + """ + if len(num_list) > 10: + return f"{min(num_list)} .. {max(num_list)} ({len(num_list)} items)" + else: + return repr(num_list) diff --git a/chatsky/utils/testing/cleanup_db.py b/chatsky/utils/testing/cleanup_db.py index fdc8f4635..cf9c237b5 100644 --- a/chatsky/utils/testing/cleanup_db.py +++ b/chatsky/utils/testing/cleanup_db.py @@ -5,19 +5,15 @@ including JSON, MongoDB, Pickle, Redis, Shelve, SQL, and YDB databases. """ -import os +from typing import Any from chatsky.context_storages import ( JSONContextStorage, MongoContextStorage, - PickleContextStorage, RedisContextStorage, - ShelveContextStorage, SQLContextStorage, YDBContextStorage, - json_available, mongo_available, - pickle_available, redis_available, sqlite_available, postgres_available, @@ -26,16 +22,14 @@ ) -async def delete_json(storage: JSONContextStorage): +async def delete_file(storage: JSONContextStorage): """ Delete all data from a JSON context storage. :param storage: A JSONContextStorage object. """ - if not json_available: - raise Exception("Can't delete JSON database - JSON provider unavailable!") - if os.path.isfile(storage.path): - os.remove(storage.path) + if storage.path.exists(): + storage.path.unlink() async def delete_mongo(storage: MongoContextStorage): @@ -46,19 +40,8 @@ async def delete_mongo(storage: MongoContextStorage): """ if not mongo_available: raise Exception("Can't delete mongo database - mongo provider unavailable!") - await storage.collection.drop() - - -async def delete_pickle(storage: PickleContextStorage): - """ - Delete all data from a Pickle context storage. - - :param storage: A PickleContextStorage object. - """ - if not pickle_available: - raise Exception("Can't delete pickle database - pickle provider unavailable!") - if os.path.isfile(storage.path): - os.remove(storage.path) + for collection in [storage.main_table, storage.turns_table]: + await collection.drop() async def delete_redis(storage: RedisContextStorage): @@ -69,17 +52,8 @@ async def delete_redis(storage: RedisContextStorage): """ if not redis_available: raise Exception("Can't delete redis database - redis provider unavailable!") - await storage.clear_async() - - -async def delete_shelve(storage: ShelveContextStorage): - """ - Delete all data from a Shelve context storage. - - :param storage: A ShelveContextStorage object. - """ - if os.path.isfile(storage.path): - os.remove(storage.path) + await storage.clear_all() + await storage.database.aclose() async def delete_sql(storage: SQLContextStorage): @@ -94,8 +68,9 @@ async def delete_sql(storage: SQLContextStorage): raise Exception("Can't delete sqlite database - sqlite provider unavailable!") if storage.dialect == "mysql" and not mysql_available: raise Exception("Can't delete mysql database - mysql provider unavailable!") - async with storage.engine.connect() as conn: - await conn.run_sync(storage.table.drop, storage.engine) + async with storage.engine.begin() as conn: + for table in [storage.main_table, storage.turns_table]: + await conn.run_sync(table.drop, storage.engine) async def delete_ydb(storage: YDBContextStorage): @@ -107,7 +82,8 @@ async def delete_ydb(storage: YDBContextStorage): if not ydb_available: raise Exception("Can't delete ydb database - ydb provider unavailable!") - async def callee(session): - await session.drop_table("/".join([storage.database, storage.table_name])) + async def callee(session: Any) -> None: + for table in [storage.main_table, storage.turns_table]: + await session.drop_table("/".join([storage.database, table])) await storage.pool.retry_operation(callee) diff --git a/chatsky/utils/testing/common.py b/chatsky/utils/testing/common.py index c884a513f..5dd848c0e 100644 --- a/chatsky/utils/testing/common.py +++ b/chatsky/utils/testing/common.py @@ -47,8 +47,8 @@ def check_happy_path( Defaults to ``Message.__eq__``. :param printout: Whether to print the requests/responses during iteration. """ - ctx_id = uuid4() # get random ID for current context - for step_id, (request_raw, reference_response_raw) in enumerate(happy_path): + ctx_id = str(uuid4()) # get random ID for current context + for request_raw, reference_response_raw in happy_path: request = Message.model_validate(request_raw) reference_response = Message.model_validate(reference_response_raw) @@ -64,7 +64,7 @@ def check_happy_path( if not response_comparator(reference_response, actual_response): raise AssertionError( f"""check_happy_path failed -step id: {step_id} +current turn id: {ctx.current_turn_id} reference response: {reference_response} actual response: {actual_response} """ diff --git a/docs/source/user_guides/context_guide.rst b/docs/source/user_guides/context_guide.rst index 5c57edbd3..c5e83572c 100644 --- a/docs/source/user_guides/context_guide.rst +++ b/docs/source/user_guides/context_guide.rst @@ -154,17 +154,6 @@ Public methods * **pipeline**: Return ``Pipeline`` object that is used to process this context. This can be used to get ``Script``, ``start_label`` or ``fallback_label``. -Private methods -^^^^^^^^^^^^^^^ - -These methods should not be used outside of the internal workings. - -* **set_last_response** -* **set_last_request** -* **add_request** -* **add_response** -* **add_label** - Context storages ~~~~~~~~~~~~~~~~ diff --git a/pyproject.toml b/pyproject.toml index f0a941f7b..202f83f0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -219,7 +219,7 @@ asyncio_mode = "auto" concurrency = [ "thread", "greenlet", - ] +] [tool.coverage.report] @@ -227,4 +227,4 @@ concurrency = [ exclude_also = [ "if TYPE_CHECKING:", "raise NotImplementedError", - ] +] diff --git a/tests/conftest.py b/tests/conftest.py index dad455b74..730b3274c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,8 @@ import pytest +from chatsky import Pipeline, Context, AbsoluteNodeLabel + def pytest_report_header(config, start_path): print(f"allow_skip: {config.getoption('--allow-skip') }") @@ -68,3 +70,38 @@ def emit(self, record) -> bool: return logs return inner + + +@pytest.fixture +def pipeline(): + return Pipeline( + script={"flow": {"node1": {}, "node2": {}, "node3": {}}, "service": {"start": {}, "fallback": {}}}, + start_label=("service", "start"), + fallback_label=("service", "fallback"), + ) + + +@pytest.fixture +def context_factory(pipeline): + def _context_factory(forbidden_fields=None, start_label=None): + ctx = Context() + if start_label is not None: + ctx.labels[0] = AbsoluteNodeLabel.model_validate(start_label) + ctx.framework_data.pipeline = pipeline + if forbidden_fields is not None: + + class Forbidden: + def __init__(self, name): + self.name = name + + class ForbiddenError(Exception): + pass + + def __getattr__(self, item): + raise self.ForbiddenError(f"{self.name!r} is forbidden") + + for forbidden_field in forbidden_fields: + ctx.__setattr__(forbidden_field, Forbidden(forbidden_field)) + return ctx + + return _context_factory diff --git a/tests/context_storages/conftest.py b/tests/context_storages/conftest.py deleted file mode 100644 index b2739bfe5..000000000 --- a/tests/context_storages/conftest.py +++ /dev/null @@ -1,22 +0,0 @@ -import uuid - -from chatsky.core import Context -import pytest - - -@pytest.fixture(scope="function") -def testing_context(): - yield Context(id=112668) - - -@pytest.fixture(scope="function") -def testing_file(tmpdir_factory): - filename = tmpdir_factory.mktemp("data").join("file.db") - string_file = str(filename) - yield string_file - - -@pytest.fixture(scope="function") -def context_id(): - ctx_id = str(uuid.uuid4()) - yield ctx_id diff --git a/tests/context_storages/test_dbs.py b/tests/context_storages/test_dbs.py index db94446a6..0f10fa7a2 100644 --- a/tests/context_storages/test_dbs.py +++ b/tests/context_storages/test_dbs.py @@ -1,48 +1,50 @@ +import os +from platform import system +from socket import AF_INET, SOCK_STREAM, socket +from typing import Optional import asyncio +import random import pytest -import socket -import os -from platform import system from chatsky.context_storages import ( get_protocol_install_suggestion, + context_storage_factory, json_available, pickle_available, - ShelveContextStorage, - DBContextStorage, postgres_available, mysql_available, sqlite_available, redis_available, mongo_available, ydb_available, - context_storage_factory, ) - -from chatsky.core import Context, Pipeline from chatsky.utils.testing.cleanup_db import ( - delete_shelve, - delete_json, - delete_pickle, + delete_file, delete_mongo, delete_redis, delete_sql, delete_ydb, ) +from chatsky import Pipeline +from chatsky.context_storages import DBContextStorage +from chatsky.context_storages.database import _SUBSCRIPT_TYPE, ContextInfo +from chatsky.utils.testing import TOY_SCRIPT_KWARGS, HAPPY_PATH, check_happy_path + +from tests.test_utils import get_path_from_tests_to_current_dir -from chatsky.utils.testing import check_happy_path, TOY_SCRIPT_KWARGS, HAPPY_PATH +dot_path_to_addon = get_path_from_tests_to_current_dir(__file__, separator=".") -def ping_localhost(port: int, timeout=60): +def ping_localhost(port: int, timeout: int = 60) -> bool: try: - socket.setdefaulttimeout(timeout) - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.connect(("localhost", port)) + sock = socket(AF_INET, SOCK_STREAM) + sock.settimeout(timeout) + sock.connect(("localhost", port)) except OSError: return False else: - s.close() + sock.close() return True @@ -57,33 +59,6 @@ def ping_localhost(port: int, timeout=60): YDB_ACTIVE = ping_localhost(2136) -def generic_test(db, testing_context, context_id): - assert isinstance(db, DBContextStorage) - # perform cleanup - db.clear() - assert len(db) == 0 - # test write operations - db[context_id] = Context(id=context_id) - assert context_id in db - assert len(db) == 1 - db[context_id] = testing_context # overwriting a key - assert len(db) == 1 - # test read operations - new_ctx = db[context_id] - assert isinstance(new_ctx, Context) - assert {**new_ctx.model_dump(), "id": str(new_ctx.id)} == { - **testing_context.model_dump(), - "id": str(testing_context.id), - } - # test delete operations - del db[context_id] - assert context_id not in db - # test `get` method - assert db.get(context_id) is None - pipeline = Pipeline(**TOY_SCRIPT_KWARGS, context_storage=db) - check_happy_path(pipeline, happy_path=HAPPY_PATH) - - @pytest.mark.parametrize( ["protocol", "expected"], [ @@ -92,106 +67,262 @@ def generic_test(db, testing_context, context_id): ("false", ""), ], ) -def test_protocol_suggestion(protocol, expected): +def test_protocol_suggestion(protocol: str, expected: str) -> None: result = get_protocol_install_suggestion(protocol) assert result == expected -def test_shelve(testing_file, testing_context, context_id): - db = ShelveContextStorage(f"shelve://{testing_file}") - generic_test(db, testing_context, context_id) - asyncio.run(delete_shelve(db)) - - -@pytest.mark.skipif(not json_available, reason="JSON dependencies missing") -def test_json(testing_file, testing_context, context_id): - db = context_storage_factory(f"json://{testing_file}") - generic_test(db, testing_context, context_id) - asyncio.run(delete_json(db)) - - -@pytest.mark.skipif(not pickle_available, reason="Pickle dependencies missing") -def test_pickle(testing_file, testing_context, context_id): - db = context_storage_factory(f"pickle://{testing_file}") - generic_test(db, testing_context, context_id) - asyncio.run(delete_pickle(db)) - - -@pytest.mark.skipif(not MONGO_ACTIVE, reason="Mongodb server is not running") -@pytest.mark.skipif(not mongo_available, reason="Mongodb dependencies missing") -@pytest.mark.docker -def test_mongo(testing_context, context_id): - if system() == "Windows": - pytest.skip() - - db = context_storage_factory( - "mongodb://{}:{}@localhost:27017/{}".format( - os.environ["MONGO_INITDB_ROOT_USERNAME"], - os.environ["MONGO_INITDB_ROOT_PASSWORD"], - os.environ["MONGO_INITDB_ROOT_USERNAME"], - ) - ) - generic_test(db, testing_context, context_id) - asyncio.run(delete_mongo(db)) - - -@pytest.mark.skipif(not REDIS_ACTIVE, reason="Redis server is not running") -@pytest.mark.skipif(not redis_available, reason="Redis dependencies missing") -@pytest.mark.docker -def test_redis(testing_context, context_id): - db = context_storage_factory("redis://{}:{}@localhost:6379/{}".format("", os.environ["REDIS_PASSWORD"], "0")) - generic_test(db, testing_context, context_id) - asyncio.run(delete_redis(db)) - - -@pytest.mark.skipif(not POSTGRES_ACTIVE, reason="Postgres server is not running") -@pytest.mark.skipif(not postgres_available, reason="Postgres dependencies missing") -@pytest.mark.docker -def test_postgres(testing_context, context_id): - db = context_storage_factory( - "postgresql+asyncpg://{}:{}@localhost:5432/{}".format( - os.environ["POSTGRES_USERNAME"], - os.environ["POSTGRES_PASSWORD"], - os.environ["POSTGRES_DB"], - ) - ) - generic_test(db, testing_context, context_id) - asyncio.run(delete_sql(db)) - - -@pytest.mark.skipif(not sqlite_available, reason="Sqlite dependencies missing") -def test_sqlite(testing_file, testing_context, context_id): - separator = "///" if system() == "Windows" else "////" - db = context_storage_factory(f"sqlite+aiosqlite:{separator}{testing_file}") - generic_test(db, testing_context, context_id) - asyncio.run(delete_sql(db)) - - -@pytest.mark.skipif(not MYSQL_ACTIVE, reason="Mysql server is not running") -@pytest.mark.skipif(not mysql_available, reason="Mysql dependencies missing") -@pytest.mark.docker -def test_mysql(testing_context, context_id): - db = context_storage_factory( - "mysql+asyncmy://{}:{}@localhost:3307/{}".format( - os.environ["MYSQL_USERNAME"], - os.environ["MYSQL_PASSWORD"], - os.environ["MYSQL_DATABASE"], - ) - ) - generic_test(db, testing_context, context_id) - asyncio.run(delete_sql(db)) - - -@pytest.mark.skipif(not YDB_ACTIVE, reason="YQL server not running") -@pytest.mark.skipif(not ydb_available, reason="YDB dependencies missing") -@pytest.mark.docker -def test_ydb(testing_context, context_id): - db = context_storage_factory( - "{}{}".format( - os.environ["YDB_ENDPOINT"], - os.environ["YDB_DATABASE"], +@pytest.mark.parametrize( + "db_kwargs,db_teardown", + [ + pytest.param({"path": ""}, None, id="memory"), + pytest.param({"path": "shelve://{__testing_file__}"}, delete_file, id="shelve"), + pytest.param( + {"path": "json://{__testing_file__}"}, + delete_file, + id="json", + marks=[pytest.mark.skipif(not json_available, reason="Asynchronous file (JSON) dependencies missing")], ), - table_name="test", - ) - generic_test(db, testing_context, context_id) - asyncio.run(delete_ydb(db)) + pytest.param( + {"path": "pickle://{__testing_file__}"}, + delete_file, + id="pickle", + marks=[pytest.mark.skipif(not pickle_available, reason="Asynchronous file (pickle) dependencies missing")], + ), + pytest.param( + { + "path": "mongodb://{MONGO_INITDB_ROOT_USERNAME}:{MONGO_INITDB_ROOT_PASSWORD}@" + "localhost:27017/{MONGO_INITDB_ROOT_USERNAME}" + }, + delete_mongo, + id="mongo", + marks=[ + pytest.mark.docker, + pytest.mark.skipif(not MONGO_ACTIVE, reason="Mongodb server is not running"), + pytest.mark.skipif(not mongo_available, reason="Mongodb dependencies missing"), + ], + ), + pytest.param( + {"path": "redis://:{REDIS_PASSWORD}@localhost:6379/0"}, + delete_redis, + id="redis", + marks=[ + pytest.mark.docker, + pytest.mark.skipif(not REDIS_ACTIVE, reason="Redis server is not running"), + pytest.mark.skipif(not redis_available, reason="Redis dependencies missing"), + ], + ), + pytest.param( + {"path": "postgresql+asyncpg://{POSTGRES_USERNAME}:{POSTGRES_PASSWORD}@localhost:5432/{POSTGRES_DB}"}, + delete_sql, + id="postgres", + marks=[ + pytest.mark.docker, + pytest.mark.skipif(not POSTGRES_ACTIVE, reason="Postgres server is not running"), + pytest.mark.skipif(not postgres_available, reason="Postgres dependencies missing"), + ], + ), + pytest.param( + {"path": "sqlite+aiosqlite:{__separator__}{__testing_file__}"}, + delete_sql, + id="sqlite", + marks=[pytest.mark.skipif(not sqlite_available, reason="Sqlite dependencies missing")], + ), + pytest.param( + {"path": "mysql+asyncmy://{MYSQL_USERNAME}:{MYSQL_PASSWORD}@localhost:3307/{MYSQL_DATABASE}"}, + delete_sql, + id="mysql", + marks=[ + pytest.mark.docker, + pytest.mark.skipif(not MYSQL_ACTIVE, reason="Mysql server is not running"), + pytest.mark.skipif(not mysql_available, reason="Mysql dependencies missing"), + ], + ), + pytest.param( + {"path": "{YDB_ENDPOINT}{YDB_DATABASE}"}, + delete_ydb, + id="ydb", + marks=[ + pytest.mark.docker, + pytest.mark.skipif(not YDB_ACTIVE, reason="YQL server not running"), + pytest.mark.skipif(not ydb_available, reason="YDB dependencies missing"), + ], + ), + ], +) +class TestContextStorages: + @pytest.fixture + async def db(self, db_kwargs, db_teardown, tmpdir_factory): + kwargs = {"__separator__": "///" if system() == "Windows" else "////", **os.environ} + if "{__testing_file__}" in db_kwargs["path"]: + kwargs["__testing_file__"] = str(tmpdir_factory.mktemp("data").join("file.db")) + db_kwargs["path"] = db_kwargs["path"].format(**kwargs) + context_storage = context_storage_factory(**db_kwargs) + + yield context_storage + + if db_teardown is not None: + await db_teardown(context_storage) + + @pytest.fixture + async def add_context(self, db): + async def add_context(ctx_id: str): + await db.update_main_info(ctx_id, ContextInfo(turn_id=1, created_at=1, updated_at=1)) + await db.update_field_items(ctx_id, "labels", [(0, b"0")]) + + yield add_context + + @staticmethod + def configure_context_storage( + context_storage: DBContextStorage, + rewrite_existing: Optional[bool] = None, + labels_subscript: Optional[_SUBSCRIPT_TYPE] = None, + requests_subscript: Optional[_SUBSCRIPT_TYPE] = None, + responses_subscript: Optional[_SUBSCRIPT_TYPE] = None, + all_subscript: Optional[_SUBSCRIPT_TYPE] = None, + ) -> None: + if rewrite_existing is not None: + context_storage.rewrite_existing = rewrite_existing + if all_subscript is not None: + labels_subscript = requests_subscript = responses_subscript = all_subscript + if labels_subscript is not None: + context_storage._subscripts["labels"] = labels_subscript + if requests_subscript is not None: + context_storage._subscripts["requests"] = requests_subscript + if responses_subscript is not None: + context_storage._subscripts["responses"] = responses_subscript + + async def test_add_context(self, db: DBContextStorage, add_context): + # test the fixture + await add_context("1") + + async def test_get_main_info(self, db: DBContextStorage, add_context): + await add_context("1") + assert await db.load_main_info("1") == ContextInfo(turn_id=1, created_at=1, updated_at=1) + assert await db.load_main_info("2") is None + + async def test_update_main_info(self, db: DBContextStorage, add_context): + await add_context("1") + await add_context("2") + assert await db.load_main_info("1") == ContextInfo(turn_id=1, created_at=1, updated_at=1) + assert await db.load_main_info("2") == ContextInfo(turn_id=1, created_at=1, updated_at=1) + + await db.update_main_info("1", ContextInfo(turn_id=2, created_at=1, updated_at=3)) + assert await db.load_main_info("1") == ContextInfo(turn_id=2, created_at=1, updated_at=3) + assert await db.load_main_info("2") == ContextInfo(turn_id=1, created_at=1, updated_at=1) + + async def test_wrong_field_name(self, db: DBContextStorage): + with pytest.raises(ValueError, match="Invalid value 'non-existent' for argument 'field_name'!"): + await db.load_field_latest("1", "non-existent") + with pytest.raises(ValueError, match="Invalid value 'non-existent' for argument 'field_name'!"): + await db.load_field_keys("1", "non-existent") + with pytest.raises(ValueError, match="Invalid value 'non-existent' for argument 'field_name'!"): + await db.load_field_items("1", "non-existent", [1, 2]) + with pytest.raises(ValueError, match="Invalid value 'non-existent' for argument 'field_name'!"): + await db.update_field_items("1", "non-existent", [(1, b"2")]) + + async def test_field_get(self, db: DBContextStorage, add_context): + await add_context("1") + + assert await db.load_field_latest("1", "labels") == [(0, b"0")] + assert set(await db.load_field_keys("1", "labels")) == {0} + + assert await db.load_field_latest("1", "requests") == [] + assert set(await db.load_field_keys("1", "requests")) == set() + + async def test_field_load(self, db: DBContextStorage, add_context): + await add_context("1") + + await db.update_field_items("1", "requests", [(1, b"1"), (3, b"3"), (2, b"2"), (4, b"4")]) + + assert await db.load_field_items("1", "requests", [1, 2]) == [(1, b"1"), (2, b"2")] + assert await db.load_field_items("1", "requests", [4, 3]) == [(3, b"3"), (4, b"4")] + + async def test_field_update(self, db: DBContextStorage, add_context): + await add_context("1") + assert await db.load_field_latest("1", "labels") == [(0, b"0")] + assert await db.load_field_latest("1", "requests") == [] + + await db.update_field_items("1", "labels", [(0, b"1")]) + await db.update_field_items("1", "requests", [(4, b"4")]) + await db.update_field_items("1", "labels", [(2, b"2")]) + + assert await db.load_field_latest("1", "labels") == [(2, b"2"), (0, b"1")] + assert set(await db.load_field_keys("1", "labels")) == {0, 2} + assert await db.load_field_latest("1", "requests") == [(4, b"4")] + assert set(await db.load_field_keys("1", "requests")) == {4} + + async def test_int_key_field_subscript(self, db: DBContextStorage, add_context): + await add_context("1") + await db.update_field_items("1", "requests", [(2, b"2")]) + await db.update_field_items("1", "requests", [(1, b"1")]) + await db.update_field_items("1", "requests", [(0, b"0")]) + + self.configure_context_storage(db, requests_subscript=2) + assert await db.load_field_latest("1", "requests") == [(2, b"2"), (1, b"1")] + + self.configure_context_storage(db, requests_subscript="__all__") + assert await db.load_field_latest("1", "requests") == [(2, b"2"), (1, b"1"), (0, b"0")] + + await db.update_field_items("1", "requests", [(5, b"5")]) + + self.configure_context_storage(db, requests_subscript=2) + assert await db.load_field_latest("1", "requests") == [(5, b"5"), (2, b"2")] + + async def test_delete_field_key(self, db: DBContextStorage, add_context): + await add_context("1") + + await db.delete_field_keys("1", "labels", [0]) + + assert await db.load_field_latest("1", "labels") == [] + + async def test_raises_on_missing_field_keys(self, db: DBContextStorage, add_context): + await add_context("1") + + assert set(await db.load_field_items("1", "labels", [0, 1])) == {(0, b"0")} + assert set(await db.load_field_items("1", "requests", [0])) == set() + + async def test_delete_context(self, db: DBContextStorage, add_context): + await add_context("1") + await add_context("2") + + # test delete + await db.delete_context("1") + + assert await db.load_main_info("1") is None + assert await db.load_main_info("2") == ContextInfo(turn_id=1, created_at=1, updated_at=1) + + assert set(await db.load_field_keys("1", "labels")) == set() + assert set(await db.load_field_keys("2", "labels")) == {0} + + @pytest.mark.slow + async def test_concurrent_operations(self, db: DBContextStorage): + async def db_operations(key: int): + str_key = str(key) + key_misc = {f"{key}": key + 2} + await asyncio.sleep(random.random() / 100) + await db.update_main_info( + str_key, ContextInfo(turn_id=key, created_at=key + 1, updated_at=key, misc=key_misc) + ) + await asyncio.sleep(random.random() / 100) + assert await db.load_main_info(str_key) == ContextInfo( + turn_id=key, created_at=key + 1, updated_at=key, misc=key_misc + ) + + for idx in range(1, 20): + await db.update_field_items(str_key, "requests", [(0, bytes(2 * key + idx)), (idx, bytes(key + idx))]) + await asyncio.sleep(random.random() / 100) + keys = list(range(idx + 1)) + assert set(await db.load_field_keys(str_key, "requests")) == set(keys) + assert set(await db.load_field_items(str_key, "requests", keys)) == { + (0, bytes(2 * key + idx)), + *[(k, bytes(key + k)) for k in range(1, idx + 1)], + } + + operations = [db_operations(key * 2) for key in range(3)] + await asyncio.gather(*operations) + + async def test_pipeline(self, db: DBContextStorage) -> None: + # Test Pipeline workload on DB + pipeline = Pipeline(**TOY_SCRIPT_KWARGS, context_storage=db) + check_happy_path(pipeline, happy_path=HAPPY_PATH) diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 465404d6d..e69de29bb 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -1,40 +0,0 @@ -import pytest - -from chatsky.core import Pipeline -from chatsky.core import Context - - -@pytest.fixture -def pipeline(): - return Pipeline( - script={"flow": {"node1": {}, "node2": {}, "node3": {}}, "service": {"start": {}, "fallback": {}}}, - start_label=("service", "start"), - fallback_label=("service", "fallback"), - ) - - -@pytest.fixture -def context_factory(pipeline): - def _context_factory(forbidden_fields=None, add_start_label=True): - if add_start_label: - ctx = Context.init(("service", "start")) - else: - ctx = Context() - ctx.framework_data.pipeline = pipeline - if forbidden_fields is not None: - - class Forbidden: - def __init__(self, name): - self.name = name - - class ForbiddenError(Exception): - pass - - def __getattr__(self, item): - raise self.ForbiddenError(f"{self.name!r} is forbidden") - - for forbidden_field in forbidden_fields: - ctx.__setattr__(forbidden_field, Forbidden(forbidden_field)) - return ctx - - return _context_factory diff --git a/tests/core/test_actor.py b/tests/core/test_actor.py index c6050e4b0..989f84811 100644 --- a/tests/core/test_actor.py +++ b/tests/core/test_actor.py @@ -9,7 +9,6 @@ from chatsky.core.context import Context from chatsky.core.script import Script from chatsky.core import RESPONSE, TRANSITIONS, PRE_TRANSITION, PRE_RESPONSE -from chatsky.core.utils import initialize_service_states class TestRequestProcessing: @@ -25,44 +24,38 @@ async def test_normal_execution(self): } ) - ctx = Context.init(start_label=("flow", "node1")) - actor = Actor() - ctx.framework_data.pipeline = Pipeline( - parallelize_processing=True, + pipeline = Pipeline( script=script, - fallback_label=AbsoluteNodeLabel(flow_name="flow", node_name="fallback"), start_label=("flow", "node1"), + fallback_label=AbsoluteNodeLabel(flow_name="flow", node_name="fallback"), + parallelize_processing=True, ) - initialize_service_states(ctx, actor) - await actor(ctx) + ctx = await pipeline._run_pipeline(Message()) - assert ctx.labels == { + assert ctx.labels._items == { 0: AbsoluteNodeLabel(flow_name="flow", node_name="node1"), 1: AbsoluteNodeLabel(flow_name="flow", node_name="node2"), } - assert ctx.responses == {1: Message(text="node2")} + assert ctx.responses._items == {1: Message(text="node2")} async def test_fallback_node(self): script = Script.model_validate({"flow": {"node": {}, "fallback": {RESPONSE: "fallback"}}}) - ctx = Context.init(start_label=("flow", "node")) - actor = Actor() - ctx.framework_data.pipeline = Pipeline( - parallelize_processing=True, + pipeline = Pipeline( script=script, - fallback_label=AbsoluteNodeLabel(flow_name="flow", node_name="fallback"), start_label=("flow", "node"), + fallback_label=AbsoluteNodeLabel(flow_name="flow", node_name="fallback"), + parallelize_processing=True, ) - initialize_service_states(ctx, actor) - await actor(ctx) + ctx = await pipeline._run_pipeline(Message()) - assert ctx.labels == { + assert ctx.labels._items == { 0: AbsoluteNodeLabel(flow_name="flow", node_name="node"), 1: AbsoluteNodeLabel(flow_name="flow", node_name="fallback"), } - assert ctx.responses == {1: Message(text="fallback")} + assert ctx.responses._items == {1: Message(text="fallback")} @pytest.mark.parametrize( "default_priority,result", @@ -84,18 +77,16 @@ async def test_default_priority(self, default_priority, result): } ) - ctx = Context.init(start_label=("flow", "node1")) - actor = Actor() - ctx.framework_data.pipeline = Pipeline( - parallelize_processing=True, + pipeline = Pipeline( script=script, + start_label=("flow", "node1"), fallback_label=AbsoluteNodeLabel(flow_name="flow", node_name="fallback"), + parallelize_processing=True, default_priority=default_priority, - start_label=("flow", "node1"), ) - initialize_service_states(ctx, actor) - await actor(ctx) + ctx = await pipeline._run_pipeline(Message()) + assert ctx.last_label.node_name == result async def test_transition_exception_handling(self, log_event_catcher): @@ -107,17 +98,14 @@ async def call(self, ctx: Context) -> None: script = Script.model_validate({"flow": {"node": {PRE_TRANSITION: {"": MyProcessing()}}, "fallback": {}}}) - ctx = Context.init(start_label=("flow", "node")) - actor = Actor() - ctx.framework_data.pipeline = Pipeline( - parallelize_processing=True, + pipeline = Pipeline( script=script, - fallback_label=AbsoluteNodeLabel(flow_name="flow", node_name="fallback"), start_label=("flow", "node"), + fallback_label=AbsoluteNodeLabel(flow_name="flow", node_name="fallback"), + parallelize_processing=True, ) - initialize_service_states(ctx, actor) - await actor(ctx) + ctx = await pipeline._run_pipeline(Message()) assert ctx.last_label.node_name == "fallback" assert log_list[0].msg == "Exception occurred during transition processing." @@ -128,17 +116,13 @@ async def test_empty_response(self, log_event_catcher): script = Script.model_validate({"flow": {"node": {}}}) - ctx = Context.init(start_label=("flow", "node")) - actor = Actor() - ctx.framework_data.pipeline = Pipeline( - parallelize_processing=True, + pipeline = Pipeline( script=script, - fallback_label=AbsoluteNodeLabel(flow_name="flow", node_name="node"), start_label=("flow", "node"), + parallelize_processing=True, ) - initialize_service_states(ctx, actor) - await actor(ctx) + ctx = await pipeline._run_pipeline(Message()) assert ctx.responses == {1: Message()} assert log_list[-1].msg == "Node has empty response." @@ -152,17 +136,13 @@ async def call(self, ctx: Context) -> MessageInitTypes: script = Script.model_validate({"flow": {"node": {RESPONSE: MyResponse()}}}) - ctx = Context.init(start_label=("flow", "node")) - actor = Actor() - ctx.framework_data.pipeline = Pipeline( - parallelize_processing=True, + pipeline = Pipeline( script=script, - fallback_label=AbsoluteNodeLabel(flow_name="flow", node_name="node"), start_label=("flow", "node"), + parallelize_processing=True, ) - initialize_service_states(ctx, actor) - await actor(ctx) + ctx = await pipeline._run_pipeline(Message()) assert ctx.responses == {1: Message()} assert log_list[-1].msg == "Response was not produced." @@ -176,17 +156,13 @@ async def call(self, ctx: Context) -> None: script = Script.model_validate({"flow": {"node": {PRE_RESPONSE: {"": MyProcessing()}}}}) - ctx = Context.init(start_label=("flow", "node")) - actor = Actor() - ctx.framework_data.pipeline = Pipeline( - parallelize_processing=True, + pipeline = Pipeline( script=script, - fallback_label=AbsoluteNodeLabel(flow_name="flow", node_name="node"), start_label=("flow", "node"), + parallelize_processing=True, ) - initialize_service_states(ctx, actor) - await actor(ctx) + ctx = await pipeline._run_pipeline(Message()) assert ctx.responses == {1: Message()} assert log_list[0].msg == "Exception occurred during response processing." @@ -207,7 +183,7 @@ async def call(self, ctx: Context) -> None: procs = {"1": Proc1(), "2": Proc2()} - ctx = Context.init(start_label=("flow", "node")) + ctx = Context() ctx.framework_data.pipeline = Pipeline(parallelize_processing=True, script={"": {"": {}}}, start_label=("", "")) await Actor._run_processing(procs, ctx) diff --git a/tests/core/test_conditions.py b/tests/core/test_conditions.py index 4d1a3f33f..d6536ef30 100644 --- a/tests/core/test_conditions.py +++ b/tests/core/test_conditions.py @@ -1,6 +1,6 @@ import pytest -from chatsky.core import BaseCondition +from chatsky.core import BaseCondition, AbsoluteNodeLabel from chatsky.core.message import Message, CallbackQuery import chatsky.conditions as cnd @@ -17,7 +17,7 @@ class SubclassMessage(Message): @pytest.fixture def request_based_ctx(context_factory): ctx = context_factory(forbidden_fields=("labels", "responses", "misc")) - ctx.add_request(Message(text="text", misc={"key": "value"})) + ctx.requests[1] = Message(text="text", misc={"key": "value"}) return ctx @@ -101,8 +101,7 @@ async def test_neg(request_based_ctx, condition, result): async def test_has_last_labels(context_factory): - ctx = context_factory(forbidden_fields=("requests", "responses", "misc")) - ctx.add_label(("flow", "node1")) + ctx = context_factory(forbidden_fields=("requests", "responses", "misc"), start_label=("flow", "node1")) assert await cnd.CheckLastLabels(flow_labels=["flow"])(ctx) is True assert await cnd.CheckLastLabels(flow_labels=["flow1"])(ctx) is False @@ -110,7 +109,7 @@ async def test_has_last_labels(context_factory): assert await cnd.CheckLastLabels(labels=[("flow", "node1")])(ctx) is True assert await cnd.CheckLastLabels(labels=[("flow", "node2")])(ctx) is False - ctx.add_label(("service", "start")) + ctx.labels[1] = AbsoluteNodeLabel(flow_name="service", node_name="start") assert await cnd.CheckLastLabels(flow_labels=["flow"])(ctx) is False assert await cnd.CheckLastLabels(flow_labels=["flow"], last_n_indices=2)(ctx) is True @@ -121,8 +120,8 @@ async def test_has_last_labels(context_factory): async def test_has_callback_query(context_factory): ctx = context_factory(forbidden_fields=("labels", "responses", "misc")) - ctx.add_request( - Message(attachments=[CallbackQuery(query_string="text", extra="extra"), CallbackQuery(query_string="text1")]) + ctx.requests[1] = Message( + attachments=[CallbackQuery(query_string="text", extra="extra"), CallbackQuery(query_string="text1")] ) assert await cnd.HasCallbackQuery("text")(ctx) is True @@ -133,6 +132,6 @@ async def test_has_callback_query(context_factory): @pytest.mark.parametrize("cnd", [cnd.HasText(""), cnd.Regexp(""), cnd.HasCallbackQuery("")]) async def test_empty_text(context_factory, cnd): ctx = context_factory() - ctx.add_request(Message()) + ctx.requests[1] = Message() assert await cnd(ctx) is False diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 1ca0e9842..942f19a32 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -1,6 +1,6 @@ import pytest -from chatsky.core.context import get_last_index, Context, ContextError +from chatsky.core.context import Context, ContextError from chatsky.core.node_label import AbsoluteNodeLabel from chatsky.core.message import Message, MessageInitTypes from chatsky.core.script_function import BaseResponse, BaseProcessing @@ -8,109 +8,104 @@ from chatsky.core import RESPONSE, PRE_TRANSITION, PRE_RESPONSE -class TestGetLastIndex: - @pytest.mark.parametrize( - "dict,result", - [ - ({1: None, 5: None}, 5), - ({5: None, 1: None}, 5), - ], - ) - def test_normal(self, dict, result): - assert get_last_index(dict) == result - - def test_exception(self): - with pytest.raises(ValueError): - get_last_index({}) - - -def test_init(): - ctx1 = Context.init(AbsoluteNodeLabel(flow_name="flow", node_name="node")) - ctx2 = Context.init(AbsoluteNodeLabel(flow_name="flow", node_name="node")) - assert ctx1.labels == {0: AbsoluteNodeLabel(flow_name="flow", node_name="node")} - assert ctx1.requests == {} - assert ctx1.responses == {} - assert ctx1.id != ctx2.id - - ctx3 = Context.init(AbsoluteNodeLabel(flow_name="flow", node_name="node"), id="id") - assert ctx3.labels == {0: AbsoluteNodeLabel(flow_name="flow", node_name="node")} - assert ctx3.requests == {} - assert ctx3.responses == {} - assert ctx3.id == "id" - - class TestLabels: @pytest.fixture def ctx(self, context_factory): - return context_factory(forbidden_fields=["requests", "responses"], add_start_label=False) - - def test_raises_on_empty_labels(self, ctx): - with pytest.raises(ContextError): - ctx.add_label(("flow", "node")) + return context_factory(forbidden_fields=["requests", "responses"]) + def test_raises_on_empty_labels(self, ctx: Context): with pytest.raises(ContextError): ctx.last_label - def test_existing_labels(self, ctx): - ctx.labels = {5: AbsoluteNodeLabel.model_validate(("flow", "node1"))} + def test_existing_labels(self, ctx: Context): + ctx.labels[5] = ("flow", "node1") assert ctx.last_label == AbsoluteNodeLabel(flow_name="flow", node_name="node1") - ctx.add_label(("flow", "node2")) - assert ctx.labels == { - 5: AbsoluteNodeLabel(flow_name="flow", node_name="node1"), - 6: AbsoluteNodeLabel(flow_name="flow", node_name="node2"), - } + ctx.labels[6] = ("flow", "node2") + assert ctx.labels.keys() == [5, 6] assert ctx.last_label == AbsoluteNodeLabel(flow_name="flow", node_name="node2") class TestRequests: @pytest.fixture def ctx(self, context_factory): - return context_factory(forbidden_fields=["labels", "responses"], add_start_label=False) + return context_factory(forbidden_fields=["labels", "responses"]) - def test_existing_requests(self, ctx): - ctx.requests = {5: Message(text="text1")} + def test_existing_requests(self, ctx: Context): + ctx.requests[5] = Message(text="text1") assert ctx.last_request == Message(text="text1") - ctx.add_request("text2") - assert ctx.requests == {5: Message(text="text1"), 6: Message(text="text2")} + ctx.requests[6] = "text2" + assert ctx.requests.keys() == [5, 6] assert ctx.last_request == Message(text="text2") - def test_empty_requests(self, ctx): + def test_empty_requests(self, ctx: Context): with pytest.raises(ContextError): ctx.last_request - ctx.add_request("text") + ctx.requests[1] = "text" assert ctx.last_request == Message(text="text") - assert list(ctx.requests.keys()) == [1] + assert ctx.requests.keys() == [1] class TestResponses: @pytest.fixture def ctx(self, context_factory): - return context_factory(forbidden_fields=["labels", "requests"], add_start_label=False) + return context_factory(forbidden_fields=["labels", "requests"]) - def test_existing_responses(self, ctx): - ctx.responses = {5: Message(text="text1")} + def test_existing_responses(self, ctx: Context): + ctx.responses[5] = Message(text="text1") assert ctx.last_response == Message(text="text1") - ctx.add_response("text2") - assert ctx.responses == {5: Message(text="text1"), 6: Message(text="text2")} + ctx.responses[6] = "text2" + assert ctx.responses.keys() == [5, 6] assert ctx.last_response == Message(text="text2") - def test_empty_responses(self, ctx): - assert ctx.last_response is None + def test_empty_responses(self, ctx: Context): + with pytest.raises(ContextError): + ctx.last_response - ctx.add_response("text") + ctx.responses[1] = "text" assert ctx.last_response == Message(text="text") - assert list(ctx.responses.keys()) == [1] - + assert ctx.responses.keys() == [1] -def test_last_items_on_init(): - ctx = Context.init(("flow", "node")) - assert ctx.last_label == AbsoluteNodeLabel(flow_name="flow", node_name="node") - assert ctx.last_response is None - with pytest.raises(ContextError): - ctx.last_request +class TestTurns: + @pytest.fixture + def ctx(self, context_factory): + return context_factory() + + async def test_complete_turn(self, ctx: Context): + ctx.labels[5] = ("flow", "node5") + ctx.requests[5] = Message(text="text5") + ctx.responses[5] = Message(text="text5") + ctx.current_turn_id = 5 + + label, request, response = list(await ctx.turns(5))[0] + assert label == AbsoluteNodeLabel(flow_name="flow", node_name="node5") + assert request == Message(text="text5") + assert response == Message(text="text5") + + async def test_partial_turn(self, ctx: Context): + ctx.labels[6] = ("flow", "node6") + ctx.requests[6] = Message(text="text6") + ctx.current_turn_id = 6 + + label, request, response = list(await ctx.turns(6))[0] + assert label == AbsoluteNodeLabel(flow_name="flow", node_name="node6") + assert request == Message(text="text6") + assert response is None + + async def test_slice_turn(self, ctx: Context): + for i in range(2, 6): + ctx.labels[i] = ("flow", f"node{i}") + ctx.requests[i] = Message(text=f"text{i}") + ctx.responses[i] = Message(text=f"text{i}") + ctx.current_turn_id = i + + labels, requests, responses = zip(*(await ctx.turns(slice(2, 6)))) + for i in range(2, 6): + assert AbsoluteNodeLabel(flow_name="flow", node_name=f"node{i}") in labels + assert Message(text=f"text{i}") in requests + assert Message(text=f"text{i}") in responses async def test_pipeline_available(): diff --git a/tests/core/test_context_dict.py b/tests/core/test_context_dict.py new file mode 100644 index 000000000..78c15ce94 --- /dev/null +++ b/tests/core/test_context_dict.py @@ -0,0 +1,160 @@ +import pytest + +from chatsky.context_storages import MemoryContextStorage +from chatsky.context_storages.database import ContextInfo, NameConfig +from chatsky.core.message import Message +from chatsky.core.ctx_dict import ContextDict, MessageContextDict + + +class TestContextDict: + @pytest.fixture(scope="function") + async def empty_dict(self) -> ContextDict: + # Empty (disconnected) context dictionary + return MessageContextDict() + + @pytest.fixture(scope="function") + async def attached_dict(self) -> ContextDict: + # Attached, but not backed by any data context dictionary + storage = MemoryContextStorage() + return await MessageContextDict.new(storage, "ID", NameConfig._requests_field) + + @pytest.fixture(scope="function") + async def prefilled_dict(self) -> ContextDict: + # Attached pre-filled context dictionary + ctx_id = "ctx1" + storage = MemoryContextStorage(rewrite_existing=False, partial_read_config={"requests": 1}) + await storage.update_main_info(ctx_id, ContextInfo(turn_id=0, created_at=0, updated_at=0)) + requests = [ + (1, Message("longer text", misc={"k": "v"}).model_dump_json().encode()), + (2, Message("text 2", misc={"1": 0, "2": 8}).model_dump_json().encode()), + ] + await storage.update_field_items(ctx_id, NameConfig._requests_field, requests) + return await MessageContextDict.connected(storage, ctx_id, NameConfig._requests_field) + + async def test_creation( + self, empty_dict: ContextDict, attached_dict: ContextDict, prefilled_dict: ContextDict + ) -> None: + # Checking creation correctness + for ctx_dict in [empty_dict, attached_dict, prefilled_dict]: + assert ctx_dict._storage is not None or ctx_dict == empty_dict + assert ctx_dict._added == ctx_dict._removed == set() + if ctx_dict != prefilled_dict: + assert ctx_dict._items == ctx_dict._hashes == dict() + assert ctx_dict._keys == set() + else: + assert len(ctx_dict._items) == len(ctx_dict._hashes) == 1 + assert ctx_dict._keys == {1, 2} + + async def test_get_set_del( + self, empty_dict: ContextDict, attached_dict: ContextDict, prefilled_dict: ContextDict + ) -> None: + for ctx_dict in [empty_dict, attached_dict, prefilled_dict]: + # Setting 1 item + message = Message("message") + ctx_dict[0] = message + assert await ctx_dict[0] == message + assert 0 in ctx_dict._keys + assert ctx_dict._added == {0} + assert ctx_dict._items[0] == message + # Setting several items + ctx_dict[1] = ctx_dict[2] = ctx_dict[3] = Message() + messages = (Message("1"), Message("2"), Message("3")) + ctx_dict[1:] = messages + assert await ctx_dict[1:] == list(messages) + assert ctx_dict._keys == {0, 1, 2, 3} + assert ctx_dict._added == {0, 1, 2, 3} + # Deleting item + del ctx_dict[0] + assert ctx_dict._keys == {1, 2, 3} + assert ctx_dict._added == {1, 2, 3} + assert ctx_dict._removed == {0} + # Getting deleted item + with pytest.raises(KeyError) as e: + _ = await ctx_dict[0] + assert e + # negative index + (await ctx_dict[-1]).text = "4" + assert (await ctx_dict[3]).text == "4" + + async def test_load_len_in_contains_keys_values(self, prefilled_dict: ContextDict) -> None: + # Checking keys + assert len(prefilled_dict) == 2 + assert prefilled_dict._keys == {1, 2} + assert prefilled_dict._added == set() + assert prefilled_dict.keys() == [1, 2] + assert 1 in prefilled_dict and 2 in prefilled_dict + assert set(prefilled_dict._items.keys()) == {2} + # Loading item + assert await prefilled_dict.get(100, None) is None + assert await prefilled_dict.get(1, None) is not None + assert prefilled_dict._added == set() + assert len(prefilled_dict._hashes) == 2 + assert len(prefilled_dict._items) == 2 + # Deleting loaded item + del prefilled_dict[1] + assert prefilled_dict._removed == {1} + assert len(prefilled_dict._items) == 1 + assert prefilled_dict._keys == {2} + assert 1 not in prefilled_dict + assert set(prefilled_dict.keys()) == {2} + # Checking remaining item + assert len(await prefilled_dict.values()) == 1 + assert len(prefilled_dict._items) == 1 + assert prefilled_dict._added == set() + + async def test_other_methods(self, prefilled_dict: ContextDict) -> None: + # Loading items + assert len(await prefilled_dict.items()) == 2 + # Poppong first item + assert await prefilled_dict.pop(1, None) is not None + assert prefilled_dict._removed == {1} + assert len(prefilled_dict) == 1 + # Popping nonexistent item + assert await prefilled_dict.pop(100, None) is None + # Poppint last item + assert (await prefilled_dict.popitem())[0] == 2 + assert prefilled_dict._removed == {1, 2} + # Updating dict with new values + await prefilled_dict.update({1: Message("some"), 2: Message("random")}) + assert set(prefilled_dict.keys()) == {1, 2} + # Adding default value to dict + message = Message("message") + assert await prefilled_dict.setdefault(3, message) == message + assert set(prefilled_dict.keys()) == {1, 2, 3} + # Clearing all the items + prefilled_dict.clear() + assert set(prefilled_dict.keys()) == set() + + async def test_eq_validate(self, empty_dict: ContextDict) -> None: + # Checking empty dict validation + assert empty_dict == MessageContextDict.model_validate(dict()) + # Checking non-empty dict validation + empty_dict[0] = Message("msg") + empty_dict._added = set() + assert empty_dict == MessageContextDict.model_validate({0: Message("msg")}) + + async def test_serialize_store( + self, empty_dict: ContextDict, attached_dict: ContextDict, prefilled_dict: ContextDict + ) -> None: + # Check all the dict types + for ctx_dict in [empty_dict, attached_dict, prefilled_dict]: + # Set overwriting existing keys to false + if ctx_dict._storage is not None: + ctx_dict._storage.rewrite_existing = False + # Adding an item + ctx_dict[0] = Message("message") + # Loading all pre-filled items + await ctx_dict.values() + # Changing one more item (might be pre-filled) + ctx_dict[2] = Message("another message") + # Removing the first added item + del ctx_dict[0] + # Checking only the changed keys were serialized + assert set(ctx_dict.model_dump(mode="json").keys()) == {"2"} + # Throw error if store in disconnected + if ctx_dict == empty_dict: + with pytest.raises(KeyError) as e: + await ctx_dict.store() + assert e + else: + await ctx_dict.store() diff --git a/tests/core/test_destinations.py b/tests/core/test_destinations.py index 5126c71aa..d66fe5d1f 100644 --- a/tests/core/test_destinations.py +++ b/tests/core/test_destinations.py @@ -7,7 +7,7 @@ @pytest.fixture def ctx(context_factory): - return context_factory(forbidden_fields=("requests", "responses", "misc")) + return context_factory(forbidden_fields=("requests", "responses", "misc"), start_label=("service", "start")) async def test_from_history(ctx): @@ -16,10 +16,10 @@ async def test_from_history(ctx): == await dst.Current()(ctx) == AbsoluteNodeLabel(flow_name="service", node_name="start") ) - with pytest.raises(KeyError): + with pytest.raises(IndexError): await dst.FromHistory(position=-2)(ctx) - ctx.add_label(("flow", "node1")) + ctx.labels[1] = ("flow", "node1") assert ( await dst.FromHistory(position=-1)(ctx) == await dst.Current()(ctx) @@ -30,10 +30,10 @@ async def test_from_history(ctx): == await dst.Previous()(ctx) == AbsoluteNodeLabel(flow_name="service", node_name="start") ) - with pytest.raises(KeyError): + with pytest.raises(IndexError): await dst.FromHistory(position=-3)(ctx) - ctx.add_label(("flow", "node2")) + ctx.labels[2] = ("flow", "node2") assert await dst.Current()(ctx) == AbsoluteNodeLabel(flow_name="flow", node_name="node2") assert await dst.FromHistory(position=-3)(ctx) == AbsoluteNodeLabel(flow_name="service", node_name="start") @@ -78,19 +78,19 @@ def test_non_existent_node_exception(self, ctx): dst.get_next_node_in_flow(("flow", "node4"), ctx) async def test_forward(self, ctx): - ctx.add_label(("flow", "node2")) + ctx.labels[1] = ("flow", "node2") assert await dst.Forward()(ctx) == AbsoluteNodeLabel(flow_name="flow", node_name="node3") - ctx.add_label(("flow", "node3")) + ctx.labels[2] = ("flow", "node3") assert await dst.Forward(loop=True)(ctx) == AbsoluteNodeLabel(flow_name="flow", node_name="node1") with pytest.raises(IndexError): await dst.Forward(loop=False)(ctx) async def test_backward(self, ctx): - ctx.add_label(("flow", "node2")) + ctx.labels[1] = ("flow", "node2") assert await dst.Backward()(ctx) == AbsoluteNodeLabel(flow_name="flow", node_name="node1") - ctx.add_label(("flow", "node1")) + ctx.labels[2] = ("flow", "node1") assert await dst.Backward(loop=True)(ctx) == AbsoluteNodeLabel(flow_name="flow", node_name="node3") with pytest.raises(IndexError): await dst.Backward(loop=False)(ctx) diff --git a/tests/core/test_node_label.py b/tests/core/test_node_label.py index 8580f5f5f..04d59f934 100644 --- a/tests/core/test_node_label.py +++ b/tests/core/test_node_label.py @@ -1,11 +1,11 @@ import pytest from pydantic import ValidationError -from chatsky.core import NodeLabel, Context, AbsoluteNodeLabel, Pipeline +from chatsky.core import NodeLabel, AbsoluteNodeLabel, Pipeline -def test_init_from_single_string(): - ctx = Context.init(("flow", "node1")) +def test_init_from_single_string(context_factory): + ctx = context_factory(start_label=("flow", "node1")) ctx.framework_data.pipeline = Pipeline({"flow": {"node2": {}}}, ("flow", "node2")) node = AbsoluteNodeLabel.model_validate("node2", context={"ctx": ctx}) @@ -31,11 +31,11 @@ def test_init_from_incorrect_iterables(data, msg): AbsoluteNodeLabel.model_validate(data) -def test_init_from_node_label(): +def test_init_from_node_label(context_factory): with pytest.raises(ValidationError): AbsoluteNodeLabel.model_validate(NodeLabel(node_name="node")) - ctx = Context.init(("flow", "node1")) + ctx = context_factory(start_label=("flow", "node1")) ctx.framework_data.pipeline = Pipeline({"flow": {"node2": {}}}, ("flow", "node2")) node = AbsoluteNodeLabel.model_validate(NodeLabel(node_name="node2"), context={"ctx": ctx}) @@ -43,8 +43,8 @@ def test_init_from_node_label(): assert node == AbsoluteNodeLabel(flow_name="flow", node_name="node2") -def test_check_node_exists(): - ctx = Context.init(("flow", "node1")) +def test_check_node_exists(context_factory): + ctx = context_factory(start_label=("flow", "node1")) ctx.framework_data.pipeline = Pipeline({"flow": {"node2": {}}}, ("flow", "node2")) with pytest.raises(ValidationError, match="Cannot find node"): diff --git a/tests/core/test_script_function.py b/tests/core/test_script_function.py index 0d23126a1..f7101a21f 100644 --- a/tests/core/test_script_function.py +++ b/tests/core/test_script_function.py @@ -97,29 +97,20 @@ class TestNodeLabelValidation: def pipeline(self): return Pipeline(script={"flow1": {"node": {}}, "flow2": {"node": {}}}, start_label=("flow1", "node")) - @pytest.fixture - def context_flow_factory(self, pipeline): - def factory(flow_name: str): - ctx = Context.init((flow_name, "node")) - ctx.framework_data.pipeline = pipeline - return ctx - - return factory - @pytest.mark.parametrize("flow_name", ("flow1", "flow2")) - async def test_const_destination(self, context_flow_factory, flow_name): + async def test_const_destination(self, context_factory, flow_name): const_dst = ConstDestination.model_validate("node") - dst = await const_dst.wrapped_call(context_flow_factory(flow_name)) + dst = await const_dst.wrapped_call(context_factory(start_label=(flow_name, "node"))) assert dst.flow_name == flow_name @pytest.mark.parametrize("flow_name", ("flow1", "flow2")) - async def test_base_destination(self, context_flow_factory, flow_name): + async def test_base_destination(self, context_factory, flow_name): class MyDestination(BaseDestination): def call(self, ctx): return "node" - dst = await MyDestination().wrapped_call(context_flow_factory(flow_name)) + dst = await MyDestination().wrapped_call(context_factory(start_label=(flow_name, "node"))) assert dst.flow_name == flow_name @@ -135,7 +126,7 @@ async def test_const_object_immutability(): message = Message(text="text1") response = ConstResponse.model_validate(message) - response_result = await response.wrapped_call(Context.init(("", ""))) + response_result = await response.wrapped_call(Context()) response_result.text = "text2" diff --git a/tests/core/test_transition.py b/tests/core/test_transition.py index 350374c66..fef5822a1 100644 --- a/tests/core/test_transition.py +++ b/tests/core/test_transition.py @@ -53,8 +53,7 @@ async def call(self, ctx: Context) -> Union[float, bool, None]: ], ) async def test_get_next_label(context_factory, transitions, default_priority, result): - ctx = context_factory() - ctx.add_label(("flow", "node1")) + ctx = context_factory(start_label=("flow", "node1")) assert await get_next_label(ctx, transitions, default_priority) == ( AbsoluteNodeLabel.model_validate(result) if result is not None else None diff --git a/tests/pipeline/conftest.py b/tests/pipeline/conftest.py new file mode 100644 index 000000000..f6f7ed00e --- /dev/null +++ b/tests/pipeline/conftest.py @@ -0,0 +1,77 @@ +import asyncio + +import pytest + +from chatsky.core import Context, Pipeline +from chatsky.core.service import ServiceGroup, Service, ComponentExecutionState +from chatsky.core.service.extra import ComponentExtraHandler +from chatsky.core.utils import initialize_service_states + +from chatsky.utils.testing import TOY_SCRIPT + + +@pytest.fixture() +def make_test_service_group(): + def inner(running_order: list) -> ServiceGroup: + def interact(stage: str, run_order: list): + async def slow_service(_: Context): + run_order.append(stage) + await asyncio.sleep(0) + + return slow_service + + test_group = ServiceGroup( + components=[ + ServiceGroup( + name="InteractWithServiceA", + components=[ + interact("A1", running_order), + interact("A2", running_order), + interact("A3", running_order), + ], + ), + ServiceGroup( + name="InteractWithServiceB", + components=[ + interact("B1", running_order), + interact("B2", running_order), + interact("B3", running_order), + ], + ), + ServiceGroup( + name="InteractWithServiceC", + components=[ + interact("C1", running_order), + interact("C2", running_order), + interact("C3", running_order), + ], + ), + ], + ) + return test_group + + return inner + + +@pytest.fixture() +def run_test_group(context_factory): + async def inner(test_group: ServiceGroup) -> ComponentExecutionState: + ctx = context_factory(start_label=("greeting_flow", "start_node")) + pipeline = Pipeline(pre_services=test_group, script=TOY_SCRIPT, start_label=("greeting_flow", "start_node")) + initialize_service_states(ctx, pipeline.pre_services) + await pipeline.pre_services(ctx) + return test_group.get_state(ctx) + + return inner + + +@pytest.fixture() +def run_extra_handler(context_factory): + async def inner(extra_handler: ComponentExtraHandler) -> ComponentExecutionState: + ctx = context_factory(start_label=("greeting_flow", "start_node")) + service = Service(handler=lambda _: None, before_handler=extra_handler) + initialize_service_states(ctx, service) + await service(ctx) + return service.get_state(ctx) + + return inner diff --git a/tests/pipeline/test_service.py b/tests/pipeline/test_service.py index 463c0f166..dc59f4512 100644 --- a/tests/pipeline/test_service.py +++ b/tests/pipeline/test_service.py @@ -14,16 +14,9 @@ ) from chatsky.core.service.extra import BeforeHandler from chatsky.core.utils import initialize_service_states, finalize_service_group -from chatsky.utils.testing import TOY_SCRIPT -from .utils import run_test_group, make_test_service_group, run_extra_handler -@pytest.fixture -def empty_context(): - return Context.init(("", "")) - - -async def test_pipeline_component_order(empty_context): +async def test_pipeline_component_order(): logs = [] class MyProcessing(BaseProcessing): @@ -41,12 +34,11 @@ async def call(self, ctx: Context): pre_services=[MyProcessing(wait=0.02, text="A")], post_services=[MyProcessing(wait=0, text="C")], ) - initialize_service_states(empty_context, pipeline.services_pipeline) - await pipeline._run_pipeline(Message(""), empty_context.id) + await pipeline._run_pipeline(Message("")) assert logs == ["A", "B", "C"] -async def test_async_services(): +async def test_async_services(make_test_service_group, run_test_group): running_order = [] test_group = make_test_service_group(running_order) test_group.components[0].concurrent = True @@ -56,7 +48,7 @@ async def test_async_services(): assert running_order == ["A1", "B1", "A2", "B2", "A3", "B3", "C1", "C2", "C3"] -async def test_all_async_flag(): +async def test_all_async_flag(make_test_service_group, run_test_group): running_order = [] test_group = make_test_service_group(running_order) test_group.fully_concurrent = True @@ -65,7 +57,7 @@ async def test_all_async_flag(): assert running_order == ["A1", "B1", "C1", "A2", "B2", "C2", "A3", "B3", "C3"] -async def test_extra_handler_timeouts(): +async def test_extra_handler_timeouts(run_extra_handler): def bad_function(timeout: float, bad_func_completed: list): def inner(_: Context, __: ExtraHandlerRuntimeInfo) -> None: asyncio.run(asyncio.sleep(timeout)) @@ -85,7 +77,7 @@ def inner(_: Context, __: ExtraHandlerRuntimeInfo) -> None: assert test_list == [True] -async def test_extra_handler_function_signatures(): +async def test_extra_handler_function_signatures(run_extra_handler): def one_parameter_func(_: Context) -> None: pass @@ -106,7 +98,7 @@ def no_parameters_func() -> None: # Checking that async functions can be run as extra_handlers. -async def test_async_extra_handler_func(): +async def test_async_extra_handler_func(run_extra_handler): def append_list(record: list): async def async_func(_: Context, __: ExtraHandlerRuntimeInfo): record.append("Value") @@ -187,7 +179,7 @@ def test_raise_on_name_collision(): # 'fully_concurrent' flag will try to run all services simultaneously, but the 'wait' option # makes it so that A waits for B, which waits for C. So "C" is first, "A" is last. -async def test_waiting_for_service_to_finish_condition(): +async def test_waiting_for_service_to_finish_condition(make_test_service_group, run_test_group): running_order = [] test_group = make_test_service_group(running_order) test_group.fully_concurrent = True @@ -198,7 +190,7 @@ async def test_waiting_for_service_to_finish_condition(): assert running_order == ["C1", "C2", "C3", "B1", "B2", "B3", "A1", "A2", "A3"] -async def test_bad_service(): +async def test_bad_service(run_test_group): def bad_service_func(_: Context) -> None: raise Exception("Custom exception") @@ -206,14 +198,15 @@ def bad_service_func(_: Context) -> None: assert await run_test_group(test_group) == ComponentExecutionState.FAILED -async def test_service_not_run(empty_context): +async def test_service_not_run(context_factory): service = Service(handler=lambda ctx: None, start_condition=False) + empty_context = context_factory() initialize_service_states(empty_context, service) await service(empty_context) assert service.get_state(empty_context) == ComponentExecutionState.NOT_RUN -async def test_inherited_extra_handlers_for_service_groups_with_conditions(): +async def test_inherited_extra_handlers_for_service_groups_with_conditions(make_test_service_group, run_test_group): def extra_handler_func(counter: list): def inner(_: Context) -> None: counter.append("Value") @@ -231,12 +224,10 @@ def condition_func(path: str): service = Service(handler=lambda _: None, name="service") test_group.components.append(service) - ctx = Context.init(("greeting_flow", "start_node")) - pipeline = Pipeline(pre_services=test_group, script=TOY_SCRIPT, start_label=("greeting_flow", "start_node")) + finalize_service_group(test_group, ".pre") test_group.add_extra_handler(ExtraHandlerType.BEFORE, extra_handler_func(counter_list), condition_func) - initialize_service_states(ctx, test_group) - await pipeline.pre_services(ctx) + await run_test_group(test_group) # One for original ServiceGroup, one for each of the defined paths in the condition function. assert counter_list == ["Value"] * 3 diff --git a/tests/pipeline/test_update_ctx_misc.py b/tests/pipeline/test_update_ctx_misc.py index fb5251c1d..a3f6cbee0 100644 --- a/tests/pipeline/test_update_ctx_misc.py +++ b/tests/pipeline/test_update_ctx_misc.py @@ -7,7 +7,7 @@ @pytest.mark.asyncio async def test_update_ctx_misc(): class MyCondition(BaseCondition): - async def call(self, ctx: Context) -> bool: + def call(self, ctx: Context) -> bool: return ctx.misc["condition"] toy_script = { diff --git a/tests/pipeline/utils.py b/tests/pipeline/utils.py deleted file mode 100644 index 70246d826..000000000 --- a/tests/pipeline/utils.py +++ /dev/null @@ -1,64 +0,0 @@ -import asyncio - -from chatsky.core import Context, Pipeline -from chatsky.core.service import ServiceGroup, Service, ComponentExecutionState -from chatsky.core.service.extra import ComponentExtraHandler -from chatsky.core.utils import initialize_service_states - -from chatsky.utils.testing import TOY_SCRIPT - - -def make_test_service_group(running_order: list) -> ServiceGroup: - - def interact(stage: str, run_order: list): - async def slow_service(_: Context): - run_order.append(stage) - await asyncio.sleep(0) - - return slow_service - - test_group = ServiceGroup( - components=[ - ServiceGroup( - name="InteractWithServiceA", - components=[ - interact("A1", running_order), - interact("A2", running_order), - interact("A3", running_order), - ], - ), - ServiceGroup( - name="InteractWithServiceB", - components=[ - interact("B1", running_order), - interact("B2", running_order), - interact("B3", running_order), - ], - ), - ServiceGroup( - name="InteractWithServiceC", - components=[ - interact("C1", running_order), - interact("C2", running_order), - interact("C3", running_order), - ], - ), - ], - ) - return test_group - - -async def run_test_group(test_group: ServiceGroup) -> ComponentExecutionState: - ctx = Context.init(("greeting_flow", "start_node")) - pipeline = Pipeline(pre_services=test_group, script=TOY_SCRIPT, start_label=("greeting_flow", "start_node")) - initialize_service_states(ctx, pipeline.pre_services) - await pipeline.pre_services(ctx) - return test_group.get_state(ctx) - - -async def run_extra_handler(extra_handler: ComponentExtraHandler) -> ComponentExecutionState: - ctx = Context.init(("greeting_flow", "start_node")) - service = Service(handler=lambda _: None, before_handler=extra_handler) - initialize_service_states(ctx, service) - await service(ctx) - return service.get_state(ctx) diff --git a/tests/slots/conftest.py b/tests/slots/conftest.py index 9cdcc4eec..d241bcb54 100644 --- a/tests/slots/conftest.py +++ b/tests/slots/conftest.py @@ -1,6 +1,6 @@ import pytest -from chatsky.core import Message, TRANSITIONS, RESPONSE, Context, Pipeline, Transition as Tr +from chatsky.core import Message, TRANSITIONS, RESPONSE, Pipeline, Transition from chatsky.slots.slots import SlotNotExtracted @@ -14,14 +14,14 @@ def patch_exception_equality(monkeypatch): @pytest.fixture(scope="function") def pipeline(): - script = {"flow": {"node": {RESPONSE: Message(), TRANSITIONS: [Tr(dst="node")]}}} + script = {"flow": {"node": {RESPONSE: Message(), TRANSITIONS: [Transition(dst="node")]}}} pipeline = Pipeline(script=script, start_label=("flow", "node")) return pipeline @pytest.fixture(scope="function") -def context(pipeline): - ctx = Context.init(("flow", "node")) - ctx.add_request(Message(text="Hi")) +def context(pipeline, context_factory): + ctx = context_factory(start_label=("flow", "node")) + ctx.requests[1] = Message(text="Hi") ctx.framework_data.pipeline = pipeline return ctx diff --git a/tests/slots/test_slot_functions.py b/tests/slots/test_slot_functions.py index 83c6b0171..f9aca2980 100644 --- a/tests/slots/test_slot_functions.py +++ b/tests/slots/test_slot_functions.py @@ -29,9 +29,9 @@ def root_slot(): @pytest.fixture -def context(root_slot): - ctx = Context.init(("", "")) - ctx.add_request("text") +def context(root_slot, context_factory): + ctx = context_factory(start_label=("", "")) + ctx.requests[1] = "text" ctx.framework_data.slot_manager = SlotManager() ctx.framework_data.slot_manager.set_root_slot(root_slot) return ctx diff --git a/tests/slots/test_slot_manager.py b/tests/slots/test_slot_manager.py index 33885b360..d24fba4c6 100644 --- a/tests/slots/test_slot_manager.py +++ b/tests/slots/test_slot_manager.py @@ -92,7 +92,7 @@ def faulty_func(_): @pytest.fixture(scope="function") def context_with_request(context): new_ctx = context.model_copy(deep=True) - new_ctx.add_request(Message(text="I am Bot. My email is bot@bot")) + new_ctx.requests[2] = Message(text="I am Bot. My email is bot@bot") return new_ctx diff --git a/tests/slots/test_slot_partial_extraction.py b/tests/slots/test_slot_partial_extraction.py index 234287c17..4bb1b8c97 100644 --- a/tests/slots/test_slot_partial_extraction.py +++ b/tests/slots/test_slot_partial_extraction.py @@ -1,3 +1,4 @@ +from chatsky.core.context import Context from chatsky.slots import RegexpSlot, GroupSlot from chatsky.slots.slots import SlotManager from chatsky.core import Message @@ -33,9 +34,11 @@ @pytest.fixture(scope="function") -def context_with_request(context): +def context_with_request(context_factory): def inner(request): - context.add_request(Message(request)) + context = context_factory() + context.requests[context.current_turn_id] = Message(request) + context.current_turn_id += 1 return context return inner diff --git a/tests/slots/test_slot_types.py b/tests/slots/test_slot_types.py index a21cbd896..17d103498 100644 --- a/tests/slots/test_slot_types.py +++ b/tests/slots/test_slot_types.py @@ -36,7 +36,7 @@ ], ) async def test_regexp(user_request, regexp, expected, context): - context.add_request(user_request) + context.requests[1] = user_request slot = RegexpSlot(regexp=regexp) result = await slot.get_value(context) assert result == expected @@ -58,7 +58,7 @@ async def test_regexp(user_request, regexp, expected, context): ], ) async def test_function(user_request, func, expected, context): - context.add_request(user_request) + context.requests[1] = user_request slot = FunctionSlot(func=func) result = await slot.get_value(context) assert result == expected @@ -125,7 +125,7 @@ def func(msg: Message): ], ) async def test_group_slot_extraction(user_request, slot, expected, is_extracted, context): - context.add_request(user_request) + context.requests[1] = user_request result = await slot.get_value(context) assert result == expected assert result.__slot_extracted__ == is_extracted diff --git a/tests/stats/test_defaults.py b/tests/stats/test_defaults.py index 331a21e42..f3069ddae 100644 --- a/tests/stats/test_defaults.py +++ b/tests/stats/test_defaults.py @@ -2,7 +2,6 @@ import pytest -from chatsky.core import Context from chatsky.core.service import Service from chatsky.core.service.types import ExtraHandlerRuntimeInfo @@ -12,8 +11,8 @@ pytest.skip(allow_module_level=True, reason="One of the Opentelemetry packages is missing.") -async def test_get_current_label(): - context = Context.init(("a", "b")) +async def test_get_current_label(context_factory): + context = context_factory(start_label=("a", "b")) runtime_info = ExtraHandlerRuntimeInfo( stage="BEFORE", component=Service(handler=lambda ctx: None, path="-", name="-", timeout=None, concurrent=False), @@ -22,7 +21,7 @@ async def test_get_current_label(): assert result == {"flow": "a", "node": "b", "label": "a: b"} -async def test_otlp_integration(tracer_exporter_and_provider, log_exporter_and_provider): +async def test_otlp_integration(tracer_exporter_and_provider, log_exporter_and_provider, context_factory): _, tracer_provider = tracer_exporter_and_provider log_exporter, logger_provider = log_exporter_and_provider tutorial_module = importlib.import_module("tutorials.stats.1_extractor_functions") @@ -32,7 +31,8 @@ async def test_otlp_integration(tracer_exporter_and_provider, log_exporter_and_p stage="BEFORE", component=Service(handler=lambda ctx: None, path="-", name="-", timeout=None, concurrent=False), ) - _ = await default_extractors.get_current_label(Context.init(("a", "b")), runtime_info) + ctx = context_factory(start_label=("a", "b")) + _ = await default_extractors.get_current_label(ctx, runtime_info) tracer_provider.force_flush() logger_provider.force_flush() assert len(log_exporter.get_finished_logs()) > 0 diff --git a/tests/stats/test_instrumentation.py b/tests/stats/test_instrumentation.py index 1b2418179..e26b5fb66 100644 --- a/tests/stats/test_instrumentation.py +++ b/tests/stats/test_instrumentation.py @@ -1,8 +1,8 @@ import pytest from chatsky import Context -from chatsky.core.service import Service, ExtraHandlerRuntimeInfo, ComponentExecutionState, ServiceGroup -from tests.pipeline.utils import run_test_group +from chatsky.core.service import Service, ExtraHandlerRuntimeInfo +from chatsky.core.utils import initialize_service_states try: from chatsky.stats import default_extractors @@ -43,7 +43,7 @@ def test_keyword_arguments(): assert instrumentor._tracer_provider is not get_tracer_provider() -async def test_failed_stats_collection(log_event_catcher): +async def test_failed_stats_collection(log_event_catcher, context_factory): chatsky_instrumentor = OtelInstrumentor.from_url("grpc://localhost:4317") chatsky_instrumentor.instrument() @@ -51,10 +51,13 @@ async def test_failed_stats_collection(log_event_catcher): async def bad_stats_collector(_: Context, __: ExtraHandlerRuntimeInfo): raise Exception - service = Service(handler=lambda _: None, before_handler=bad_stats_collector) + service = Service(handler=lambda _: None, before_handler=bad_stats_collector, path=".service") log_list = log_event_catcher(logger=instrumentor_logger, level="ERROR") - assert await run_test_group(ServiceGroup(components=[service])) == ComponentExecutionState.FINISHED + ctx = context_factory() + initialize_service_states(ctx, service) + + await service(ctx) assert len(log_list) == 1 diff --git a/tests/utils/test_benchmark.py b/tests/utils/test_benchmark.py index 89444b064..c05c4fdf9 100644 --- a/tests/utils/test_benchmark.py +++ b/tests/utils/test_benchmark.py @@ -30,30 +30,45 @@ def test_get_dict(): } -def test_get_context(): +def test_db_factory(tmp_path: Path): + factory = bm.DBFactory(uri=f"json://{tmp_path}/json.json") + + db = factory.db() + + assert isinstance(db, JSONContextStorage) + + +@pytest.fixture +def context_storage(tmp_path: Path) -> JSONContextStorage: + factory = bm.DBFactory(uri=f"json://{tmp_path}/json.json") + + return factory.db() + + +async def test_get_context(context_storage: JSONContextStorage): random.seed(42) - context = get_context(2, (1, 2), (2, 3)) - assert context == Context( - id=context.id, - labels={0: ("flow_0", "node_0"), 1: ("flow_1", "node_1")}, - requests={0: Message(misc={"0": ">e"}), 1: Message(misc={"0": "3 "})}, - responses={0: Message(misc={"0": "zv"}), 1: Message(misc={"0": "sh"})}, - misc={"0": " d]", "1": " (b"}, + context = await get_context(context_storage, 2, (1, 2), (2, 3)) + copy_ctx = await Context.connected(context_storage, ("flow", "node")) + await copy_ctx.labels.update({0: ("flow_0", "node_0"), 1: ("flow_1", "node_1")}) + await copy_ctx.requests.update({0: Message(misc={"0": ">e"}), 1: Message(misc={"0": "zv"})}) + await copy_ctx.responses.update({0: Message(misc={"0": "3 "}), 1: Message(misc={"0": "sh"})}) + copy_ctx.misc.update({"0": " d]", "1": " (b"}) + assert context.model_dump(exclude={"id", "current_turn_id"}) == copy_ctx.model_dump( + exclude={"id", "current_turn_id"} ) -def test_benchmark_config(monkeypatch: pytest.MonkeyPatch): +async def test_benchmark_config(context_storage: JSONContextStorage, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(random, "choice", lambda x: ".") config = bm.BasicBenchmarkConfig( from_dialog_len=1, to_dialog_len=5, message_dimensions=(2, 2), misc_dimensions=(3, 3, 3) ) - context = config.get_context() - actual_context = get_context(1, (2, 2), (3, 3, 3)) - actual_context.id = context.id - assert context == actual_context + context = await config.get_context(context_storage) + actual_context = await get_context(context_storage, 1, (2, 2), (3, 3, 3)) + assert context.model_dump(exclude={"id"}) == actual_context.model_dump(exclude={"id"}) - info = config.info() + info = await config.info() for size in ("starting_context_size", "final_context_size", "misc_size", "message_size"): assert isinstance(info["sizes"][size], str) @@ -62,7 +77,7 @@ def test_benchmark_config(monkeypatch: pytest.MonkeyPatch): contexts = [context] while contexts[-1] is not None: - contexts.append(context_updater(deepcopy(contexts[-1]))) + contexts.append(await context_updater(deepcopy(contexts[-1]))) assert len(contexts) == 5 @@ -70,12 +85,11 @@ def test_benchmark_config(monkeypatch: pytest.MonkeyPatch): if context is not None: assert len(context.labels) == len(context.requests) == len(context.responses) == index + 1 - actual_context = get_context(index + 1, (2, 2), (3, 3, 3)) - actual_context.id = context.id - assert context == actual_context + actual_context = await get_context(context_storage, index + 1, (2, 2), (3, 3, 3)) + assert context.model_dump(exclude={"id"}) == actual_context.model_dump(exclude={"id"}) -def test_context_updater_with_steps(monkeypatch: pytest.MonkeyPatch): +async def test_context_updater_with_steps(context_storage: JSONContextStorage, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(random, "choice", lambda x: ".") config = bm.BasicBenchmarkConfig( @@ -84,10 +98,10 @@ def test_context_updater_with_steps(monkeypatch: pytest.MonkeyPatch): context_updater = config.context_updater - contexts = [config.get_context()] + contexts = [await config.get_context(context_storage)] while contexts[-1] is not None: - contexts.append(context_updater(deepcopy(contexts[-1]))) + contexts.append(await context_updater(deepcopy(contexts[-1]))) assert len(contexts) == 5 @@ -95,27 +109,11 @@ def test_context_updater_with_steps(monkeypatch: pytest.MonkeyPatch): if context is not None: assert len(context.labels) == len(context.requests) == len(context.responses) == index - actual_context = get_context(index, (2, 2), (3, 3, 3)) - actual_context.id = context.id - assert context == actual_context + actual_context = await get_context(context_storage, index, (2, 2), (3, 3, 3)) + assert context.model_dump(exclude={"id"}) == actual_context.model_dump(exclude={"id"}) -def test_db_factory(tmp_path: Path): - factory = bm.DBFactory(uri=f"json://{tmp_path}/json.json") - - db = factory.db() - - assert isinstance(db, JSONContextStorage) - - -@pytest.fixture -def context_storage(tmp_path: Path) -> JSONContextStorage: - factory = bm.DBFactory(uri=f"json://{tmp_path}/json.json") - - return factory.db() - - -def test_time_context_read_write(context_storage: JSONContextStorage): +async def test_time_context_read_write(context_storage: JSONContextStorage): config = bm.BasicBenchmarkConfig( context_num=5, from_dialog_len=1, @@ -125,12 +123,10 @@ def test_time_context_read_write(context_storage: JSONContextStorage): misc_dimensions=(3, 3, 3), ) - results = bm.time_context_read_write( + results = await bm.time_context_read_write( context_storage, config.get_context, config.context_num, config.context_updater ) - assert len(context_storage) == 0 - assert len(results) == 3 write, read, update = results @@ -148,7 +144,7 @@ def test_time_context_read_write(context_storage: JSONContextStorage): assert all([isinstance(update_time, float) and update_time > 0 for update_time in update_item.values()]) -def test_time_context_read_write_without_updates(context_storage: JSONContextStorage): +async def test_time_context_read_write_without_updates(context_storage: JSONContextStorage): config = bm.BasicBenchmarkConfig( context_num=5, from_dialog_len=1, @@ -158,7 +154,7 @@ def test_time_context_read_write_without_updates(context_storage: JSONContextSto misc_dimensions=(3, 3, 3), ) - results = bm.time_context_read_write( + results = await bm.time_context_read_write( context_storage, config.get_context, config.context_num, @@ -169,7 +165,7 @@ def test_time_context_read_write_without_updates(context_storage: JSONContextSto assert list(read[0].keys()) == [1] assert list(update[0].keys()) == [] - results = bm.time_context_read_write( + results = await bm.time_context_read_write( context_storage, config.get_context, config.context_num, diff --git a/tutorials/context_storages/1_basics.py b/tutorials/context_storages/1_basics.py index 4dbc2335a..d5e38a226 100644 --- a/tutorials/context_storages/1_basics.py +++ b/tutorials/context_storages/1_basics.py @@ -27,7 +27,7 @@ pathlib.Path("dbs").mkdir(exist_ok=True) db = context_storage_factory("json://dbs/file.json") # db = context_storage_factory("pickle://dbs/file.pkl") -# db = context_storage_factory("shelve://dbs/file") +# db = context_storage_factory("shelve://dbs/file.shlv") pipeline = Pipeline(**TOY_SCRIPT_KWARGS, context_storage=db) diff --git a/tutorials/context_storages/8_partial_updates.py b/tutorials/context_storages/8_partial_updates.py new file mode 100644 index 000000000..188eb1e49 --- /dev/null +++ b/tutorials/context_storages/8_partial_updates.py @@ -0,0 +1,134 @@ +# %% [markdown] +""" +# 8. Partial context updates + +The following tutorial shows the advanced usage +of context storage and context storage schema. +""" + +# %pip install chatsky + +# %% +from pathlib import Path + +from chatsky import Pipeline +from chatsky.context_storages import context_storage_factory +from chatsky.context_storages.database import NameConfig +from chatsky.utils.testing.common import check_happy_path, is_interactive_mode +from chatsky.utils.testing.toy_script import TOY_SCRIPT_KWARGS, HAPPY_PATH + +# %% +Path("dbs").mkdir(exist_ok=True) +db = context_storage_factory("shelve://dbs/partly.shlv") + +pipeline = Pipeline(**TOY_SCRIPT_KWARGS, context_storage=db) + +# %% [markdown] +""" +Most of the `Context` fields, that might grow in size uncontrollably, +are stored in a special structure, `ContextDict`. +This structure can be used for fine-grained access to the underlying +database, partial and asynchronous element loading. +In particular, this is relevant for `labels`, `requests` and `responses` +fields, while `misc` and `framework_data` are always loaded +completely and synchronously. + +How does that partial field writing work? +In most cases, every context storage +operates two "tables", "dictionaries", "files", etc. +One of them is called `MAIN` and contains all the "primitive" `Context` +data (and also the data that will be read and written completely every time, +serialized) - that includes context `id`, `current_turn_id`, `_created_at`, +`_updated_at`, `misc` and `framework_data` fields. +The other one is called `TURNS` and contains triplets of the data generated on +each conversation step: `label`, `request` and `response`. + +Whenever a context is loaded, only one item from `MAIN` teble +and zero to few items from `TURNS` table are loaded synchronously. +More items from `TURNS` table can be loaded later asynchronously on demand. +""" + +# %% [markdown] +""" +Database table layout and default behavior is controlled by +some special fields of the `DBContextStorage` class. + +All the table and field names are stored in a special `NameConfig` +static class. +""" + +# %% +print({k: v for k, v in vars(NameConfig).items() if not k.startswith("__") and not callable(v)}) + +# %% [markdown] +""" +Another property worth mentioning is `_subscripts`: +this property controls the number of *last* dictionary items +that will be read and written +(the items are ordered by keys, ascending) - default value is 3. +In order to read *all* items at once, the property +can also be set to "__all__" literal. +In order to read only a specific subset of keys, the property +can be set to a set of desired integers. +""" + +# %% +# All items will be read. +db._subscripts[NameConfig._requests_field] = "__all__" + +# %% +# 5 last items will be read. +db._subscripts[NameConfig._requests_field] = 5 + +# %% +# Items 1, 3, 5 and 7 will be read. +db._subscripts[NameConfig._requests_field] = {1, 3, 5, 7} + +# %% [markdown] +""" +Last but not least, comes `rewrite_existing` boolean flag. +In order to understand it, let's explore `ContextDict` class more. + +`ContextDict` provides dict-like access to its elements, however +by default not all of them might be loaded from the very beginning. +Usually, only the keys are loaded completely, while values loading is +controlled by `subscript` mentioned above. + +Still, `ContextDict` allows accessing items that are not yet loaded +(they are loaded lazily), as well as deleting and overwriting them. +Once `ContextDict` is serialized, it always includes information about +all the added and removed elements. +As for the modified elements, that's where `rewrite_existing` flag +comes into play: if it is set to `True`, modifications are included, +otherwise they are discarded. + +NB! Keeping track of the modified elements comes with a price of calculating +their hashes and comparing them, so in performance-critical environments +this feature probably might be avoided. +""" + +# %% +# Any modifications done to the elements already present in storage +# will be preserved. +db.rewrite_existing = True + +# %% [markdown] +""" +A few more words about the `labels`, `requests` and `responses` fields. +One and only one label, request and response is added on every dialog turn, +and they are numbered consecutively. + +Framework ensures this ordering is preserved, each of them can be modified or +replaced with `None`, but never deleted or removed completely. +""" + +# %% +if __name__ == "__main__": + check_happy_path(pipeline, HAPPY_PATH) + # This is a function for automatic + # tutorial running (testing) with HAPPY_PATH + + # This runs tutorial in interactive mode if not in IPython env + # and if `DISABLE_INTERACTIVE_MODE` is not set + if is_interactive_mode(): + pipeline.run() # This runs tutorial in interactive mode diff --git a/tutorials/context_storages/8_db_benchmarking.py b/tutorials/context_storages/9_db_benchmarking.py similarity index 99% rename from tutorials/context_storages/8_db_benchmarking.py rename to tutorials/context_storages/9_db_benchmarking.py index 16a8df959..c635bda70 100644 --- a/tutorials/context_storages/8_db_benchmarking.py +++ b/tutorials/context_storages/9_db_benchmarking.py @@ -1,6 +1,6 @@ # %% [markdown] """ -# 8. Context storage benchmarking +# 9. Context storage benchmarking This tutorial shows how to benchmark context storages. diff --git a/utils/stats/sample_data_provider.py b/utils/stats/sample_data_provider.py index f1d990589..9ea3315ad 100644 --- a/utils/stats/sample_data_provider.py +++ b/utils/stats/sample_data_provider.py @@ -103,7 +103,9 @@ async def main(n_iterations: int = 100, n_workers: int = 4): ctxs = asyncio.Queue() parallel_iterations = n_iterations // n_workers for _ in range(n_workers): - await ctxs.put(Context.init(("root", "start"))) + ctx = Context() + ctx.last_label = ("root", "start") + await ctxs.put(ctx) for _ in tqdm(range(parallel_iterations)): await asyncio.gather(*(worker(ctxs) for _ in range(n_workers)))