Skip to content

Commit

Permalink
more small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
drf7 committed Jan 22, 2025
1 parent 1355db2 commit b6fe2cb
Showing 1 changed file with 25 additions and 18 deletions.
43 changes: 25 additions & 18 deletions nuclia/lib/nuclia_nua_chat.py → nuclia/lib/nua_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,24 @@
from base64 import b64decode

try:
from litellm import (
CustomLLM,
ModelResponse,
Choices,
Message,
)

from litellm import CustomLLM
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.utils import ModelResponse, Choices, Message
except ImportError:
raise ImportError(
"The 'litellm' library is required to use this functionality. "
"Install it with: pip install nuclia[litellm]"
)

# Nuclia (sync & async)
from nuclia.lib.nua import NuaClient, AsyncNuaClient
from nuclia.sdk.predict import NucliaPredict, AsyncNucliaPredict
from nuclia.lib.nua import NuaClient
from nuclia.sdk.predict import NucliaPredict
from nuclia.lib.nua_responses import ChatModel, UserPrompt
from nuclia_models.predict.generative_responses import (
GenerativeFullResponse,
)
from typing import Callable, Optional, Union

import httpx


class NucliaNuaChat(CustomLLM):
Expand All @@ -47,13 +45,6 @@ def __init__(self, token: str):
)
self.predict_sync = NucliaPredict()

self.nc_async = AsyncNuaClient(
region=self.region_base_url,
token=self.token,
account="", # Not needed for current implementation, required by the client
)
self.predict_async = AsyncNucliaPredict()

@staticmethod
def _parse_token(token: str):
parts = token.split(".")
Expand Down Expand Up @@ -85,7 +76,23 @@ def _process_messages(self, messages: list[dict[str, str]]) -> tuple[str, str]:
return formatted_system, formatted_user

def completion(
self, *args, model: str, messages: list[dict[str, str]], **kwargs
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[HTTPHandler] = None,
) -> ModelResponse:
if not self.predict_sync or not self.nc_sync:
raise RuntimeError("Sync clients not initialized.")
Expand Down

0 comments on commit b6fe2cb

Please sign in to comment.