Skip to content

Commit

Permalink
fixing types
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsbatista committed Dec 19, 2024
1 parent 211c4ed commit 0807902
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 40 deletions.
65 changes: 49 additions & 16 deletions haystack/components/preprocessors/recursive_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ class RecursiveDocumentSplitter:
>]
""" # noqa: E501

def __init__( # pylint: disable=too-many-positional-arguments
def __init__(
self,
*,
split_length: int = 200,
split_overlap: int = 0,
split_unit: Literal["word", "char"] = "word",
Expand All @@ -71,6 +72,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
text will be split into sentences using a custom sentence tokenizer based on NLTK.
See: haystack.components.preprocessors.sentence_tokenizer.SentenceSplitter.
If no separators are provided, the default separators ["\n\n", "sentence", "\n", " "] are used.
:param sentence_splitter_params: Optional parameters to pass to the sentence tokenizer.
See: haystack.components.preprocessors.sentence_tokenizer.SentenceSplitter for more information.
:raises ValueError: If the overlap is greater than or equal to the chunk size or if the overlap is negative, or
if any separator is not a string.
Expand All @@ -81,9 +84,20 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.separators = separators if separators else ["\n\n", "sentence", "\n", " "] # default separators
self.sentence_tokenizer_params = sentence_splitter_params
self._check_params()
self.nltk_tokenizer = None
if "sentence" in self.separators:
sentence_splitter_params = sentence_splitter_params or {"keep_white_spaces": True}
self.nltk_tokenizer = self._get_custom_sentence_tokenizer(sentence_splitter_params)
self.warm_up(sentence_splitter_params)

def warm_up(self, sentence_splitter_params):
"""
Warm up the sentence tokenizer.
:param sentence_splitter_params: Optional parameters to pass to the sentence tokenizer.
:returns:
An instance of the SentenceSplitter.
"""
sentence_splitter_params = sentence_splitter_params or {"keep_white_spaces": True}
self.nltk_tokenizer = self._get_custom_sentence_tokenizer(sentence_splitter_params)

def _check_params(self):
if self.split_length < 1:
Expand Down Expand Up @@ -162,7 +176,8 @@ def _chunk_text(self, text: str) -> List[str]:

for curr_separator in self.separators: # type: ignore # the caller already checked that separators is not None
if curr_separator == "sentence":
sentence_with_spans = self.nltk_tokenizer.split_sentences(text)
# correct SentenceSplitter initialization is checked at the initialization of the component
sentence_with_spans = self.nltk_tokenizer.split_sentences(text) # type: ignore
splits = [sentence["sentence"] for sentence in sentence_with_spans]
else:
escaped_separator = re.escape(curr_separator)
Expand Down Expand Up @@ -221,19 +236,37 @@ def _chunk_text(self, text: str) -> List[str]:
if chunks:
return chunks

# if no separator worked, fall back to character- or word-level chunking
# ToDo: refactor into a function making use of split_unit parameter that can be easily tested in isolation
# if no separator worked, fall back to word- or character-level chunking
if self.split_units == "word":
return [
" ".join(text.split()[i : i + self.split_length])
for i in range(0, self._chunk_length(text), self.split_length - self.split_overlap)
]
return self.fall_back_to_word_level_chunking(text)

if self.split_units == "char":
return [
text[i : i + self.split_length]
for i in range(0, self._chunk_length(text), self.split_length - self.split_overlap)
]
return self.fall_back_to_char_level_chunking(text)

def fall_back_to_word_level_chunking(self, text: str) -> List[str]:
"""
Fall back to word-level chunking if no separator works.
:param text: The text to be split into chunks.
:returns:
A list of text chunks.
"""
return [
" ".join(text.split()[i : i + self.split_length])
for i in range(0, self._chunk_length(text), self.split_length - self.split_overlap)
]

def fall_back_to_char_level_chunking(self, text: str) -> List[str]:
"""
Fall back to character-level chunking if no separator works.
:param text: The text to be split into chunks.
:returns:
A list of text chunks.
"""
return [
text[i : i + self.split_length]
for i in range(0, self._chunk_length(text), self.split_length - self.split_overlap)
]

def _add_overlap_info(self, curr_pos: int, new_doc: Document, new_docs: List[Document]) -> None:
prev_doc = new_docs[-1]
Expand All @@ -244,7 +277,7 @@ def _add_overlap_info(self, curr_pos: int, new_doc: Document, new_docs: List[Doc
{
"doc_id": prev_doc.id,
"range": (
self._chunk_length(prev_doc.content) - overlap_length,
self._chunk_length(prev_doc.content) - overlap_length, # type: ignore
self._chunk_length(prev_doc.content), # type: ignore
),
}
Expand Down
57 changes: 33 additions & 24 deletions test/components/preprocessors/test_recursive_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@ def test_init_with_negative_split_length():

def test_apply_overlap_no_overlap():
# Test the case where there is no overlap between chunks
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=0, separators=["."])
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=0, separators=["."], split_unit="char")
chunks = ["chunk1", "chunk2", "chunk3"]
result = splitter._apply_overlap(chunks)
assert result == ["chunk1", "chunk2", "chunk3"]


def test_apply_overlap_with_overlap():
# Test the case where there is overlap between chunks
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=4, separators=["."])
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=4, separators=["."], split_unit="char")
chunks = ["chunk1", "chunk2", "chunk3"]
result = splitter._apply_overlap(chunks)
assert result == ["chunk1", "unk1chunk2", "unk2chunk3"]


def test_apply_overlap_with_overlap_capturing_completely_previous_chunk(caplog):
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=6, separators=["."])
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=6, separators=["."], split_unit="char")
chunks = ["chunk1", "chunk2", "chunk3", "chunk4"]
_ = splitter._apply_overlap(chunks)
assert (
Expand All @@ -59,7 +59,7 @@ def test_apply_overlap_with_overlap_capturing_completely_previous_chunk(caplog):

def test_apply_overlap_single_chunk():
# Test the case where there is only one chunk
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=3, separators=["."])
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=3, separators=["."], split_unit="char")
chunks = ["chunk1"]
result = splitter._apply_overlap(chunks)
assert result == ["chunk1"]
Expand All @@ -74,7 +74,7 @@ def test_chunk_text_smaller_than_chunk_size():


def test_chunk_text_by_period():
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=0, separators=["."])
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=0, separators=["."], split_unit="char")
text = "This is a test. Another sentence. And one more."
chunks = splitter._chunk_text(text)
assert len(chunks) == 3
Expand All @@ -84,7 +84,7 @@ def test_chunk_text_by_period():


def test_run_multiple_new_lines():
splitter = RecursiveDocumentSplitter(split_length=20, separators=["\n\n", "\n"])
splitter = RecursiveDocumentSplitter(split_length=20, separators=["\n\n", "\n"], split_unit="char")
text = "This is a test.\n\n\nAnother test.\n\n\n\nFinal test."
doc = Document(content=text)
chunks = splitter.run([doc])["documents"]
Expand All @@ -110,6 +110,7 @@ def test_run_using_custom_sentence_tokenizer():
splitter = RecursiveDocumentSplitter(
split_length=400,
split_overlap=0,
split_unit="char",
separators=["\n\n", "\n", "sentence", " "],
sentence_splitter_params={"language": "en", "use_split_rules": True, "keep_white_spaces": False},
)
Expand All @@ -134,8 +135,8 @@ def test_run_using_custom_sentence_tokenizer():
) # noqa: E501


def test_run_split_by_dot_count_page_breaks() -> None:
document_splitter = RecursiveDocumentSplitter(separators=["."], split_length=30, split_overlap=0)
def test_run_split_by_dot_count_page_breaks_split_unit_char() -> None:
document_splitter = RecursiveDocumentSplitter(separators=["."], split_length=30, split_overlap=0, split_unit="char")

text = (
"Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f"
Expand Down Expand Up @@ -181,8 +182,8 @@ def test_run_split_by_dot_count_page_breaks() -> None:
assert documents[6].meta["split_idx_start"] == text.index(documents[6].content)


def test_run_split_by_word_count_page_breaks():
splitter = RecursiveDocumentSplitter(split_length=18, split_overlap=0, separators=["w"])
def test_run_split_by_word_count_page_breaks_split_unit_char():
splitter = RecursiveDocumentSplitter(split_length=18, split_overlap=0, separators=["w"], split_unit="char")
text = "This is some text. \f This text is on another page. \f This is the last pag3."
doc = Document(content=text)
doc_chunks = splitter.run([doc])
Expand Down Expand Up @@ -216,7 +217,9 @@ def test_run_split_by_word_count_page_breaks():


def test_run_split_by_page_break_count_page_breaks() -> None:
document_splitter = RecursiveDocumentSplitter(separators=["\f"], split_length=50, split_overlap=0)
document_splitter = RecursiveDocumentSplitter(
separators=["\f"], split_length=50, split_overlap=0, split_unit="char"
)

text = (
"Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f"
Expand Down Expand Up @@ -247,8 +250,10 @@ def test_run_split_by_page_break_count_page_breaks() -> None:
assert chunks_docs[3].meta["split_idx_start"] == text.index(chunks_docs[3].content)


def test_run_split_by_new_line_count_page_breaks() -> None:
document_splitter = RecursiveDocumentSplitter(separators=["\n"], split_length=21, split_overlap=0)
def test_run_split_by_new_line_count_page_breaks_split_unit_char() -> None:
document_splitter = RecursiveDocumentSplitter(
separators=["\n"], split_length=21, split_overlap=0, split_unit="char"
)

text = (
"Sentence on page 1.\nAnother on page 1.\n\f"
Expand Down Expand Up @@ -298,8 +303,10 @@ def test_run_split_by_new_line_count_page_breaks() -> None:
assert chunks_docs[6].meta["split_idx_start"] == text.index(chunks_docs[6].content)


def test_run_split_by_sentence_count_page_breaks() -> None:
document_splitter = RecursiveDocumentSplitter(separators=["sentence"], split_length=28, split_overlap=0)
def test_run_split_by_sentence_count_page_breaks_split_unit_char() -> None:
document_splitter = RecursiveDocumentSplitter(
separators=["sentence"], split_length=28, split_overlap=0, split_unit="char"
)

text = (
"Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f"
Expand Down Expand Up @@ -347,7 +354,7 @@ def test_run_split_by_sentence_count_page_breaks() -> None:


def test_run_split_document_with_overlap_character_unit():
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=11, separators=[".", " "])
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=11, separators=[".", " "], split_unit="char")
text = """A simple sentence1. A bright sentence2. A clever sentence3. A joyful sentence4"""

doc = Document(content=text)
Expand Down Expand Up @@ -384,18 +391,18 @@ def test_run_split_document_with_overlap_character_unit():


def test_run_separator_exists_but_split_length_too_small_fall_back_to_character_chunking():
splitter = RecursiveDocumentSplitter(separators=[" "], split_length=2)
splitter = RecursiveDocumentSplitter(separators=[" "], split_length=2, split_unit="char")
doc = Document(content="This is some text. This is some more text.")
result = splitter.run(documents=[doc])
assert len(result["documents"]) == 21
for doc in result["documents"]:
assert len(doc.content) == 2


def test_run_fallback_to_character_chunking():
def test_run_fallback_to_character_chunking_by_default_length_too_short():
text = "abczdefzghizjkl"
separators = ["\n\n", "\n", "z"]
splitter = RecursiveDocumentSplitter(split_length=2, separators=separators)
splitter = RecursiveDocumentSplitter(split_length=2, separators=separators, split_unit="char")
doc = Document(content=text)
chunks = splitter.run([doc])["documents"]
for chunk in chunks:
Expand All @@ -404,7 +411,7 @@ def test_run_fallback_to_character_chunking():

def test_run_custom_sentence_tokenizer_document_and_overlap_char_unit():
"""Test that RecursiveDocumentSplitter works correctly with custom sentence tokenizer and overlap"""
splitter = RecursiveDocumentSplitter(split_length=25, split_overlap=5, separators=["sentence"])
splitter = RecursiveDocumentSplitter(split_length=25, split_overlap=5, separators=["sentence"], split_unit="char")
text = "This is sentence one. This is sentence two. This is sentence three."

doc = Document(content=text)
Expand Down Expand Up @@ -485,6 +492,10 @@ def test_run_split_by_word_count_page_breaks_word_unit():
doc_chunks = splitter.run([doc])
doc_chunks = doc_chunks["documents"]

for doc in doc_chunks:
print(doc.content)
print(doc.meta)

assert len(doc_chunks) == 4
assert doc_chunks[0].content == "This is some text."
assert doc_chunks[0].meta["page_number"] == 1
Expand Down Expand Up @@ -546,9 +557,7 @@ def test_run_split_by_page_break_count_page_breaks_word_unit() -> None:


def test_run_split_by_new_line_count_page_breaks_word_unit() -> None:
document_splitter = RecursiveDocumentSplitter(
separators=["\n"], split_length=21, split_overlap=0, split_unit="word"
)
document_splitter = RecursiveDocumentSplitter(separators=["\n"], split_length=4, split_overlap=0, split_unit="word")

text = (
"Sentence on page 1.\nAnother on page 1.\n\f"
Expand Down Expand Up @@ -600,7 +609,7 @@ def test_run_split_by_new_line_count_page_breaks_word_unit() -> None:

def test_run_split_by_sentence_count_page_breaks_word_unit() -> None:
document_splitter = RecursiveDocumentSplitter(
separators=["sentence"], split_length=28, split_overlap=0, split_unit="word"
separators=["sentence"], split_length=7, split_overlap=0, split_unit="word"
)

text = (
Expand Down

0 comments on commit 0807902

Please sign in to comment.