From 4e3aa805dc9af80160127c0eb0b5c75e30da4798 Mon Sep 17 00:00:00 2001 From: James Osmond Date: Tue, 16 Apr 2024 11:03:30 +0100 Subject: [PATCH] Removed empty import --- hackathon/streamlit/utils.py | 38 ++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/hackathon/streamlit/utils.py b/hackathon/streamlit/utils.py index a48328b..c423016 100644 --- a/hackathon/streamlit/utils.py +++ b/hackathon/streamlit/utils.py @@ -13,7 +13,9 @@ import plotly.graph_objects as go import streamlit as st from folium.plugins import StripePattern -from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings +from langchain.embeddings.sentence_transformer import ( + SentenceTransformerEmbeddings, +) from streamlit.runtime import get_instance from streamlit.runtime.scriptrunner import get_script_run_ctx from streamlit_folium import folium_static @@ -33,9 +35,10 @@ S3_LOADER_FILE_NAME, VECTOR_STORE_CONFIG, ) -from hackathon.llm.chain_config import ( -) +# from hackathon.llm.chain_config import ( + +# ) from hackathon.llm.llm import LLama2, SagemakerHostedLLM from hackathon.llm.llm_handler import LLMRunner from hackathon.loader.chunker import TextChunker @@ -74,10 +77,14 @@ def get_password(): # Create a Secrets Manager client session = boto3.session.Session() - client = session.client(service_name="secretsmanager", region_name=AWS_REGION) + client = session.client( + service_name="secretsmanager", region_name=AWS_REGION + ) try: - get_secret_value_response = client.get_secret_value(SecretId=secret_name) + get_secret_value_response = client.get_secret_value( + SecretId=secret_name + ) except Exception as e: # For a list of exceptions thrown, see # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html @@ -128,12 +135,15 @@ def initialise_llm_runner(): if VECTOR_STORE_CONFIG == "chroma": vector_store = ChromaStore( - embedding_function=st_embedder, collection_name=OPENSEARCH_INDEX_NAME + embedding_function=st_embedder, + collection_name=OPENSEARCH_INDEX_NAME, ) else: if "skills_os_client" not in st.session_state: st.session_state["skills_os_client"] = OpensearchClient( - OPENSEARCH_SKILLS_INDEX_NAME, OPENSEARCH_ENDPOINT_NAME, AWS_REGION + OPENSEARCH_SKILLS_INDEX_NAME, + OPENSEARCH_ENDPOINT_NAME, + AWS_REGION, ) if "vacancy_os_client" not in st.session_state: st.session_state["vacancy_os_client"] = OpensearchClient( @@ -141,7 +151,9 @@ def initialise_llm_runner(): ) vector_store = OpenSearchStore( - st_embedder, OPENSEARCH_INDEX_NAME, st.session_state["vacancy_os_client"] + st_embedder, + OPENSEARCH_INDEX_NAME, + st.session_state["vacancy_os_client"], ) if LLM_MODEL == "local_llm": @@ -156,11 +168,11 @@ def initialise_llm_runner(): llm_runner = LLMRunner( llm=llm, vectorstore=vector_store, - chain_configs=[ - ], + chain_configs=[], ) st.session_state["runner"] = llm_runner + def initialise_vector_store_loader(): if LLM_MODEL == "local_llm": st_embedder = SentenceTransformerEmbeddings( @@ -179,7 +191,9 @@ def initialise_vector_store_loader(): OPENSEARCH_INDEX_NAME, OPENSEARCH_ENDPOINT_NAME, AWS_REGION ) vector_store = OpensearchClientStore( - st_embedder, OPENSEARCH_INDEX_NAME, st.session_state["vacancy_os_client"] + st_embedder, + OPENSEARCH_INDEX_NAME, + st.session_state["vacancy_os_client"], ) if LOADER_CONFIG == "file_loader": @@ -197,9 +211,9 @@ def initialise_vector_store_loader(): chunker=TextChunker(chunk_size=1000, overlap=10), ) + def safe_literal_eval(x): try: return literal_eval(x) except (SyntaxError, ValueError): return None -