Skip to content

Commit

Permalink
Removed empty import
Browse files Browse the repository at this point in the history
  • Loading branch information
James-Osmond committed Apr 16, 2024
1 parent 0a97b79 commit 4e3aa80
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions hackathon/streamlit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import plotly.graph_objects as go
import streamlit as st
from folium.plugins import StripePattern
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.embeddings.sentence_transformer import (
SentenceTransformerEmbeddings,
)
from streamlit.runtime import get_instance
from streamlit.runtime.scriptrunner import get_script_run_ctx
from streamlit_folium import folium_static
Expand All @@ -33,9 +35,10 @@
S3_LOADER_FILE_NAME,
VECTOR_STORE_CONFIG,
)
from hackathon.llm.chain_config import (

)
# from hackathon.llm.chain_config import (

# )
from hackathon.llm.llm import LLama2, SagemakerHostedLLM
from hackathon.llm.llm_handler import LLMRunner
from hackathon.loader.chunker import TextChunker
Expand Down Expand Up @@ -74,10 +77,14 @@ def get_password():

# Create a Secrets Manager client
session = boto3.session.Session()
client = session.client(service_name="secretsmanager", region_name=AWS_REGION)
client = session.client(
service_name="secretsmanager", region_name=AWS_REGION
)

try:
get_secret_value_response = client.get_secret_value(SecretId=secret_name)
get_secret_value_response = client.get_secret_value(
SecretId=secret_name
)
except Exception as e:
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
Expand Down Expand Up @@ -128,20 +135,25 @@ def initialise_llm_runner():

if VECTOR_STORE_CONFIG == "chroma":
vector_store = ChromaStore(
embedding_function=st_embedder, collection_name=OPENSEARCH_INDEX_NAME
embedding_function=st_embedder,
collection_name=OPENSEARCH_INDEX_NAME,
)
else:
if "skills_os_client" not in st.session_state:
st.session_state["skills_os_client"] = OpensearchClient(
OPENSEARCH_SKILLS_INDEX_NAME, OPENSEARCH_ENDPOINT_NAME, AWS_REGION
OPENSEARCH_SKILLS_INDEX_NAME,
OPENSEARCH_ENDPOINT_NAME,
AWS_REGION,
)
if "vacancy_os_client" not in st.session_state:
st.session_state["vacancy_os_client"] = OpensearchClient(
OPENSEARCH_INDEX_NAME, OPENSEARCH_ENDPOINT_NAME, AWS_REGION
)

vector_store = OpenSearchStore(
st_embedder, OPENSEARCH_INDEX_NAME, st.session_state["vacancy_os_client"]
st_embedder,
OPENSEARCH_INDEX_NAME,
st.session_state["vacancy_os_client"],
)

if LLM_MODEL == "local_llm":
Expand All @@ -156,11 +168,11 @@ def initialise_llm_runner():
llm_runner = LLMRunner(
llm=llm,
vectorstore=vector_store,
chain_configs=[
],
chain_configs=[],
)
st.session_state["runner"] = llm_runner


def initialise_vector_store_loader():
if LLM_MODEL == "local_llm":
st_embedder = SentenceTransformerEmbeddings(
Expand All @@ -179,7 +191,9 @@ def initialise_vector_store_loader():
OPENSEARCH_INDEX_NAME, OPENSEARCH_ENDPOINT_NAME, AWS_REGION
)
vector_store = OpensearchClientStore(
st_embedder, OPENSEARCH_INDEX_NAME, st.session_state["vacancy_os_client"]
st_embedder,
OPENSEARCH_INDEX_NAME,
st.session_state["vacancy_os_client"],
)

if LOADER_CONFIG == "file_loader":
Expand All @@ -197,9 +211,9 @@ def initialise_vector_store_loader():
chunker=TextChunker(chunk_size=1000, overlap=10),
)


def safe_literal_eval(x):
try:
return literal_eval(x)
except (SyntaxError, ValueError):
return None

0 comments on commit 4e3aa80

Please sign in to comment.