Skip to content

Commit

Permalink
feat: add_data에서 dict도 허용
Browse files Browse the repository at this point in the history
* test code도 추가
  • Loading branch information
monologg committed Jan 15, 2024
1 parent a3ee220 commit bf90a48
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 7 deletions.
15 changes: 8 additions & 7 deletions ko_lm_dataformat/archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,17 @@ def set_chunk_name(self):

def add_data(
self,
data: Union[str, List[str]],
data: Union[str, List, dict],
meta: Optional[Dict] = None,
split_sent: bool = False,
clean_sent: bool = False,
):
"""
Args:
data (Union[str, List[str]]):
data (Union[str, List, dict]):
- Simple text
- List of text (multiple sentences)
- List of text (multiple sentences), List of dict
- Json style data
meta (Dict, optional): metadata. Defaults to None.
split_sent (bool): Whether to split text into sentences
clean_sent (bool): Whether to clean text (NFC, remove control char etc.)
Expand All @@ -73,7 +74,7 @@ def add_data(
meta = {}
if split_sent:
assert self.sentence_splitter
assert type(data) is not list # Shouldn't be List[str]
assert type(data) is str
data = self.sentence_splitter.split(data, clean_sent=clean_sent)

if clean_sent and type(data) is str:
Expand Down Expand Up @@ -133,10 +134,10 @@ def __init__(self, out_dir: str, sentence_splitter: Optional[SentenceSplitterBas

self.sentence_splitter = sentence_splitter

def add_data(self, data: Union[str, List[str]], split_sent: bool = False, clean_sent: bool = False):
def add_data(self, data: Union[str, List, dict], split_sent: bool = False, clean_sent: bool = False):
if split_sent:
assert self.sentence_splitter
assert type(data) is str # Shouldn't be List[str]
assert type(data) is str
data = self.sentence_splitter.split(data, clean_sent=clean_sent)

self.data.append(data)
Expand Down Expand Up @@ -191,7 +192,7 @@ def __init__(self, out_dir: str, sentence_splitter: Optional[SentenceSplitterBas

self.sentence_splitter = sentence_splitter

def add_data(self, data: Union[str, List[str]], split_sent: bool = False, clean_sent: bool = False):
def add_data(self, data: Union[str, List, dict], split_sent: bool = False, clean_sent: bool = False):
if split_sent:
assert self.sentence_splitter
assert type(data) is str # Shouldn't be List[str]
Expand Down
4 changes: 4 additions & 0 deletions tests/assets/sample.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"sentence_a": "안녕하세요",
"sentence_b": "반가워요"
}
23 changes: 23 additions & 0 deletions tests/test_archive.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import shutil

import pytest
Expand All @@ -16,6 +17,28 @@ def test_kowiki_archive():
shutil.rmtree(TMP_DIR_NAME)


def test_archive_json_data():
remove_tmp_dir()
archive = kldf.Archive(TMP_DIR_NAME)
with open(get_tests_dir(append_path="assets/sample.json"), "r", encoding="utf-8") as f:
data = json.load(f)
archive.add_data(data)

shutil.rmtree(TMP_DIR_NAME)


def test_archive_json_data_is_same():
remove_tmp_dir()
archive = kldf.Archive(TMP_DIR_NAME)
with open(get_tests_dir(append_path="assets/sample.json"), "r", encoding="utf-8") as f:
orig_data = json.load(f)
archive.add_data(orig_data)
archive.commit()
reader = kldf.Reader(TMP_DIR_NAME)
kldf_data = list(reader.stream_data())
assert kldf_data[0] == orig_data


def test_kor_str_is_same():
remove_tmp_dir()
archive = kldf.Archive(TMP_DIR_NAME)
Expand Down

0 comments on commit bf90a48

Please sign in to comment.