Skip to content

Commit

Permalink
Show logging at module level (#395)
Browse files Browse the repository at this point in the history
* add logging at generator_node

* add logging at the start of each module

* delete model after using UPR model

* add tqdm to rerankers for viewing progress easily

---------

Co-authored-by: jeffrey <[email protected]>
  • Loading branch information
vkehfdl1 and jeffrey authored Apr 29, 2024
1 parent 612cb0d commit 38e44c5
Showing 14 changed files with 37 additions and 9 deletions.
4 changes: 4 additions & 0 deletions autorag/nodes/generator/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import logging
from pathlib import Path
from typing import Union, Tuple, List

@@ -7,6 +8,8 @@
from autorag import generator_models
from autorag.utils import result_to_dataframe

logger = logging.getLogger("AutoRAG")


def generator_node(func):
@functools.wraps(func)
@@ -28,6 +31,7 @@ def wrapper(
:return: Pandas dataframe that contains generated texts, generated tokens, and generated log probs.
Each column is "generated_texts", "generated_tokens", and "generated_log_probs".
"""
logger.info(f"Running generator node - {func.__name__} module...")
assert 'prompts' in previous_result.columns, "previous_result must contain prompts column."
prompts = previous_result['prompts'].tolist()
if func.__name__ == 'llama_index_llm':
1 change: 1 addition & 0 deletions autorag/nodes/passageaugmenter/base.py
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@ def wrapper(
project_dir: Union[str, Path],
previous_result: pd.DataFrame,
*args, **kwargs) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
logger.info(f"Running passage augmenter node - {func.__name__} module...")
validate_qa_dataset(previous_result)
data_dir = os.path.join(project_dir, "data")

4 changes: 4 additions & 0 deletions autorag/nodes/passagecompressor/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import logging
from pathlib import Path
from typing import List, Union, Dict

@@ -8,6 +9,8 @@
from autorag import generator_models
from autorag.utils import result_to_dataframe

logger = logging.getLogger("AutoRAG")


def passage_compressor_node(func):
@functools.wraps(func)
@@ -16,6 +19,7 @@ def wrapper(
project_dir: Union[str, Path],
previous_result: pd.DataFrame,
*args, **kwargs) -> List[List[str]]:
logger.info(f"Running generator node - {func.__name__} module...")
assert all([column in previous_result.columns for column in
['query', 'retrieved_contents', 'retrieved_ids', 'retrieve_scores']]), \
"previous_result must have retrieved_contents, retrieved_ids, and retrieve_scores columns."
4 changes: 4 additions & 0 deletions autorag/nodes/passagefilter/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import logging
import os
from pathlib import Path
from typing import Union, Tuple, List
@@ -7,6 +8,8 @@

from autorag.utils import result_to_dataframe, validate_qa_dataset, fetch_contents

logger = logging.getLogger("AutoRAG")


# same with passage filter from now
def passage_filter_node(func):
@@ -16,6 +19,7 @@ def wrapper(
project_dir: Union[str, Path],
previous_result: pd.DataFrame,
*args, **kwargs) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
logger.info(f"Running passage filter node - {func.__name__} module...")
validate_qa_dataset(previous_result)

# find queries columns
1 change: 1 addition & 0 deletions autorag/nodes/passagereranker/base.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ def wrapper(
project_dir: Union[str, Path],
previous_result: pd.DataFrame,
*args, **kwargs) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
logger.info(f"Running passage reranker node - {func.__name__} module...")
validate_qa_dataset(previous_result)

# find queries columns
3 changes: 2 additions & 1 deletion autorag/nodes/passagereranker/colbert.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

from autorag.nodes.passagereranker.base import passage_reranker_node
@@ -68,7 +69,7 @@ def get_colbert_embedding_batch(input_strings: List[str],

input_batches = slice_tokenizer_result(encoding, batch_size)
result_embedding = []
for encoding in input_batches:
for encoding in tqdm(input_batches):
result_embedding.append(model(**encoding).last_hidden_state)
total_tensor = torch.cat(result_embedding, dim=0) # shape [batch_size, token_length, embedding_dim]
tensor_results = list(total_tensor.chunk(total_tensor.size()[0]))
3 changes: 2 additions & 1 deletion autorag/nodes/passagereranker/flag_embedding.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
import pandas as pd
import torch
from FlagEmbedding import FlagReranker
from tqdm import tqdm

from autorag.nodes.passagereranker.base import passage_reranker_node
from autorag.utils.util import make_batch, sort_by_scores, flatten_apply, select_top_k
@@ -54,7 +55,7 @@ def flag_embedding_reranker(queries: List[str], contents_list: List[List[str]],
def flag_embedding_run_model(input_texts, model, batch_size: int):
batch_input_texts = make_batch(input_texts, batch_size)
results = []
for batch_texts in batch_input_texts:
for batch_texts in tqdm(batch_input_texts):
with torch.no_grad():
pred_scores = model.compute_score(sentence_pairs=batch_texts)
if batch_size == 1:
3 changes: 2 additions & 1 deletion autorag/nodes/passagereranker/monot5.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

import pandas as pd
import torch
from tqdm import tqdm
from transformers import T5Tokenizer, T5ForConditionalGeneration

from autorag.nodes.passagereranker.base import passage_reranker_node
@@ -92,7 +93,7 @@ def monot5(queries: List[str], contents_list: List[List[str]],
def monot5_run_model(input_texts, model, batch_size: int, tokenizer, device, token_false_id, token_true_id):
batch_input_texts = make_batch(input_texts, batch_size)
results = []
for batch_texts in batch_input_texts:
for batch_texts in tqdm(batch_input_texts):
flattened_batch_texts = list(chain.from_iterable(batch_texts))
input_encodings = tokenizer(flattened_batch_texts, padding=True, truncation=True, max_length=512,
return_tensors='pt').to(
3 changes: 2 additions & 1 deletion autorag/nodes/passagereranker/sentence_transformer.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
import pandas as pd
import torch
from sentence_transformers import CrossEncoder
from tqdm import tqdm

from autorag.nodes.passagereranker.base import passage_reranker_node
from autorag.utils.util import flatten_apply, make_batch, select_top_k, sort_by_scores
@@ -54,7 +55,7 @@ def sentence_transformer_reranker(queries: List[str], contents_list: List[List[s
def sentence_transformer_run_model(input_texts, model, batch_size: int):
batch_input_texts = make_batch(input_texts, batch_size)
results = []
for batch_texts in batch_input_texts:
for batch_texts in tqdm(batch_input_texts):
with torch.no_grad():
pred_scores = model.predict(sentences=batch_texts, apply_softmax=True)
results.extend(pred_scores.tolist())
3 changes: 2 additions & 1 deletion autorag/nodes/passagereranker/tart/tart.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
import pandas as pd
import torch
import torch.nn.functional as F
from tqdm import tqdm

from autorag.nodes.passagereranker.base import passage_reranker_node
from autorag.nodes.passagereranker.tart.modeling_enc_t5 import EncT5ForSequenceClassification
@@ -68,7 +69,7 @@ def tart_run_model(input_texts, contents_list, model, batch_size: int, tokenizer
batch_input_texts = make_batch(flattened_texts, batch_size)
batch_contents_list = make_batch(flattened_contents, batch_size)
results = []
for batch_texts, batch_contents in zip(batch_input_texts, batch_contents_list):
for batch_texts, batch_contents in tqdm(zip(batch_input_texts, batch_contents_list)):
feature = tokenizer(batch_texts, batch_contents, padding=True, truncation=True,
return_tensors="pt").to(device)
with torch.no_grad():
11 changes: 7 additions & 4 deletions autorag/nodes/passagereranker/upr.py
Original file line number Diff line number Diff line change
@@ -53,17 +53,19 @@ def upr(queries: List[str], contents_list: List[List[str]],
})
ds = ray.data.from_pandas(df)

scorer = UPRScorer(suffix_prompt=suffix_prompt, prefix_prompt=prefix_prompt, use_bf16=use_bf16)

if torch.cuda.is_available():
score_batch = ds.map_batches(UPRScorer(suffix_prompt=suffix_prompt,
prefix_prompt=prefix_prompt, use_bf16=use_bf16), batch_size=1,
score_batch = ds.map_batches(scorer, batch_size=1,
concurrency=num_gpus, num_gpus=1)
else:
score_batch = ds.map_batches(UPRScorer(suffix_prompt=suffix_prompt,
prefix_prompt=prefix_prompt, use_bf16=use_bf16),
score_batch = ds.map_batches(scorer,
batch_size=1,
concurrency=min(len(df), os.cpu_count()), num_cpus=os.cpu_count())
scores = score_batch.to_pandas()['output'].tolist() # converted to a flatten list of scores

del scorer

explode_df = df.explode('contents')
explode_df['scores'] = scores
df['scores'] = explode_df.groupby(level=0, sort=False)['scores'].apply(list).tolist()
@@ -108,5 +110,6 @@ def __call__(self, batch: Dict[str, np.ndarray]):

def __del__(self):
del self.model
del self.tokenizer
if torch.cuda.is_available():
torch.cuda.empty_cache()
4 changes: 4 additions & 0 deletions autorag/nodes/promptmaker/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import functools
import logging
from pathlib import Path
from typing import List, Union

import pandas as pd

from autorag.utils import result_to_dataframe

logger = logging.getLogger("AutoRAG")


def prompt_maker_node(func):
@functools.wraps(func)
@@ -14,6 +17,7 @@ def wrapper(
project_dir: Union[str, Path],
previous_result: pd.DataFrame,
*args, **kwargs) -> List[str]:
logger.info(f"Running prompt maker node - {func.__name__} module...")
# get query and retrieved contents from previous_result
assert "query" in previous_result.columns, "previous_result must have query column."
assert "retrieved_contents" in previous_result.columns, "previous_result must have retrieved_contents column."
1 change: 1 addition & 0 deletions autorag/nodes/queryexpansion/base.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ def wrapper(
project_dir: Union[str, Path],
previous_result: pd.DataFrame,
*args, **kwargs) -> List[List[str]]:
logger.info(f"Running query expansion node - {func.__name__} module...")
validate_qa_dataset(previous_result)

# find queries columns
1 change: 1 addition & 0 deletions autorag/nodes/retrieval/base.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@ def wrapper(
project_dir: Union[str, Path],
previous_result: pd.DataFrame,
**kwargs) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
logger.info(f"Running retrieval node - {func.__name__} module...")
validate_qa_dataset(previous_result)
resources_dir = os.path.join(project_dir, "resources")
data_dir = os.path.join(project_dir, "data")

0 comments on commit 38e44c5

Please sign in to comment.