Skip to content

Commit

Permalink
Merge pull request #27 from epinzur/crag3
Browse files Browse the repository at this point in the history
more crag fixes
  • Loading branch information
epinzur authored Jun 25, 2024
2 parents 83c8bc0 + 368288e commit f1a8b18
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 43 deletions.
42 changes: 31 additions & 11 deletions ragulate/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,37 @@ def get_source_file_paths(self) -> List[str]:
def get_queries_and_golden_set(self) -> Tuple[List[str], List[Dict[str, str]]]:
"""gets a list of queries and golden_truth answers for a dataset"""

async def _download_file(self, session, url, temp_file_path):
async with session.get(url) as response:
file_size = int(response.headers.get('Content-Length', 0))
async def _download_file(
self, session: aiohttp.ClientSession, url: str, temp_file_path: str
) -> None:
timeout = aiohttp.ClientTimeout(total=6000)
async with session.get(url, timeout=timeout) as response:
file_size = int(response.headers.get("Content-Length", 0))
chunk_size = 1024
with tqdm(total=file_size, unit='B', unit_scale=True, desc=f'Downloading {url.split("/")[-1]}') as progress_bar:
async with aiofiles.open(temp_file_path, 'wb') as temp_file:
with tqdm(
total=file_size,
unit="B",
unit_scale=True,
desc=f'Downloading {url.split("/")[-1]}',
) as progress_bar:
async with aiofiles.open(temp_file_path, "wb") as temp_file:
async for chunk in response.content.iter_chunked(chunk_size):
await temp_file.write(chunk)
progress_bar.update(len(chunk))

async def _decompress_file(self, temp_file_path, output_file_path):
async def _decompress_file(
self, temp_file_path: str, output_file_path: str
) -> None:
makedirs(path.dirname(output_file_path), exist_ok=True)
with open(temp_file_path, 'rb') as temp_file:
with open(temp_file_path, "rb") as temp_file:
decompressed_size = 0
with bz2.BZ2File(temp_file, 'rb') as bz2_file:
async with aiofiles.open(output_file_path, 'wb') as output_file:
with tqdm(unit='B', unit_scale=True, desc=f'Decompressing {output_file_path}') as progress_bar:
with bz2.BZ2File(temp_file, "rb") as bz2_file:
async with aiofiles.open(output_file_path, "wb") as output_file:
with tqdm(
unit="B",
unit_scale=True,
desc=f"Decompressing {output_file_path}",
) as progress_bar:
while True:
chunk = bz2_file.read(1024)
if not chunk:
Expand All @@ -83,7 +97,13 @@ async def _decompress_file(self, temp_file_path, output_file_path):
decompressed_size += len(chunk)
progress_bar.update(len(chunk))

async def _download_and_decompress(self, url, output_file_path):
async def _download_and_decompress(
self, url: str, output_file_path: str, force: bool
) -> None:
if not force and path.exists(output_file_path):
print(f"File {output_file_path} already exists. Skipping download.")
return

async with aiohttp.ClientSession() as session:
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file_path = temp_file.name
Expand Down
45 changes: 24 additions & 21 deletions ragulate/datasets/crag_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def download_dataset(self) -> None:
path.join(self.storage_path(), "parsed_documents.jsonl"),
path.join(self.storage_path(), "questions.jsonl"),
]
tasks = [self._download_and_decompress(url, output_file) for url, output_file in zip(urls, output_files)]
tasks = [
self._download_and_decompress(url, output_file)
for url, output_file in zip(urls, output_files)
]
asyncio.run(asyncio.gather(*tasks))
else:
raise NotImplementedError(f"Crag download not supported for {self.name}")
Expand All @@ -52,27 +55,27 @@ def get_queries_and_golden_set(self) -> Tuple[List[str], List[Dict[str, str]]]:
queries: List[str] = []
golden_set: List[Dict[str, str]] = []

if len(self.subsets) == 0:
subset_kinds = self._subset_kinds
else:
subset_kinds = []
for subset in self.subsets:
if subset not in self._subset_kinds:
raise ValueError(f"Subset: {subset} doesn't exist in dataset {self.name}. Choices are {self._subset_kinds}")
subset_kinds.append(subset)
for subset in self.subsets:
if subset not in self._subset_kinds:
raise ValueError(
f"Subset: {subset} doesn't exist in dataset {self.name}. Choices are {self._subset_kinds}"
)

json_path = path.join(self.storage_path(), f"questions.jsonl")
with open(json_path, "r") as f:
for line in f:
data = json.loads(line.strip())
kind = data.get("question_type")

if len(self.subsets) > 0 and kind not in self.subsets:
continue

for subset in subset_kinds:
query = data.get("query")
answer = data.get("answer")
if query is not None and answer is not None:
queries.append(query)
golden_set.append({"query": query, "response": answer})

json_path = path.join(
self.storage_path(), f"{subset}.jsonl"
)
with open(json_path, "r") as f:
for line in f:
data = json.loads(line.strip())
query = data.get("query")
answer = data.get("answer")
if query is not None:
queries.append(query)
golden_set.append({"query": query, "response": answer})
print(f"found {len(queries)} for subsets: {self.subsets}")

return queries, golden_set
13 changes: 5 additions & 8 deletions ragulate/datasets/llama_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def __init__(
):
super().__init__(dataset_name=dataset_name, root_storage_path=root_storage_path)
self._llama_datasets_lfs_url: str = LLAMA_DATASETS_LFS_URL
self._llama_datasets_source_files_tree_url: str = LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL
self._llama_datasets_source_files_tree_url: str = (
LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL
)

def sub_storage_path(self) -> str:
return "llama"
Expand All @@ -47,7 +49,6 @@ def download_by_name(name):
load_documents=False,
)


# to conform with naming scheme at LlamaHub
name = self.name
try:
Expand All @@ -65,16 +66,12 @@ def download_by_name(name):

def get_source_file_paths(self) -> List[str]:
"""gets a list of source file paths for for a dataset"""
source_path = path.join(
self._get_dataset_path(), "source_files"
)
source_path = path.join(self._get_dataset_path(), "source_files")
return self.list_files_at_path(path=source_path)

def get_queries_and_golden_set(self) -> Tuple[List[str], List[Dict[str, str]]]:
"""gets a list of queries and golden_truth answers for a dataset"""
json_path = path.join(
self._get_dataset_path(), "rag_dataset.json"
)
json_path = path.join(self._get_dataset_path(), "rag_dataset.json")
with open(json_path, "r") as f:
examples = json.load(f)["examples"]
queries = [e["query"] for e in examples]
Expand Down
6 changes: 3 additions & 3 deletions ragulate/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .llama_dataset import LlamaDataset


def find_dataset(name:str) -> BaseDataset:
def find_dataset(name: str) -> BaseDataset:
root_path = "datasets"
name = name.lower()
for kind in os.listdir(root_path):
Expand All @@ -21,12 +21,12 @@ def find_dataset(name:str) -> BaseDataset:
""" searches for a downloaded dataset with this name. if found, returns it."""
return get_dataset(name, "llama")

def get_dataset(name:str, kind:str) -> BaseDataset:

def get_dataset(name: str, kind: str) -> BaseDataset:
kind = kind.lower()
if kind == "llama":
return LlamaDataset(dataset_name=name)
elif kind == "crag":
return CragDataset(dataset_name=name)

raise NotImplementedError("only llama and crag datasets are currently supported")

0 comments on commit f1a8b18

Please sign in to comment.