diff --git a/ko_lm_dataformat/archive.py b/ko_lm_dataformat/archive.py index 4bbe6ad..eb28d05 100644 --- a/ko_lm_dataformat/archive.py +++ b/ko_lm_dataformat/archive.py @@ -55,16 +55,17 @@ def set_chunk_name(self): def add_data( self, - data: Union[str, List[str]], + data: Union[str, List, dict], meta: Optional[Dict] = None, split_sent: bool = False, clean_sent: bool = False, ): """ Args: - data (Union[str, List[str]]): + data (Union[str, List, dict]): - Simple text - - List of text (multiple sentences) + - List of text (multiple sentences), List of dict + - Json style data meta (Dict, optional): metadata. Defaults to None. split_sent (bool): Whether to split text into sentences clean_sent (bool): Whether to clean text (NFC, remove control char etc.) @@ -73,7 +74,7 @@ def add_data( meta = {} if split_sent: assert self.sentence_splitter - assert type(data) is not list # Shouldn't be List[str] + assert type(data) is str data = self.sentence_splitter.split(data, clean_sent=clean_sent) if clean_sent and type(data) is str: @@ -133,10 +134,10 @@ def __init__(self, out_dir: str, sentence_splitter: Optional[SentenceSplitterBas self.sentence_splitter = sentence_splitter - def add_data(self, data: Union[str, List[str]], split_sent: bool = False, clean_sent: bool = False): + def add_data(self, data: Union[str, List, dict], split_sent: bool = False, clean_sent: bool = False): if split_sent: assert self.sentence_splitter - assert type(data) is str # Shouldn't be List[str] + assert type(data) is str data = self.sentence_splitter.split(data, clean_sent=clean_sent) self.data.append(data) @@ -191,7 +192,7 @@ def __init__(self, out_dir: str, sentence_splitter: Optional[SentenceSplitterBas self.sentence_splitter = sentence_splitter - def add_data(self, data: Union[str, List[str]], split_sent: bool = False, clean_sent: bool = False): + def add_data(self, data: Union[str, List, dict], split_sent: bool = False, clean_sent: bool = False): if split_sent: assert self.sentence_splitter assert type(data) is str # Shouldn't be List[str] diff --git a/tests/assets/sample.json b/tests/assets/sample.json new file mode 100644 index 0000000..08e617d --- /dev/null +++ b/tests/assets/sample.json @@ -0,0 +1,4 @@ +{ + "sentence_a": "안녕하세요", + "sentence_b": "반가워요" +} diff --git a/tests/test_archive.py b/tests/test_archive.py index 908ac05..073776d 100644 --- a/tests/test_archive.py +++ b/tests/test_archive.py @@ -1,3 +1,4 @@ +import json import shutil import pytest @@ -16,6 +17,28 @@ def test_kowiki_archive(): shutil.rmtree(TMP_DIR_NAME) +def test_archive_json_data(): + remove_tmp_dir() + archive = kldf.Archive(TMP_DIR_NAME) + with open(get_tests_dir(append_path="assets/sample.json"), "r", encoding="utf-8") as f: + data = json.load(f) + archive.add_data(data) + + shutil.rmtree(TMP_DIR_NAME) + + +def test_archive_json_data_is_same(): + remove_tmp_dir() + archive = kldf.Archive(TMP_DIR_NAME) + with open(get_tests_dir(append_path="assets/sample.json"), "r", encoding="utf-8") as f: + orig_data = json.load(f) + archive.add_data(orig_data) + archive.commit() + reader = kldf.Reader(TMP_DIR_NAME) + kldf_data = list(reader.stream_data()) + assert kldf_data[0] == orig_data + + def test_kor_str_is_same(): remove_tmp_dir() archive = kldf.Archive(TMP_DIR_NAME)