Skip to content

Commit

Permalink
removing the tokens concept.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfmezger committed Jun 2, 2024
1 parent 8ec45ea commit 1328380
Show file tree
Hide file tree
Showing 13 changed files with 33 additions and 106 deletions.
3 changes: 2 additions & 1 deletion agent/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ async def embedd_text(embedding: EmbeddTextRequest, llm_backend: LLMBackend) ->
"""
logger.info("Embedding Text")

# TODO: REWORK THE TOKEN
token = validate_token(token=llm_backend.token, llm_backend=llm_backend, aleph_alpha_key=ALEPH_ALPHA_API_KEY, openai_key=OPENAI_API_KEY)

service = LLMContext(LLMStrategyFactory.get_strategy(strategy_type=llm_backend.llm_provider, token=token, collection_name=llm_backend.collection_name))
Expand Down Expand Up @@ -195,7 +197,6 @@ def search(search: SearchParams, llm_backend: LLMBackend) -> list[SearchResponse
"""
logger.info("Searching for Documents")
llm_backend.token = validate_token(token=llm_backend.token, llm_backend=llm_backend, aleph_alpha_key=ALEPH_ALPHA_API_KEY, openai_key=OPENAI_API_KEY)

service = LLMContext(LLMStrategyFactory.get_strategy(strategy_type=llm_backend.llm_provider, token=llm_backend.token, collection_name=llm_backend.collection_name))

Expand Down
3 changes: 1 addition & 2 deletions agent/backend/LLMBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ class LLMBase(ABC):
"""The LLM Base Strategy."""

@abstractmethod
def __init__(self, token: str | None, collection_name: str | None) -> None:
def __init__(self, collection_name: str | None) -> None:
"""Init the LLM Base."""
self.token = token
self.collection_name = collection_name

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions agent/backend/LLMStrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def __init__(self, llm: LLMBase) -> None:
"""Init the Context."""
self.llm = llm

def change_strategy(self, strategy_type: str, token: str, collection_name: str) -> None:
def change_strategy(self, strategy_type: str, collection_name: str) -> None:
"""Changes the strategy using the Factory."""
self.llm = LLMStrategyFactory.get_strategy(strategy_type=strategy_type, token=token, collection_name=collection_name)
self.llm = LLMStrategyFactory.get_strategy(strategy_type=strategy_type, collection_name=collection_name)

def search(self, search: SearchParams) -> list:
"""Wrapper for the search."""
Expand Down
12 changes: 4 additions & 8 deletions agent/backend/aleph_alpha_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,10 @@ class AlephAlphaService(LLMBase):
"""Aleph Alpha Strategy implementation."""

@load_config(location="config/main.yml")
def __init__(self, cfg: DictConfig, collection_name: str, token: str) -> None:
def __init__(self, cfg: DictConfig, collection_name: str) -> None:
"""Initialize the Aleph Alpha Service."""
super().__init__(token=token, collection_name=collection_name)
super().__init__(collection_name=collection_name)
"""Initialize the Aleph Alpha Service."""
if token:
os.environ["ALEPH_ALPHA_API_KEY"] = token

self.cfg = cfg

if collection_name:
Expand All @@ -61,7 +58,6 @@ def __init__(self, cfg: DictConfig, collection_name: str, token: str) -> None:

embedding = AlephAlphaAsymmetricSemanticEmbedding(
model=self.cfg.aleph_alpha_embeddings.model_name,
aleph_alpha_api_key=self.aleph_alpha_token,
normalize=self.cfg.aleph_alpha_embeddings.normalize,
compress_to_size=self.cfg.aleph_alpha_embeddings.compress_to_size,
)
Expand All @@ -72,7 +68,7 @@ def __init__(self, cfg: DictConfig, collection_name: str, token: str) -> None:

def get_tokenizer(self) -> None:
"""Initialize the tokenizer."""
client = Client(token=self.aleph_alpha_token)
client = Client(token=os.getenv("ALEPH_ALPHA_API_KEY"))
self.tokenizer = client.tokenizer("luminous-base")

def count_tokens(self, text: str) -> int:
Expand Down Expand Up @@ -115,7 +111,7 @@ def summarize_text(self, text: str) -> str:
"""
# TODO: rewrite because deprecated.
client = Client(token=self.aleph_alpha_token)
client = Client(token=os.getenv("ALEPH_ALPHA_API_KEY"))
document = Document.from_text(text=text)
request = SummarizationRequest(document=document)
response = client.summarize(request=request)
Expand Down
8 changes: 2 additions & 6 deletions agent/backend/cohere_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Cohere Backend."""
import os

from dotenv import load_dotenv
from langchain_cohere import ChatCohere, CohereEmbeddings
Expand Down Expand Up @@ -30,12 +29,9 @@ class CohereService(LLMBase):
"""Wrapper for cohere llms."""

@load_config(location="config/main.yml")
def __init__(self, cfg: DictConfig, collection_name: str | None, token: str | None) -> None:
def __init__(self, cfg: DictConfig, collection_name: str | None) -> None:
"""Init the Cohere Service."""
super().__init__(token=token, collection_name=collection_name)

if token:
os.environ["COHERE_API_KEY"] = token
super().__init__(collection_name=collection_name)

self.cfg = cfg

Expand Down
10 changes: 3 additions & 7 deletions agent/backend/gpt4all_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,11 @@ class GPT4AllService(LLMBase):
"""GPT4ALL Backend Service."""

@load_config(location="config/main.yml")
def __init__(self, cfg: DictConfig, collection_name: str, token: str | None) -> None:
def __init__(self, cfg: DictConfig, collection_name: str) -> None:
"""Init the GPT4ALL Service."""
self.cfg = cfg
self.token = token
super().__init__(collection_name=collection_name)

if collection_name:
self.collection_name = collection_name
else:
self.collection_name = self.cfg.qdrant.collection_name_gpt4all
self.cfg = cfg

embedding = GPT4AllEmbeddings(model_name="nomic-embed-text-v1.5.f16.gguf")

Expand Down
6 changes: 3 additions & 3 deletions agent/backend/ollama_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class OllamaService(LLMBase):
"""Wrapper for Ollama llms."""

@load_config(location="config/main.yml")
def __init__(self, cfg: DictConfig, collection_name: str | None, token: str | None) -> None:
def __init__(self, cfg: DictConfig, collection_name: str | None) -> None:
"""Init the Ollama Service."""
super().__init__(token=token, collection_name=collection_name)
super().__init__(collection_name=collection_name)

self.cfg = cfg

Expand Down Expand Up @@ -122,7 +122,7 @@ def summarize_text(self, text: str) -> str:
if __name__ == "__main__":
query = "Was ist Attention?"

Ollama_service = OllamaService(collection_name="", token="")
Ollama_service = OllamaService(collection_name="")

Ollama_service.embed_documents(directory="tests/resources/")

Expand Down
7 changes: 2 additions & 5 deletions agent/backend/open_ai_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,11 @@ class OpenAIService(LLMBase):
"""OpenAI Backend Service."""

@load_config(location="config/main.yml")
def __init__(self, cfg: DictConfig, collection_name: str, token: str) -> None:
def __init__(self, cfg: DictConfig, collection_name: str) -> None:
"""Init the OpenAI Service."""
super().__init__(token=token, collection_name=collection_name)
super().__init__(collection_name=collection_name)

"""Openai Service."""
if token:
os.environ["COHERE_API_KEY"] = token

self.cfg = cfg

if collection_name:
Expand Down
58 changes: 0 additions & 58 deletions agent/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from loguru import logger

from agent.data_model.internal_model import RetrievalResults
from agent.data_model.request_data_model import LLMProvider

# add new languages to detect here
languages = [Language.ENGLISH, Language.GERMAN]
Expand Down Expand Up @@ -132,63 +131,6 @@ def load_prompt_template(prompt_name: str, task: str) -> PromptTemplate:
return prompt_template


def get_token(token: str | None, llm_provider: str | LLMProvider | None, aleph_alpha_key: str | None, openai_key: str | None) -> str:
"""Get the token from the environment variables or the parameter.
Args:
----
token (str, optional): Token from the REST service.
llm_provider (Union[str, LLMProvider], optional): LLM provider. Defaults to "openai".
aleph_alpha_key (str, optional): Key from the .env file. Defaults to None.
openai_key (str, optional): Key from the .env file. Defaults to None.
Returns:
-------
str: Token for the LLM Provider of choice.
Raises:
------
ValueError: If no token is provided.
"""
if isinstance(llm_provider, str):
llm_provider = LLMProvider.normalize(llm_provider)

if token in ("string", ""):
token = None

if token:
return token

env_token = aleph_alpha_key if llm_provider == LLMProvider.ALEPH_ALPHA else openai_key
if not env_token and not token:
msg = "No token provided."
raise ValueError(msg)
return env_token


def validate_token(token: str | None, llm_backend: str | LLMProvider, aleph_alpha_key: str | None, openai_key: str | None) -> str:
"""Test if a token is available, and raise an error if it is missing when needed.
Args:
----
token (str): Token from the request
llm_backend (str): Backend from the request
aleph_alpha_key (str): Key from the .env file
openai_key (str): Key from the .env file
Raises:
------
ValueError: If the llm backend is AA or OpenAI and there is no token.
Returns:
-------
str: Token
"""
return get_token(token, llm_backend.llm_provider, aleph_alpha_key, openai_key) if llm_backend != "gpt4all" else "gpt4all"


def convert_qdrant_result_to_retrieval_results(docs: list) -> list[RetrievalResults]:
"""Converts the Qdrant result to a list of tuples.
Expand Down
2 changes: 1 addition & 1 deletion config/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ qdrant:
# COHERE CONFIG
cohere_embeddings:
embedding_model_name: "embed-multilingual-v3.0"
size: 2048
size: 1024

# OLLAMA CONFIG
ollama_embeddings:
Expand Down
6 changes: 6 additions & 0 deletions prompts/chat/aleph_alpha_chat.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
### Instruction:
{{question}} If there's no answer, say "NO_ANSWER_IN_TEXT".
### Input:
Text:{{context}}
Question:{{question}}
### Response:
6 changes: 0 additions & 6 deletions prompts/en/aleph_alpha_qa.j2

This file was deleted.

14 changes: 7 additions & 7 deletions tests/unit_tests/test_utility.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Tests for the utility functions."""
from agent.utils.utility import generate_prompt, validate_token
from agent.utils.utility import generate_prompt, set_token


def test_generate_prompt() -> None:
Expand Down Expand Up @@ -37,30 +37,30 @@ def test_combine_text_from_list() -> None:

def test_validate_token() -> None:
"""Test that validate_token returns the correct token."""
token = validate_token(token="example_token", llm_backend="openai", aleph_alpha_key="example_key_a", openai_key="example_key_o")
token = set_token(token="example_token", llm_backend="openai", aleph_alpha_key="example_key_a", openai_key="example_key_o")

assert token == "example_token"

token = validate_token(token="", llm_backend="aleph-alpha", aleph_alpha_key="example_key_a", openai_key="example_key_o")
token = set_token(token="", llm_backend="aleph-alpha", aleph_alpha_key="example_key_a", openai_key="example_key_o")

assert token == "example_key_a"

token = validate_token(token="", llm_backend="openai", aleph_alpha_key="example_key_a", openai_key="example_key_o")
token = set_token(token="", llm_backend="openai", aleph_alpha_key="example_key_a", openai_key="example_key_o")

assert token == "example_key_o"

token = validate_token(token=None, llm_backend="openai", aleph_alpha_key="example_key_a", openai_key="example_key_o")
token = set_token(token=None, llm_backend="openai", aleph_alpha_key="example_key_a", openai_key="example_key_o")

assert token == "example_key_o"

token = validate_token(token="", llm_backend="gpt4all", aleph_alpha_key="example_key_a", openai_key="example_key_o")
token = set_token(token="", llm_backend="gpt4all", aleph_alpha_key="example_key_a", openai_key="example_key_o")

assert token == "gpt4all"

from agent.data_model.request_data_model import LLMProvider

backend = LLMProvider.ALEPH_ALPHA

token = validate_token(token="", llm_backend=backend, aleph_alpha_key="example_key_a", openai_key="example_key_o")
token = set_token(token="", llm_backend=backend, aleph_alpha_key="example_key_a", openai_key="example_key_o")

assert token == "example_key_a"

0 comments on commit 1328380

Please sign in to comment.