Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Building retrievers with multiple docs (web pages or files) #13

Merged
merged 4 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions assets/images/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions denser_retriever/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from pydantic import BaseModel
import yaml
import os
from dotenv import load_dotenv


class RetrieverSettings(BaseModel):
Expand All @@ -26,6 +28,18 @@ def from_yaml(yaml_file: str = "config.yaml") -> RetrieverSettings:

def _from_yaml(yaml_file: str) -> RetrieverSettings:
data = yaml.safe_load(open(yaml_file))
# Load environment variables
load_dotenv()
data["keyword"]["es_host"] = os.getenv("ES_HOST", data["keyword"]["es_host"])
data["keyword"]["es_passwd"] = os.getenv(
"ES_PASSWD", data["keyword"]["es_passwd"]
)
data["vector"]["milvus_host"] = os.getenv(
"MILVUS_HOST", data["vector"]["milvus_host"]
)
data["vector"]["milvus_passwd"] = os.getenv(
"MILVUS_PASSWD", data["vector"]["milvus_passwd"]
)
return RetrieverSettings(**data)


Expand Down
20 changes: 17 additions & 3 deletions experiments/index_and_query_from_docs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
from langchain_community.document_loaders import TextLoader
import os
from langchain_text_splitters import RecursiveCharacterTextSplitter
from denser_retriever.utils import save_HF_docs_as_denser_passages
from denser_retriever.retriever_general import RetrieverGeneral
from utils_data import load_document
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Generate text chunks
documents = TextLoader("tests/test_data/state_of_the_union.txt").load()
file_paths = [
"tests/test_data/state_of_the_union.txt",
"tests/test_data/dpr.pdf",
"https://example.com/index.html",
]
documents = []
for file_path in file_paths:
documents.extend(load_document(file_path))

text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
texts = text_splitter.split_documents(documents)
passage_file = "passages.jsonl"
Expand All @@ -19,4 +32,5 @@
# Query
query = "What did the president say about Ketanji Brown Jackson"
passages, docs = retriever_denser.retrieve(query, {})
print(passages)
logger.info(passages)
os.remove(passage_file)
27 changes: 13 additions & 14 deletions experiments/index_and_query_from_webpage.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import bs4
from langchain_community.document_loaders import WebBaseLoader
import os
from langchain_text_splitters import RecursiveCharacterTextSplitter
from denser_retriever.utils import save_HF_docs_as_denser_passages
from denser_retriever.retriever_general import RetrieverGeneral
from utils_data import CustomWebBaseLoader
import logging

# Load, chunk and index the contents of the blog to create a retriever.
loader = WebBaseLoader(
web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
bs_kwargs=dict(
parse_only=bs4.SoupStrainer(
class_=("post-content", "post-title", "post-header")
)
),
)
docs = loader.load()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load, chunk and index the contents of all webpages under an url to create a retriever.
base_url = "https://denser.ai"
loader = CustomWebBaseLoader(base_url)
docs = loader.load()
logger.info(f"Total docs: {len(docs)}")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(docs)
passage_file = "passages.jsonl"
Expand All @@ -25,6 +23,7 @@
retriever_denser.ingest(passage_file)

# Query
query = "What is Task Decomposition?"
query = "What use cases does Denser AI support?"
passages, docs = retriever_denser.retrieve(query, {})
print(passages)
logger.info(passages)
os.remove(passage_file)
65 changes: 64 additions & 1 deletion experiments/utils_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,70 @@
import os

import logging
import requests
from bs4 import BeautifulSoup
from langchain_community.document_loaders import WebBaseLoader, TextLoader, PyPDFLoader
from denser_retriever.utils import standardize_normalize, min_max_normalize

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def get_all_urls_from_domain(base_url):
"""Recursively get all URLs under the given domain."""
urls_to_visit = {base_url}
visited_urls = set()
domain_urls = set()

while urls_to_visit:
url = urls_to_visit.pop()
if url in visited_urls:
continue

visited_urls.add(url)
try:
response = requests.get(url)
if response.status_code == 200:
domain_urls.add(url)
soup = BeautifulSoup(response.content, "html.parser")
logger.info(f"Processing URL: {url}")
for link in soup.find_all("a", href=True):
full_url = requests.compat.urljoin(base_url, link["href"])
if base_url in full_url and full_url not in visited_urls:
urls_to_visit.add(full_url)
except requests.RequestException as e:
print(f"Failed to fetch {url}: {e}")

return domain_urls


class CustomWebBaseLoader(WebBaseLoader):
def __init__(self, base_url):
self.base_url = base_url
self.urls = get_all_urls_from_domain(base_url)

def load(self):
all_docs = []
for url in self.urls:
loader = WebBaseLoader(url)
docs = loader.load()
all_docs.extend(docs)
return all_docs


# Define a function to load documents based on file extension
def load_document(file_path):
_, file_extension = os.path.splitext(file_path)
if file_extension in [".txt", ".csv", ".tsv"]:
loader = TextLoader(file_path)
elif file_extension == ".pdf":
loader = PyPDFLoader(file_path)
elif file_extension in [".html", ".htm"]:
loader = WebBaseLoader(file_path)
else:
raise ValueError(f"Unsupported file format: {file_extension}")

return loader.load()


def save_data(
group_data, output_feature, output_group, features, features_to_normalize
Expand Down
Loading
Loading