Skip to content

Commit

Permalink
add demojize with emoji package (#935)
Browse files Browse the repository at this point in the history
Co-authored-by: Bwook (Byoungwook) Kim <[email protected]>
  • Loading branch information
rjwharry and bwook00 authored Nov 10, 2024
1 parent 4cb8835 commit c5fcddc
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 10 deletions.
10 changes: 5 additions & 5 deletions autorag/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pandas as pd

from autorag.utils.util import normalize_unicode
from autorag.utils.util import preprocess_text


def validate_qa_dataset(df: pd.DataFrame):
Expand Down Expand Up @@ -60,9 +60,9 @@ def cast_generation_gt(gt):
), "query must be string type."
df["retrieval_gt"] = df["retrieval_gt"].apply(cast_retrieval_gt)
df["generation_gt"] = df["generation_gt"].apply(cast_generation_gt)
df["query"] = df["query"].apply(normalize_unicode)
df["query"] = df["query"].apply(preprocess_text)
df["generation_gt"] = df["generation_gt"].apply(
lambda x: list(map(normalize_unicode, x))
lambda x: list(map(preprocess_text, x))
)
return df

Expand Down Expand Up @@ -104,13 +104,13 @@ def make_prev_next_id_metadata(x, id_type: str):
lambda x: make_prev_next_id_metadata(x, "next_id")
)

df["contents"] = df["contents"].apply(normalize_unicode)
df["contents"] = df["contents"].apply(preprocess_text)

def normalize_unicode_metadata(metadata: dict):
result = {}
for key, value in metadata.items():
if isinstance(value, str):
result[key] = normalize_unicode(value)
result[key] = preprocess_text(value)
else:
result[key] = value
return result
Expand Down
13 changes: 11 additions & 2 deletions autorag/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import List, Callable, Dict, Optional, Any, Collection, Iterable

from asyncio import AbstractEventLoop
import emoji
import numpy as np
import pandas as pd
import tiktoken
Expand Down Expand Up @@ -468,6 +469,14 @@ def find_node_summary_files(trial_dir: str) -> List[str]:
return filtered_files


def preprocess_text(text: str) -> str:
return normalize_unicode(demojize(text))


def demojize(text: str) -> str:
return emoji.demojize(text)


def normalize_unicode(text: str) -> str:
return unicodedata.normalize("NFC", text)

Expand Down Expand Up @@ -703,10 +712,10 @@ def decode_multiple_json_from_bytes(byte_data: bytes) -> list:
Decode multiple JSON objects from bytes received from SSE server.
Args:
byte_data: Bytes containing one or more JSON objects
byte_data: Bytes containing one or more JSON objects
Returns:
List of decoded JSON objects
List of decoded JSON objects
"""
# Decode bytes to string
try:
Expand Down
8 changes: 5 additions & 3 deletions autorag/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,17 @@ def validate(self, yaml_path: str, qa_cnt: int = 5, random_state: int = 42):
# Determine the sample size and log a warning if qa_cnt is larger than available records
available_records = len(self.qa_data)
safe_sample_size = min(qa_cnt, available_records) # 먼저 safe_sample_size 계산

if safe_sample_size < qa_cnt:
logger.warning(
f"Minimal Requested sample size ({qa_cnt}) is larger than available records ({available_records}). "
f"Sampling will be limited to {safe_sample_size} records. "
)

# safe sample QA data
sample_qa_df = self.qa_data.sample(n=safe_sample_size, random_state=random_state)
sample_qa_df = self.qa_data.sample(
n=safe_sample_size, random_state=random_state
)
sample_qa_df.reset_index(drop=True, inplace=True)

# get doc_id
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ voyageai # for voyageai reranker
mixedbread-ai # for mixedbread-ai reranker
llama-index-llms-bedrock
scikit-learn
emoji

### Vector DB ###
pymilvus # for using milvus vectordb
Expand Down
25 changes: 25 additions & 0 deletions tests/autorag/utils/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
find_trial_dir,
find_node_summary_files,
normalize_unicode,
demojize,
preprocess_text,
dict_to_markdown,
dict_to_markdown_table,
convert_inputs_to_list,
Expand Down Expand Up @@ -425,6 +427,12 @@ def test_find_node_summary_files():
assert all(os.path.basename(path) == "summary.csv" for path in node_summary_paths)


def test_demojize():
str = "👍엄지엄지척"
new_str = demojize(str)
assert new_str == ":thumbs_up:엄지엄지척"


def test_normalize_unicode():
str1 = "전국보행자전용도로표준데이터"
str2 = "전국보행자전용도로표준데이터"
Expand All @@ -440,6 +448,23 @@ def test_normalize_unicode():
assert new_str1 == new_str2


def test_preprocess():
str1 = (
"👍전국보행자전용도로표준데이터👍" # ":thumbs_up:" is added on both sides + 22
)
str2 = "👍전국보행자전용도로표준데이터👍"
assert len(str1) == 16
assert len(str2) == 36
assert str1 != str2

new_str1 = preprocess_text(str1)
new_str2 = preprocess_text(str2)

assert len(new_str1) == 36
assert len(new_str2) == 36
assert new_str1 == new_str2


def test_dict_to_markdown():
data = {
"Title": "Sample Document",
Expand Down

0 comments on commit c5fcddc

Please sign in to comment.