Skip to content

Commit

Permalink
llama hybrid retriever nb updates
Browse files Browse the repository at this point in the history
  • Loading branch information
joshreini1 committed Feb 23, 2024
1 parent 5004bfa commit ae7d1ac
Showing 1 changed file with 21 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"metadata": {},
"outputs": [],
"source": [
"# ! pip install trulens_eval==0.24.0 llama_index==0.10.11 openai pypdf torch sentence-transformers"
"# ! pip install trulens_eval==0.24.0 llama_index==0.10.11 llama-index-readers-file llama-index-llms-openai llama-index-retrievers-bm25 openai pypdf torch sentence-transformers"
]
},
{
Expand Down Expand Up @@ -82,37 +82,31 @@
"outputs": [],
"source": [
"from llama_index.core import (\n",
" VectorStoreIndex,\n",
" StorageContext,\n",
" SimpleDirectoryReader,\n",
" StorageContext,\n",
" VectorStoreIndex,\n",
")\n",
"from llama_index.legacy import ServiceContext\n",
"from llama_index.retrievers.bm25 import BM25Retriever\n",
"from llama_index.core.retrievers import VectorIndexRetriever\n",
"from llama_index.core.node_parser import SentenceSplitter\n",
"from llama_index.llms.openai import OpenAI\n",
"\n",
"splitter = SentenceSplitter(chunk_size=1024)\n",
"\n",
"# load documents\n",
"documents = SimpleDirectoryReader(\n",
" input_files=[\"IPCC_AR6_WGII_Chapter03.pdf\"]\n",
").load_data()\n",
"\n",
"# initialize service context (set chunk size)\n",
"# -- here, we set a smaller chunk size, to allow for more effective re-ranking\n",
"llm = OpenAI(model=\"gpt-3.5-turbo\")\n",
"service_context = ServiceContext.from_defaults(chunk_size=256, llm=llm)\n",
"nodes = service_context.node_parser.get_nodes_from_documents(documents)\n",
"nodes = splitter.get_nodes_from_documents(documents)\n",
"\n",
"# initialize storage context (by default it's in-memory)\n",
"storage_context = StorageContext.from_defaults()\n",
"storage_context.docstore.add_documents(nodes)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"storage_context.docstore.add_documents(nodes)\n",
"\n",
"index = VectorStoreIndex(\n",
" nodes, storage_context=storage_context, service_context=service_context\n",
" nodes=nodes,\n",
" storage_context=storage_context,\n",
")"
]
},
Expand All @@ -129,13 +123,11 @@
"metadata": {},
"outputs": [],
"source": [
"from llama_index.retrievers import BM25Retriever\n",
"\n",
"# retireve the top 10 most similar nodes using embeddings\n",
"vector_retriever = index.as_retriever(similarity_top_k=10)\n",
"vector_retriever = VectorIndexRetriever(index)\n",
"\n",
"# retireve the top 10 most similar nodes using bm25\n",
"bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=10)"
"bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=2)"
]
},
{
Expand All @@ -151,7 +143,7 @@
"metadata": {},
"outputs": [],
"source": [
"from llama_index.retrievers import BaseRetriever\n",
"from llama_index.core.retrievers import BaseRetriever\n",
"\n",
"class HybridRetriever(BaseRetriever):\n",
" def __init__(self, vector_retriever, bm25_retriever):\n",
Expand Down Expand Up @@ -190,7 +182,7 @@
"metadata": {},
"outputs": [],
"source": [
"from llama_index.postprocessor import SentenceTransformerRerank\n",
"from llama_index.core.postprocessor import SentenceTransformerRerank\n",
"\n",
"reranker = SentenceTransformerRerank(top_n=4, model=\"BAAI/bge-reranker-base\")"
]
Expand All @@ -201,12 +193,11 @@
"metadata": {},
"outputs": [],
"source": [
"from llama_index.query_engine import RetrieverQueryEngine\n",
"from llama_index.core.query_engine import RetrieverQueryEngine\n",
"\n",
"query_engine = RetrieverQueryEngine.from_args(\n",
" retriever=hybrid_retriever,\n",
" node_postprocessors=[reranker],\n",
" service_context=service_context,\n",
" node_postprocessors=[reranker]\n",
")"
]
},
Expand All @@ -216,7 +207,7 @@
"metadata": {},
"outputs": [],
"source": [
"tru.start_dashboard()"
"tru.run_dashboard()"
]
},
{
Expand Down Expand Up @@ -307,9 +298,7 @@
"outputs": [],
"source": [
"with tru_recorder as recording:\n",
" response = query_engine.query(\n",
" \"What is the impact of climate change on the ocean?\"\n",
")"
" response = query_engine.query(\"What is the impact of climate change on the ocean?\")"
]
},
{
Expand Down

0 comments on commit ae7d1ac

Please sign in to comment.