diff --git a/nuclia/lib/nuclia_nua_chat.py b/nuclia/lib/nua_chat.py similarity index 78% rename from nuclia/lib/nuclia_nua_chat.py rename to nuclia/lib/nua_chat.py index 5333d0e..d55fc05 100644 --- a/nuclia/lib/nuclia_nua_chat.py +++ b/nuclia/lib/nua_chat.py @@ -4,26 +4,24 @@ from base64 import b64decode try: - from litellm import ( - CustomLLM, - ModelResponse, - Choices, - Message, - ) - + from litellm import CustomLLM + from litellm.llms.custom_httpx.http_handler import HTTPHandler + from litellm.utils import ModelResponse, Choices, Message except ImportError: raise ImportError( "The 'litellm' library is required to use this functionality. " "Install it with: pip install nuclia[litellm]" ) -# Nuclia (sync & async) -from nuclia.lib.nua import NuaClient, AsyncNuaClient -from nuclia.sdk.predict import NucliaPredict, AsyncNucliaPredict +from nuclia.lib.nua import NuaClient +from nuclia.sdk.predict import NucliaPredict from nuclia.lib.nua_responses import ChatModel, UserPrompt from nuclia_models.predict.generative_responses import ( GenerativeFullResponse, ) +from typing import Callable, Optional, Union + +import httpx class NucliaNuaChat(CustomLLM): @@ -47,13 +45,6 @@ def __init__(self, token: str): ) self.predict_sync = NucliaPredict() - self.nc_async = AsyncNuaClient( - region=self.region_base_url, - token=self.token, - account="", # Not needed for current implementation, required by the client - ) - self.predict_async = AsyncNucliaPredict() - @staticmethod def _parse_token(token: str): parts = token.split(".") @@ -85,7 +76,23 @@ def _process_messages(self, messages: list[dict[str, str]]) -> tuple[str, str]: return formatted_system, formatted_user def completion( - self, *args, model: str, messages: list[dict[str, str]], **kwargs + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[HTTPHandler] = None, ) -> ModelResponse: if not self.predict_sync or not self.nc_sync: raise RuntimeError("Sync clients not initialized.")