Skip to content

Commit

Permalink
Reformatted and cleaned up deprecated code
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexejPenner committed Dec 11, 2024
1 parent a5c8be6 commit 6fc5b8a
Show file tree
Hide file tree
Showing 24 changed files with 321 additions and 275 deletions.
14 changes: 6 additions & 8 deletions llm-complete-guide/gh_action_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@

import click
import yaml
from zenml.enums import PluginSubType

from pipelines.llm_index_and_evaluate import llm_index_and_evaluate
from zenml.client import Client
from zenml import Model
from zenml.exceptions import ZenKeyError
from zenml.client import Client
from zenml.enums import PluginSubType


@click.command(
Expand Down Expand Up @@ -89,7 +87,7 @@ def main(
zenml_model_name: Optional[str] = "zenml-docs-qa-rag",
zenml_model_version: Optional[str] = None,
):
"""
"""
Executes the pipeline to train a basic RAG model.
Args:
Expand All @@ -108,14 +106,14 @@ def main(
config = yaml.safe_load(file)

# Read the model version from a file in the root of the repo
# called "ZENML_VERSION.txt".
# called "ZENML_VERSION.txt".
if zenml_model_version == "staging":
postfix = "-rc0"
elif zenml_model_version == "production":
postfix = ""
else:
postfix = "-dev"

if Path("ZENML_VERSION.txt").exists():
with open("ZENML_VERSION.txt", "r") as file:
zenml_model_version = file.read().strip()
Expand Down Expand Up @@ -177,7 +175,7 @@ def main(
service_account_id=service_account_id,
auth_window=0,
flavor="builtin",
action_type=PluginSubType.PIPELINE_RUN
action_type=PluginSubType.PIPELINE_RUN,
).id
client.create_trigger(
name="Production Trigger LLM-Complete",
Expand Down
4 changes: 2 additions & 2 deletions llm-complete-guide/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pipelines.generate_chunk_questions import generate_chunk_questions
from pipelines.llm_basic_rag import llm_basic_rag
from pipelines.llm_eval import llm_eval
from pipelines.rag_deployment import rag_deployment
from pipelines.llm_index_and_evaluate import llm_index_and_evaluate
from pipelines.local_deployment import local_deployment
from pipelines.prod_deployment import production_deployment
from pipelines.prod_deployment import production_deployment
from pipelines.rag_deployment import rag_deployment
1 change: 0 additions & 1 deletion llm-complete-guide/pipelines/finetune_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

from constants import EMBEDDINGS_MODEL_NAME_ZENML
from steps.finetune_embeddings import (
evaluate_base_model,
evaluate_finetuned_model,
Expand Down
1 change: 0 additions & 1 deletion llm-complete-guide/pipelines/llm_basic_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from litellm import config_path

from steps.populate_index import (
generate_embeddings,
Expand Down
3 changes: 2 additions & 1 deletion llm-complete-guide/pipelines/llm_index_and_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# limitations under the License.
#

from pipelines import llm_basic_rag, llm_eval
from zenml import pipeline

from pipelines import llm_basic_rag, llm_eval


@pipeline
def llm_index_and_evaluate() -> None:
Expand Down
1 change: 0 additions & 1 deletion llm-complete-guide/pipelines/local_deployment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from steps.bento_builder import bento_builder
from steps.bento_deployment import bento_deployment
from steps.visualize_chat import create_chat_interface
from zenml import pipeline


Expand Down
5 changes: 2 additions & 3 deletions llm-complete-guide/pipelines/prod_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@


@pipeline(enable_cache=False)
def production_deployment(
):
def production_deployment():
"""Model deployment pipeline.
This is a pipeline deploys trained model for future inference.
"""
bento_model_image = bento_dockerizer()
deployment_info = k8s_deployment(bento_model_image)
create_chat_interface(deployment_info)
create_chat_interface(deployment_info)
20 changes: 11 additions & 9 deletions llm-complete-guide/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@
generate_synthetic_data,
llm_basic_rag,
llm_eval,
rag_deployment,
llm_index_and_evaluate,
local_deployment,
production_deployment,
rag_deployment,
)
from structures import Document
from zenml.materializers.materializer_registry import materializer_registry
from zenml import Model
from zenml.materializers.materializer_registry import materializer_registry

logger = get_logger(__name__)

Expand Down Expand Up @@ -150,7 +150,7 @@
"env",
default="local",
help="The environment to use for the completion.",
)
)
def main(
pipeline: str,
query_text: Optional[str] = None,
Expand Down Expand Up @@ -186,9 +186,9 @@ def main(
}
},
}

# Read the model version from a file in the root of the repo
# called "ZENML_VERSION.txt".
# called "ZENML_VERSION.txt".
if zenml_model_version == "staging":
postfix = "-rc0"
elif zenml_model_version == "production":
Expand All @@ -200,8 +200,8 @@ def main(
with open("ZENML_VERSION.txt", "r") as file:
zenml_version = file.read().strip()
zenml_version += postfix
#zenml_model_version = file.read().strip()
#zenml_model_version += postfix
# zenml_model_version = file.read().strip()
# zenml_model_version += postfix
else:
raise RuntimeError(
"No model version file found. Please create a file called ZENML_VERSION.txt in the root of the repo with the model version."
Expand Down Expand Up @@ -294,7 +294,9 @@ def main(

elif pipeline == "embeddings":
finetune_embeddings.with_options(
model=zenml_model, config_path=config_path, **embeddings_finetune_args
model=zenml_model,
config_path=config_path,
**embeddings_finetune_args,
)()

elif pipeline == "chunks":
Expand All @@ -309,4 +311,4 @@ def main(
materializer_registry.register_materializer_type(
Document, DocumentMaterializer
)
main()
main()
104 changes: 64 additions & 40 deletions llm-complete-guide/service.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import asyncio
from typing import Any, AsyncGenerator, Dict
from typing import AsyncGenerator

import bentoml
import litellm
import numpy as np
from constants import (
EMBEDDINGS_MODEL_ID_FINE_TUNED,
MODEL_NAME_MAP,
OPENAI_MODEL,
SECRET_NAME,
SECRET_NAME_ELASTICSEARCH,
)
from elasticsearch import Elasticsearch
Expand All @@ -29,30 +26,43 @@
http={
"cors": {
"enabled": True,
"access_control_allow_origins": ["https://cloud.zenml.io"], # Add your allowed origins
"access_control_allow_methods": ["GET", "OPTIONS", "POST", "HEAD", "PUT"],
"access_control_allow_origins": [
"https://cloud.zenml.io"
], # Add your allowed origins
"access_control_allow_methods": [
"GET",
"OPTIONS",
"POST",
"HEAD",
"PUT",
],
"access_control_allow_credentials": True,
"access_control_allow_headers": ["*"],
# "access_control_allow_origin_regex": "https://.*\.my_org\.com", # Optional regex
"access_control_max_age": 1200,
"access_control_expose_headers": ["Content-Length"],
}
}
},
)
class RAGService:
"""RAG service for generating responses using LLM and RAG."""

def __init__(self):
"""Initialize the RAG service."""
# Initialize embeddings model
self.embeddings_model = SentenceTransformer(EMBEDDINGS_MODEL)

# Initialize reranker
self.reranker = Reranker("flashrank")

# Initialize Elasticsearch client
client = Client()
es_host = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values["elasticsearch_host"]
es_api_key = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values["elasticsearch_api_key"]
es_host = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values[
"elasticsearch_host"
]
es_api_key = client.get_secret(
SECRET_NAME_ELASTICSEARCH
).secret_values["elasticsearch_api_key"]
self.es_client = Elasticsearch(es_host, api_key=es_api_key)

def get_embeddings(self, text: str) -> np.ndarray:
Expand All @@ -62,40 +72,52 @@ def get_embeddings(self, text: str) -> np.ndarray:
embeddings = embeddings[0]
return embeddings

def get_similar_docs(self, query_embedding: np.ndarray, n: int = 20) -> list:
def get_similar_docs(
self, query_embedding: np.ndarray, n: int = 20
) -> list:
"""Get similar documents for the given query embedding."""
if query_embedding.ndim == 2:
query_embedding = query_embedding[0]

response = self.es_client.search(index="zenml_docs", knn={
"field": "embedding",
"query_vector": query_embedding.tolist(),
"num_candidates": 50,
"k": n
})


response = self.es_client.search(
index="zenml_docs",
knn={
"field": "embedding",
"query_vector": query_embedding.tolist(),
"num_candidates": 50,
"k": n,
},
)

docs = []
for hit in response["hits"]["hits"]:
docs.append({
"content": hit["_source"]["content"],
"url": hit["_source"]["url"],
"parent_section": hit["_source"]["parent_section"]
})
docs.append(
{
"content": hit["_source"]["content"],
"url": hit["_source"]["url"],
"parent_section": hit["_source"]["parent_section"],
}
)
return docs

def rerank_documents(self, query: str, documents: list) -> list:
"""Rerank documents using the reranker."""
docs_texts = [f"{doc['content']} PARENT SECTION: {doc['parent_section']}" for doc in documents]
docs_texts = [
f"{doc['content']} PARENT SECTION: {doc['parent_section']}"
for doc in documents
]
results = self.reranker.rank(query=query, docs=docs_texts)

reranked_docs = []
for result in results.results:
index_val = result.doc_id
doc = documents[index_val]
reranked_docs.append((result.text, doc["url"]))
return reranked_docs[:5]

async def get_completion(self, messages: list, model: str, temperature: float, max_tokens: int) -> AsyncGenerator[str, None]:
async def get_completion(
self, messages: list, model: str, temperature: float, max_tokens: int
) -> AsyncGenerator[str, None]:
"""Handle the completion request and streaming response."""
try:
response = await litellm.acompletion(
Expand All @@ -104,9 +126,9 @@ async def get_completion(self, messages: list, model: str, temperature: float, m
temperature=temperature,
max_tokens=max_tokens,
api_key=get_openai_api_key(),
stream=True
stream=True,
)

async for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
Expand All @@ -124,16 +146,16 @@ async def generate(
try:
# Get embeddings for query
query_embedding = self.get_embeddings(query)

# Retrieve similar documents
similar_docs = self.get_similar_docs(query_embedding, n=20)

# Rerank documents
reranked_docs = self.rerank_documents(query, similar_docs)

# Prepare context from reranked documents
context = "\n\n".join([doc[0] for doc in reranked_docs])

# Prepare system message
system_message = """
You are a friendly chatbot. \
Expand All @@ -149,15 +171,17 @@ async def generate(
{"role": "system", "content": system_message},
{"role": "user", "content": query},
{
"role": "assistant",
"content": f"Please use the following relevant ZenML documentation to answer the query: \n{context}"
}
"role": "assistant",
"content": f"Please use the following relevant ZenML documentation to answer the query: \n{context}",
},
]

# Get completion from LLM using the new async method
model = MODEL_NAME_MAP.get(OPENAI_MODEL, OPENAI_MODEL)
async for chunk in self.get_completion(messages, model, temperature, max_tokens):
async for chunk in self.get_completion(
messages, model, temperature, max_tokens
):
yield chunk

except Exception as e:
yield f"Error occurred: {str(e)}"
yield f"Error occurred: {str(e)}"
Loading

0 comments on commit 6fc5b8a

Please sign in to comment.