Skip to content

Feature/46 retrieve 성능 향상 #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions interface/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def summarize_total_tokens(data):
return total_tokens


use_reranker = st.checkbox(
"리랭킹(Reranking) 기능 사용",
value=False,
help="검색 결과의 정확도를 높이기 위한 리랭킹 기능을 사용합니다.",
)

# 버튼 클릭 시 실행
if st.button("쿼리 실행"):
# 그래프 컴파일 및 쿼리 실행
Expand All @@ -38,6 +44,7 @@ def summarize_total_tokens(data):
"messages": [HumanMessage(content=user_query)],
"user_database_env": user_database_env,
"best_practice_query": "",
"use_rerank": use_reranker,
}
)
total_tokens = summarize_total_tokens(res["messages"])
Expand Down
11 changes: 8 additions & 3 deletions llm_utils/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,28 @@ def create_query_refiner_chain(llm):
예시:
사용자가 "유저 이탈 원인이 궁금해요"라고 했다면,
재질문 형식이 아니라
"최근 1개월 간의 접속·결제 로그를 기준으로,
"접속·결제 로그를 기준으로,
주로 어떤 사용자가 어떤 과정을 거쳐 이탈하는지를 분석해야 한다"처럼
분석 방향이 명확해진 질문 한 문장(또는 한 문단)으로 정리해 주세요.

최종 출력 형식 예시:
------------------------------
구체화된 질문:
"최근 1개월 동안 고액 결제 경험이 있는 유저가
"고액 결제 경험이 있는 유저가
행동 로그에서 이탈 전 어떤 패턴을 보였는지 분석"

가정한 조건:
- 최근 1개월치 행동 로그와 결제 로그 중심
- 행동 로그와 결제 로그 중심
- 고액 결제자(월 결제액 10만 원 이상) 그룹 대상으로 한정
------------------------------
""",
),
MessagesPlaceholder(variable_name="user_input"),
(
"system",
"다음은 사용자의 실제 사용 가능한 테이블 및 컬럼 정보입니다:",
),
MessagesPlaceholder(variable_name="searched_tables"),
(
"system",
"""
Expand Down
57 changes: 12 additions & 45 deletions llm_utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

from llm_utils.tools import get_info_from_db
from llm_utils.retrieval import search_tables

# 노드 식별자 정의
QUERY_REFINER = "query_refiner"
Expand All @@ -31,6 +32,7 @@ class QueryMakerState(TypedDict):
best_practice_query: str
refined_input: str
generated_query: str
use_rerank: bool


# 노드 함수: QUERY_REFINER 노드
Expand All @@ -40,6 +42,7 @@ def query_refiner_node(state: QueryMakerState):
"user_input": [state["messages"][0].content],
"user_database_env": [state["user_database_env"]],
"best_practice_query": [state["best_practice_query"]],
"searched_tables": [json.dumps(state["searched_tables"])],
}
)
state["messages"].append(res)
Expand All @@ -48,43 +51,10 @@ def query_refiner_node(state: QueryMakerState):


def get_table_info_node(state: QueryMakerState):
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings

embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
try:
db = FAISS.load_local(
os.getcwd() + "/table_info_db",
embeddings,
allow_dangerous_deserialization=True,
)
except:
documents = get_info_from_db()
db = FAISS.from_documents(documents, embeddings)
db.save_local(os.getcwd() + "/table_info_db")
print("table_info_db not found")
doc_res = db.similarity_search(state["messages"][-1].content)
documents_dict = {}

for doc in doc_res:
lines = doc.page_content.split("\n")

# 테이블명 및 설명 추출
table_name, table_desc = lines[0].split(": ", 1)

# 컬럼 정보 추출
columns = {}
if len(lines) > 2 and lines[1].strip() == "Columns:":
for line in lines[2:]:
if ": " in line:
col_name, col_desc = line.split(": ", 1)
columns[col_name.strip()] = col_desc.strip()

# 딕셔너리 저장
documents_dict[table_name] = {
"table_description": table_desc.strip(),
**columns, # 컬럼 정보 추가
}
# state의 use_rerank 값을 이용하여 검색 수행
documents_dict = search_tables(
state["messages"][0].content, use_rerank=state["use_rerank"]
)
state["searched_tables"] = documents_dict

return state
Expand Down Expand Up @@ -134,19 +104,16 @@ def query_maker_node_with_db_guide(state: QueryMakerState):

# StateGraph 생성 및 구성
builder = StateGraph(QueryMakerState)
builder.set_entry_point(QUERY_REFINER)
builder.set_entry_point(GET_TABLE_INFO)

# 노드 추가
builder.add_node(QUERY_REFINER, query_refiner_node)
builder.add_node(GET_TABLE_INFO, get_table_info_node)
# builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide
builder.add_node(
QUERY_MAKER, query_maker_node_with_db_guide
) # query_maker_node_with_db_guide
builder.add_node(QUERY_REFINER, query_refiner_node)
builder.add_node(QUERY_MAKER, query_maker_node_with_db_guide)

# 기본 엣지 설정
builder.add_edge(QUERY_REFINER, GET_TABLE_INFO)
builder.add_edge(GET_TABLE_INFO, QUERY_MAKER)
builder.add_edge(GET_TABLE_INFO, QUERY_REFINER)
builder.add_edge(QUERY_REFINER, QUERY_MAKER)

# QUERY_MAKER 노드 후 종료
builder.add_edge(QUERY_MAKER, END)
94 changes: 94 additions & 0 deletions llm_utils/retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from .tools import get_info_from_db


def get_vector_db():
"""벡터 데이터베이스를 로드하거나 생성합니다."""
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
try:
db = FAISS.load_local(
os.getcwd() + "/table_info_db",
embeddings,
allow_dangerous_deserialization=True,
)
except:
documents = get_info_from_db()
db = FAISS.from_documents(documents, embeddings)
db.save_local(os.getcwd() + "/table_info_db")
print("table_info_db not found")
return db


def load_reranker_model():
"""한국어 reranker 모델을 로드하거나 다운로드합니다."""
local_model_path = os.path.join(os.getcwd(), "ko_reranker_local")

# 로컬에 저장된 모델이 있으면 불러오고, 없으면 다운로드 후 저장
if os.path.exists(local_model_path) and os.path.isdir(local_model_path):
print("🔄 ko-reranker 모델 로컬에서 로드 중...")
else:
print("⬇️ ko-reranker 모델 다운로드 및 저장 중...")
model = AutoModelForSequenceClassification.from_pretrained(
"Dongjin-kr/ko-reranker"
)
tokenizer = AutoTokenizer.from_pretrained("Dongjin-kr/ko-reranker")
model.save_pretrained(local_model_path)
tokenizer.save_pretrained(local_model_path)

return HuggingFaceCrossEncoder(model_name=local_model_path)


def get_retriever(use_rerank=False):
"""검색기를 생성합니다. use_rerank가 True이면 reranking을 적용합니다."""
db = get_vector_db()
retriever = db.as_retriever(search_kwargs={"k": 10})

if use_rerank:
model = load_reranker_model()
compressor = CrossEncoderReranker(model=model, top_n=3)
return ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
else:
return retriever


def search_tables(query, use_rerank=False):
"""쿼리에 맞는 테이블 정보를 검색합니다."""
if use_rerank:
retriever = get_retriever(use_rerank=True)
doc_res = retriever.invoke(query)
else:
db = get_vector_db()
doc_res = db.similarity_search(query, k=10)

# 결과를 사전 형태로 변환
documents_dict = {}
for doc in doc_res:
lines = doc.page_content.split("\n")

# 테이블명 및 설명 추출
table_name, table_desc = lines[0].split(": ", 1)

# 컬럼 정보 추출
columns = {}
if len(lines) > 2 and lines[1].strip() == "Columns:":
for line in lines[2:]:
if ": " in line:
col_name, col_desc = line.split(": ", 1)
columns[col_name.strip()] = col_desc.strip()

# 딕셔너리 저장
documents_dict[table_name] = {
"table_description": table_desc.strip(),
**columns, # 컬럼 정보 추가
}

return documents_dict
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ pre_commit==4.1.0
setuptools
wheel
twine
langchain-huggingface==0.1.2
transformers==4.51.2
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
"streamlit==1.41.1",
"python-dotenv==1.0.1",
"faiss-cpu==1.10.0",
"transformers==4.51.2",
"langchain-huggingface==0.1.2",
],
entry_points={
"console_scripts": [
Expand Down
Binary file removed table_info_db/index.faiss
Binary file not shown.
Binary file removed table_info_db/index.pkl
Binary file not shown.