diff --git a/src/typesense/client.py b/src/typesense/client.py index cde957b..f60acd0 100644 --- a/src/typesense/client.py +++ b/src/typesense/client.py @@ -46,6 +46,7 @@ from typesense.keys import Keys from typesense.metrics import Metrics from typesense.multi_search import MultiSearch +from typesense.nl_search_models import NLSearchModels from typesense.operations import Operations from typesense.stemming import Stemming from typesense.stopwords import Stopwords @@ -107,6 +108,7 @@ def __init__(self, config_dict: ConfigDict) -> None: self.stopwords = Stopwords(self.api_call) self.metrics = Metrics(self.api_call) self.conversations_models = ConversationsModels(self.api_call) + self.nl_search_models = NLSearchModels(self.api_call) def typed_collection( self, diff --git a/src/typesense/nl_search_model.py b/src/typesense/nl_search_model.py new file mode 100644 index 0000000..49aaab1 --- /dev/null +++ b/src/typesense/nl_search_model.py @@ -0,0 +1,108 @@ +""" +This module provides functionality for managing individual NL search models in Typesense. + +Classes: + - NLSearchModel: Handles operations related to a specific NL search model. + +Methods: + - __init__: Initializes the NLSearchModel object. + - _endpoint_path: Constructs the API endpoint path for this specific NL search model. + - retrieve: Retrieves the details of this specific NL search model. + - update: Updates this specific NL search model. + - delete: Deletes this specific NL search model. + +The NLSearchModel class interacts with the Typesense API to manage operations on a +specific NL search model. It provides methods to retrieve, update, +and delete individual models. + +This module uses type hinting and is compatible with Python 3.11+ as well as earlier +versions through the use of the typing_extensions library. +""" + +from typesense.api_call import ApiCall +from typesense.types.nl_search_model import ( + NLSearchModelDeleteSchema, + NLSearchModelSchema, + NLSearchModelUpdateSchema, +) + + +class NLSearchModel: + """ + Class for managing individual NL search models in Typesense. + + This class provides methods to interact with a specific NL search model, + including retrieving, updating, and deleting it. + + Attributes: + model_id (str): The ID of the NL search model. + api_call (ApiCall): The API call object for making requests. + """ + + def __init__(self, api_call: ApiCall, model_id: str) -> None: + """ + Initialize the NLSearchModel object. + + Args: + api_call (ApiCall): The API call object for making requests. + model_id (str): The ID of the NL search model. + """ + self.model_id = model_id + self.api_call = api_call + + def retrieve(self) -> NLSearchModelSchema: + """ + Retrieve this specific NL search model. + + Returns: + NLSearchModelSchema: The schema containing the NL search model details. + """ + response = self.api_call.get( + self._endpoint_path, + as_json=True, + entity_type=NLSearchModelSchema, + ) + return response + + def update(self, model: NLSearchModelUpdateSchema) -> NLSearchModelSchema: + """ + Update this specific NL search model. + + Args: + model (NLSearchModelUpdateSchema): + The schema containing the updated model details. + + Returns: + NLSearchModelSchema: The schema containing the updated NL search model. + """ + response: NLSearchModelSchema = self.api_call.put( + self._endpoint_path, + body=model, + entity_type=NLSearchModelSchema, + ) + return response + + def delete(self) -> NLSearchModelDeleteSchema: + """ + Delete this specific NL search model. + + Returns: + NLSearchModelDeleteSchema: The schema containing the deletion response. + """ + response: NLSearchModelDeleteSchema = self.api_call.delete( + self._endpoint_path, + entity_type=NLSearchModelDeleteSchema, + ) + return response + + @property + def _endpoint_path(self) -> str: + """ + Construct the API endpoint path for this specific NL search model. + + Returns: + str: The constructed endpoint path. + """ + from typesense.nl_search_models import NLSearchModels + + return "/".join([NLSearchModels.resource_path, self.model_id]) diff --git a/src/typesense/nl_search_models.py b/src/typesense/nl_search_models.py new file mode 100644 index 0000000..d184add --- /dev/null +++ b/src/typesense/nl_search_models.py @@ -0,0 +1,117 @@ +""" +This module provides functionality for managing NL search models in Typesense. + +Classes: + - NLSearchModels: Handles operations related to NL search models. + +Methods: + - __init__: Initializes the NLSearchModels object. + - __getitem__: Retrieves or creates an NLSearchModel object for a given model_id. + - create: Creates a new NL search model. + - retrieve: Retrieves all NL search models. + +Attributes: + - resource_path: The API resource path for NL search models operations. + +The NLSearchModels class interacts with the Typesense API to manage +NL search model operations. + +It provides methods to create and retrieve NL search models, as well as access +individual NLSearchModel objects. + +This module uses type hinting and is compatible with Python 3.11+ as well as earlier +versions through the use of the typing_extensions library. +""" + +import sys + +from typesense.api_call import ApiCall +from typesense.types.nl_search_model import ( + NLSearchModelCreateSchema, + NLSearchModelSchema, + NLSearchModelsRetrieveSchema, +) + +if sys.version_info > (3, 11): + import typing +else: + import typing_extensions as typing + +from typesense.nl_search_model import NLSearchModel + + +class NLSearchModels(object): + """ + Class for managing NL search models in Typesense. + + This class provides methods to interact with NL search models, including + creating, retrieving, and accessing individual models. + + Attributes: + resource_path (str): The API resource path for NL search models operations. + api_call (ApiCall): The API call object for making requests. + nl_search_models (Dict[str, NLSearchModel]): + A dictionary of NLSearchModel objects. + """ + + resource_path: typing.Final[str] = "/nl_search_models" + + def __init__(self, api_call: ApiCall) -> None: + """ + Initialize the NLSearchModels object. + + Args: + api_call (ApiCall): The API call object for making requests. + """ + self.api_call = api_call + self.nl_search_models: typing.Dict[str, NLSearchModel] = {} + + def __getitem__(self, model_id: str) -> NLSearchModel: + """ + Get or create an NLSearchModel object for a given model_id. + + Args: + model_id (str): The ID of the NL search model. + + Returns: + NLSearchModel: The NLSearchModel object for the given ID. + """ + if model_id not in self.nl_search_models: + self.nl_search_models[model_id] = NLSearchModel( + self.api_call, + model_id, + ) + return self.nl_search_models[model_id] + + def create(self, model: NLSearchModelCreateSchema) -> NLSearchModelSchema: + """ + Create a new NL search model. + + Args: + model (NLSearchModelCreateSchema): + The schema for creating the NL search model. + + Returns: + NLSearchModelSchema: The created NL search model. + """ + response = self.api_call.post( + endpoint=NLSearchModels.resource_path, + entity_type=NLSearchModelSchema, + as_json=True, + body=model, + ) + return response + + def retrieve(self) -> NLSearchModelsRetrieveSchema: + """ + Retrieve all NL search models. + + Returns: + NLSearchModelsRetrieveSchema: A list of all NL search models. + """ + response: NLSearchModelsRetrieveSchema = self.api_call.get( + endpoint=NLSearchModels.resource_path, + entity_type=NLSearchModelsRetrieveSchema, + as_json=True, + ) + return response diff --git a/src/typesense/types/document.py b/src/typesense/types/document.py index 416ca7e..a0b63b4 100644 --- a/src/typesense/types/document.py +++ b/src/typesense/types/document.py @@ -569,6 +569,23 @@ class CachingParameters(typing.TypedDict): cache_ttl: typing.NotRequired[int] +class NLLanguageParameters(typing.TypedDict): + """ + Parameters regarding [caching search results](https://typesense.org/docs/26.0/api/search.html#caching-parameters). + + Attributes: + nl_query_prompt_cache_ttl (int): The duration (in seconds) that determines how long the schema prompts are cached. + nl_query (bool): Whether to use natural language in the query or not. + nl_model_id (str): The ID of the natural language model to use for the query. + nl_query_debug (bool): Whether to return the raw LLM response or not. + """ + + nl_query_prompt_cache_ttl: typing.NotRequired[int] + nl_query: typing.NotRequired[bool] + nl_model_id: typing.NotRequired[str] + nl_query_debug: typing.NotRequired[bool] + + class SearchParameters( RequiredSearchParameters, QueryParameters, @@ -580,6 +597,7 @@ class SearchParameters( ResultsParameters, TypoToleranceParameters, CachingParameters, + NLLanguageParameters, ): """Parameters for searching documents.""" @@ -823,6 +841,38 @@ class Conversation(typing.TypedDict): query: str +class LLMResponse(typing.TypedDict): + """ + Schema for a raw LLM response. + + Attributes: + content (str): Content of the LLM response. + extraction_method (str): Extraction method of the LLM response (e.g. "regex"). + model (str): Model used to generate the response. + """ + + content: str + extraction_method: str + model: str + + +class ParsedNLQuery(typing.TypedDict): + """ + Schema for a parsed natural language query. + + Attributes: + parse_time_ms (int): Parse time in milliseconds. + generated_params (SearchParameters): Generated parameters. + augmented_params (SearchParameters): Augmented parameters. + llm_response (LLMResponse): Raw LLM response. + """ + + parse_time_ms: int + generated_params: SearchParameters + augmented_params: SearchParameters + llm_response: typing.NotRequired[LLMResponse] + + class SearchResponse(typing.Generic[TDoc], typing.TypedDict): """ Schema for a search response. @@ -838,6 +888,7 @@ class SearchResponse(typing.Generic[TDoc], typing.TypedDict): hits (list[Hit[TDoc]]): List of hits in the search results. grouped_hits (list[GroupedHit[TDoc]]): List of grouped hits in the search results. conversation (Conversation): Conversation in the search results. + parsed_nl_query (ParsedNLQuery): Information about the natural language query """ facet_counts: typing.List[SearchResponseFacetCountSchema] @@ -850,6 +901,7 @@ class SearchResponse(typing.Generic[TDoc], typing.TypedDict): hits: typing.List[Hit[TDoc]] grouped_hits: typing.NotRequired[typing.List[GroupedHit[TDoc]]] conversation: typing.NotRequired[Conversation] + parsed_nl_query: typing.NotRequired[ParsedNLQuery] class DeleteSingleDocumentParameters(typing.TypedDict): diff --git a/src/typesense/types/nl_search_model.py b/src/typesense/types/nl_search_model.py new file mode 100644 index 0000000..5ad4570 --- /dev/null +++ b/src/typesense/types/nl_search_model.py @@ -0,0 +1,140 @@ +"""NLSearchModel types for Typesense Python Client.""" + +import sys + +if sys.version_info >= (3, 11): + import typing +else: + import typing_extensions as typing + + +class NLSearchModelBase(typing.TypedDict): + """ + Base schema with all possible fields for NL search models. + + Attributes: + model_name (str): Name of the LLM model. + api_key (str): The LLM service's API Key. + api_url (str): The API URL for the LLM service. + max_bytes (int): The maximum number of bytes to send to the LLM. + temperature (float): The temperature parameter for the LLM. + system_prompt (str): The system prompt for the LLM. + top_p (float): The top_p parameter (Google-specific). + top_k (int): The top_k parameter (Google-specific). + stop_sequences (list[str]): Stop sequences for the LLM (Google-specific). + api_version (str): API version (Google-specific). + project_id (str): GCP project ID (GCP Vertex AI specific). + access_token (str): Access token for GCP (GCP Vertex AI specific). + refresh_token (str): Refresh token for GCP (GCP Vertex AI specific). + client_id (str): Client ID for GCP (GCP Vertex AI specific). + client_secret (str): Client secret for GCP (GCP Vertex AI specific). + region (str): Region for GCP (GCP Vertex AI specific). + max_output_tokens (int): Maximum output tokens (GCP Vertex AI specific). + account_id (str): Account ID (Cloudflare specific). + """ + + model_name: str + api_key: typing.NotRequired[str] + api_url: typing.NotRequired[str] + max_bytes: typing.NotRequired[int] + temperature: typing.NotRequired[float] + system_prompt: typing.NotRequired[str] + # Google-specific parameters + top_p: typing.NotRequired[float] + top_k: typing.NotRequired[int] + stop_sequences: typing.NotRequired[typing.List[str]] + api_version: typing.NotRequired[str] + # GCP Vertex AI specific + project_id: typing.NotRequired[str] + access_token: typing.NotRequired[str] + refresh_token: typing.NotRequired[str] + client_id: typing.NotRequired[str] + client_secret: typing.NotRequired[str] + region: typing.NotRequired[str] + max_output_tokens: typing.NotRequired[int] + # Cloudflare specific + account_id: typing.NotRequired[str] + + +class NLSearchModelCreateSchema(NLSearchModelBase): + """ + Schema for creating a new NL search model. + + Attributes: + id (str): The custom ID of the model. + """ + + id: typing.NotRequired[str] + + +class NLSearchModelUpdateSchema(typing.TypedDict): + """ + Base schema with all possible fields for NL search models. + + Attributes: + model_name (str): Name of the LLM model. + api_key (str): The LLM service's API Key. + api_url (str): The API URL for the LLM service. + max_bytes (int): The maximum number of bytes to send to the LLM. + temperature (float): The temperature parameter for the LLM. + system_prompt (str): The system prompt for the LLM. + top_p (float): The top_p parameter (Google-specific). + top_k (int): The top_k parameter (Google-specific). + stop_sequences (list[str]): Stop sequences for the LLM (Google-specific). + api_version (str): API version (Google-specific). + project_id (str): GCP project ID (GCP Vertex AI specific). + access_token (str): Access token for GCP (GCP Vertex AI specific). + refresh_token (str): Refresh token for GCP (GCP Vertex AI specific). + client_id (str): Client ID for GCP (GCP Vertex AI specific). + client_secret (str): Client secret for GCP (GCP Vertex AI specific). + region (str): Region for GCP (GCP Vertex AI specific). + max_output_tokens (int): Maximum output tokens (GCP Vertex AI specific). + account_id (str): Account ID (Cloudflare specific). + """ + + model_name: typing.NotRequired[str] + api_key: typing.NotRequired[str] + api_url: typing.NotRequired[str] + max_bytes: typing.NotRequired[int] + temperature: typing.NotRequired[float] + system_prompt: typing.NotRequired[str] + # Google-specific parameters + top_p: typing.NotRequired[float] + top_k: typing.NotRequired[int] + stop_sequences: typing.NotRequired[typing.List[str]] + api_version: typing.NotRequired[str] + # GCP Vertex AI specific + project_id: typing.NotRequired[str] + access_token: typing.NotRequired[str] + refresh_token: typing.NotRequired[str] + client_id: typing.NotRequired[str] + client_secret: typing.NotRequired[str] + region: typing.NotRequired[str] + max_output_tokens: typing.NotRequired[int] + # Cloudflare specific + account_id: typing.NotRequired[str] + + +class NLSearchModelDeleteSchema(typing.TypedDict): + """ + Schema for deleting an NL search model. + + Attributes: + id (str): The ID of the model. + """ + + id: str + + +class NLSearchModelSchema(NLSearchModelBase): + """ + Schema for an NL search model. + + Attributes: + id (str): The ID of the model. + """ + + id: str + + +NLSearchModelsRetrieveSchema = typing.List[NLSearchModelSchema] diff --git a/tests/fixtures/nl_search_model_fixtures.py b/tests/fixtures/nl_search_model_fixtures.py new file mode 100644 index 0000000..4949b98 --- /dev/null +++ b/tests/fixtures/nl_search_model_fixtures.py @@ -0,0 +1,78 @@ +"""Fixtures for the NL search model tests.""" + +import os + +import pytest +import requests +from dotenv import load_dotenv + +from typesense.api_call import ApiCall +from typesense.nl_search_model import NLSearchModel +from typesense.nl_search_models import NLSearchModels + +load_dotenv() + + +@pytest.fixture(scope="function", name="delete_all_nl_search_models") +def clear_typesense_nl_search_models() -> None: + """Remove all nl_search_models from the Typesense server.""" + url = "http://localhost:8108/nl_search_models" + headers = {"X-TYPESENSE-API-KEY": "xyz"} + + # Get the list of models + response = requests.get(url, headers=headers, timeout=3) + response.raise_for_status() + + nl_search_models = response.json() + + # Delete each NL search model + for nl_search_model in nl_search_models: + model_id = nl_search_model.get("id") + delete_url = f"{url}/{model_id}" + delete_response = requests.delete(delete_url, headers=headers, timeout=3) + delete_response.raise_for_status() + + +@pytest.fixture(scope="function", name="create_nl_search_model") +def create_nl_search_model_fixture() -> str: + """Create an NL search model in the Typesense server.""" + url = "http://localhost:8108/nl_search_models" + headers = {"X-TYPESENSE-API-KEY": "xyz"} + nl_search_model_data = { + "api_key": os.environ.get("OPEN_AI_KEY", "test-api-key"), + "max_bytes": 16384, + "model_name": "openai/gpt-3.5-turbo", + "system_prompt": "This is a system prompt for NL search", + } + + response = requests.post( + url, + headers=headers, + json=nl_search_model_data, + timeout=3, + ) + + response.raise_for_status() + + model_id: str = response.json()["id"] + return model_id + + +@pytest.fixture(scope="function", name="fake_nl_search_models") +def fake_nl_search_models_fixture(fake_api_call: ApiCall) -> NLSearchModels: + """Return an NLSearchModels object with test values.""" + return NLSearchModels(fake_api_call) + + +@pytest.fixture(scope="function", name="fake_nl_search_model") +def fake_nl_search_model_fixture(fake_api_call: ApiCall) -> NLSearchModel: + """Return an NLSearchModel object with test values.""" + return NLSearchModel(fake_api_call, "nl_search_model_id") + + +@pytest.fixture(scope="function", name="actual_nl_search_models") +def actual_nl_search_models_fixture( + actual_api_call: ApiCall, +) -> NLSearchModels: + """Return an NLSearchModels object using a real API.""" + return NLSearchModels(actual_api_call) diff --git a/tests/nl_search_model_test.py b/tests/nl_search_model_test.py new file mode 100644 index 0000000..d47a536 --- /dev/null +++ b/tests/nl_search_model_test.py @@ -0,0 +1,99 @@ +"""Tests for the NLSearchModel class.""" + +from __future__ import annotations + +import pytest +from dotenv import load_dotenv + +from tests.utils.object_assertions import ( + assert_match_object, + assert_object_lists_match, + assert_to_contain_keys, +) +from typesense.api_call import ApiCall +from typesense.nl_search_model import NLSearchModel +from typesense.nl_search_models import NLSearchModels + +load_dotenv() + + +def test_init(fake_api_call: ApiCall) -> None: + """Test that the NLSearchModel object is initialized correctly.""" + nl_search_model = NLSearchModel( + fake_api_call, + "nl_search_model_id", + ) + + assert nl_search_model.model_id == "nl_search_model_id" + assert_match_object(nl_search_model.api_call, fake_api_call) + assert_object_lists_match( + nl_search_model.api_call.node_manager.nodes, + fake_api_call.node_manager.nodes, + ) + assert_match_object( + nl_search_model.api_call.config.nearest_node, + fake_api_call.config.nearest_node, + ) + assert ( + nl_search_model._endpoint_path # noqa: WPS437 + == "/nl_search_models/nl_search_model_id" + ) + + +@pytest.mark.open_ai +def test_actual_retrieve( + actual_nl_search_models: NLSearchModels, + delete_all_nl_search_models: None, + create_nl_search_model: str, +) -> None: + """Test it can retrieve an NL search model from Typesense Server.""" + response = actual_nl_search_models[create_nl_search_model].retrieve() + + assert_to_contain_keys( + response, + ["id", "model_name", "system_prompt", "max_bytes", "api_key"], + ) + assert response.get("id") == create_nl_search_model + + +@pytest.mark.open_ai +def test_actual_update( + actual_nl_search_models: NLSearchModels, + delete_all_nl_search_models: None, + create_nl_search_model: str, +) -> None: + """Test that it can update an NL search model from Typesense Server.""" + response = actual_nl_search_models[create_nl_search_model].update( + {"system_prompt": "This is a new system prompt for NL search"}, + ) + + assert_to_contain_keys( + response, + [ + "id", + "model_name", + "system_prompt", + "max_bytes", + "api_key", + ], + ) + + assert response.get("system_prompt") == "This is a new system prompt for NL search" + assert response.get("id") == create_nl_search_model + + +@pytest.mark.open_ai +def test_actual_delete( + actual_nl_search_models: NLSearchModels, + delete_all_nl_search_models: None, + create_nl_search_model: str, +) -> None: + """Test that it can delete an NL search model from Typesense Server.""" + response = actual_nl_search_models[create_nl_search_model].delete() + + assert_to_contain_keys( + response, + ["id"], + ) + + assert response.get("id") == create_nl_search_model diff --git a/tests/nl_search_models_test.py b/tests/nl_search_models_test.py new file mode 100644 index 0000000..1558b39 --- /dev/null +++ b/tests/nl_search_models_test.py @@ -0,0 +1,117 @@ +"""Tests for the NLSearchModels class.""" + +from __future__ import annotations + +import os +import sys + +import pytest + +if sys.version_info >= (3, 11): + import typing +else: + import typing_extensions as typing + +from tests.utils.object_assertions import ( + assert_match_object, + assert_object_lists_match, + assert_to_contain_keys, + assert_to_contain_object, +) +from typesense.api_call import ApiCall +from typesense.nl_search_models import NLSearchModels +from typesense.types.nl_search_model import NLSearchModelSchema + + +def test_init(fake_api_call: ApiCall) -> None: + """Test that the NLSearchModels object is initialized correctly.""" + nl_search_models = NLSearchModels(fake_api_call) + + assert_match_object(nl_search_models.api_call, fake_api_call) + assert_object_lists_match( + nl_search_models.api_call.node_manager.nodes, + fake_api_call.node_manager.nodes, + ) + assert_match_object( + nl_search_models.api_call.config.nearest_node, + fake_api_call.config.nearest_node, + ) + + assert not nl_search_models.nl_search_models + + +def test_get_missing_nl_search_model( + fake_nl_search_models: NLSearchModels, +) -> None: + """Test that the NLSearchModels object can get a missing nl_search_model.""" + nl_search_model = fake_nl_search_models["nl_search_model_id"] + + assert_match_object( + nl_search_model.api_call, + fake_nl_search_models.api_call, + ) + assert_object_lists_match( + nl_search_model.api_call.node_manager.nodes, + fake_nl_search_models.api_call.node_manager.nodes, + ) + assert_match_object( + nl_search_model.api_call.config.nearest_node, + fake_nl_search_models.api_call.config.nearest_node, + ) + assert ( + nl_search_model._endpoint_path # noqa: WPS437 + == "/nl_search_models/nl_search_model_id" + ) + + +def test_get_existing_nl_search_model( + fake_nl_search_models: NLSearchModels, +) -> None: + """Test that the NLSearchModels object can get an existing nl_search_model.""" + nl_search_model = fake_nl_search_models["nl_search_model_id"] + fetched_nl_search_model = fake_nl_search_models["nl_search_model_id"] + + assert len(fake_nl_search_models.nl_search_models) == 1 + + assert nl_search_model is fetched_nl_search_model + + +@pytest.mark.open_ai +def test_actual_create( + actual_nl_search_models: NLSearchModels, +) -> None: + """Test that it can create an NL search model on Typesense Server.""" + response = actual_nl_search_models.create( + { + "api_key": os.environ.get("OPEN_AI_KEY", "test-api-key"), + "max_bytes": 16384, + "model_name": "openai/gpt-3.5-turbo", + "system_prompt": "This is meant for testing purposes", + }, + ) + + assert_to_contain_keys( + response, + ["id", "api_key", "max_bytes", "model_name", "system_prompt"], + ) + + +@pytest.mark.open_ai +def test_actual_retrieve( + actual_nl_search_models: NLSearchModels, + delete_all_nl_search_models: None, + create_nl_search_model: str, +) -> None: + """Test that it can retrieve NL search models from Typesense Server.""" + response = actual_nl_search_models.retrieve() + assert len(response) == 1 + assert_to_contain_object( + response[0], + { + "id": create_nl_search_model, + }, + ) + assert_to_contain_keys( + response[0], + ["id", "api_key", "max_bytes", "model_name", "system_prompt"], + )