Skip to content

Commit 0807902

Browse files
committed
fixing types
1 parent 211c4ed commit 0807902

File tree

2 files changed

+82
-40
lines changed

2 files changed

+82
-40
lines changed

haystack/components/preprocessors/recursive_splitter.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ class RecursiveDocumentSplitter:
5151
>]
5252
""" # noqa: E501
5353

54-
def __init__( # pylint: disable=too-many-positional-arguments
54+
def __init__(
5555
self,
56+
*,
5657
split_length: int = 200,
5758
split_overlap: int = 0,
5859
split_unit: Literal["word", "char"] = "word",
@@ -71,6 +72,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
7172
text will be split into sentences using a custom sentence tokenizer based on NLTK.
7273
See: haystack.components.preprocessors.sentence_tokenizer.SentenceSplitter.
7374
If no separators are provided, the default separators ["\n\n", "sentence", "\n", " "] are used.
75+
:param sentence_splitter_params: Optional parameters to pass to the sentence tokenizer.
76+
See: haystack.components.preprocessors.sentence_tokenizer.SentenceSplitter for more information.
7477
7578
:raises ValueError: If the overlap is greater than or equal to the chunk size or if the overlap is negative, or
7679
if any separator is not a string.
@@ -81,9 +84,20 @@ def __init__( # pylint: disable=too-many-positional-arguments
8184
self.separators = separators if separators else ["\n\n", "sentence", "\n", " "] # default separators
8285
self.sentence_tokenizer_params = sentence_splitter_params
8386
self._check_params()
87+
self.nltk_tokenizer = None
8488
if "sentence" in self.separators:
85-
sentence_splitter_params = sentence_splitter_params or {"keep_white_spaces": True}
86-
self.nltk_tokenizer = self._get_custom_sentence_tokenizer(sentence_splitter_params)
89+
self.warm_up(sentence_splitter_params)
90+
91+
def warm_up(self, sentence_splitter_params):
92+
"""
93+
Warm up the sentence tokenizer.
94+
95+
:param sentence_splitter_params: Optional parameters to pass to the sentence tokenizer.
96+
:returns:
97+
An instance of the SentenceSplitter.
98+
"""
99+
sentence_splitter_params = sentence_splitter_params or {"keep_white_spaces": True}
100+
self.nltk_tokenizer = self._get_custom_sentence_tokenizer(sentence_splitter_params)
87101

88102
def _check_params(self):
89103
if self.split_length < 1:
@@ -162,7 +176,8 @@ def _chunk_text(self, text: str) -> List[str]:
162176

163177
for curr_separator in self.separators: # type: ignore # the caller already checked that separators is not None
164178
if curr_separator == "sentence":
165-
sentence_with_spans = self.nltk_tokenizer.split_sentences(text)
179+
# correct SentenceSplitter initialization is checked at the initialization of the component
180+
sentence_with_spans = self.nltk_tokenizer.split_sentences(text) # type: ignore
166181
splits = [sentence["sentence"] for sentence in sentence_with_spans]
167182
else:
168183
escaped_separator = re.escape(curr_separator)
@@ -221,19 +236,37 @@ def _chunk_text(self, text: str) -> List[str]:
221236
if chunks:
222237
return chunks
223238

224-
# if no separator worked, fall back to character- or word-level chunking
225-
# ToDo: refactor into a function making use of split_unit parameter that can be easily tested in isolation
239+
# if no separator worked, fall back to word- or character-level chunking
226240
if self.split_units == "word":
227-
return [
228-
" ".join(text.split()[i : i + self.split_length])
229-
for i in range(0, self._chunk_length(text), self.split_length - self.split_overlap)
230-
]
241+
return self.fall_back_to_word_level_chunking(text)
231242

232-
if self.split_units == "char":
233-
return [
234-
text[i : i + self.split_length]
235-
for i in range(0, self._chunk_length(text), self.split_length - self.split_overlap)
236-
]
243+
return self.fall_back_to_char_level_chunking(text)
244+
245+
def fall_back_to_word_level_chunking(self, text: str) -> List[str]:
246+
"""
247+
Fall back to word-level chunking if no separator works.
248+
249+
:param text: The text to be split into chunks.
250+
:returns:
251+
A list of text chunks.
252+
"""
253+
return [
254+
" ".join(text.split()[i : i + self.split_length])
255+
for i in range(0, self._chunk_length(text), self.split_length - self.split_overlap)
256+
]
257+
258+
def fall_back_to_char_level_chunking(self, text: str) -> List[str]:
259+
"""
260+
Fall back to character-level chunking if no separator works.
261+
262+
:param text: The text to be split into chunks.
263+
:returns:
264+
A list of text chunks.
265+
"""
266+
return [
267+
text[i : i + self.split_length]
268+
for i in range(0, self._chunk_length(text), self.split_length - self.split_overlap)
269+
]
237270

238271
def _add_overlap_info(self, curr_pos: int, new_doc: Document, new_docs: List[Document]) -> None:
239272
prev_doc = new_docs[-1]
@@ -244,7 +277,7 @@ def _add_overlap_info(self, curr_pos: int, new_doc: Document, new_docs: List[Doc
244277
{
245278
"doc_id": prev_doc.id,
246279
"range": (
247-
self._chunk_length(prev_doc.content) - overlap_length,
280+
self._chunk_length(prev_doc.content) - overlap_length, # type: ignore
248281
self._chunk_length(prev_doc.content), # type: ignore
249282
),
250283
}

test/components/preprocessors/test_recursive_splitter.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,22 @@ def test_init_with_negative_split_length():
3333

3434
def test_apply_overlap_no_overlap():
3535
# Test the case where there is no overlap between chunks
36-
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=0, separators=["."])
36+
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=0, separators=["."], split_unit="char")
3737
chunks = ["chunk1", "chunk2", "chunk3"]
3838
result = splitter._apply_overlap(chunks)
3939
assert result == ["chunk1", "chunk2", "chunk3"]
4040

4141

4242
def test_apply_overlap_with_overlap():
4343
# Test the case where there is overlap between chunks
44-
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=4, separators=["."])
44+
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=4, separators=["."], split_unit="char")
4545
chunks = ["chunk1", "chunk2", "chunk3"]
4646
result = splitter._apply_overlap(chunks)
4747
assert result == ["chunk1", "unk1chunk2", "unk2chunk3"]
4848

4949

5050
def test_apply_overlap_with_overlap_capturing_completely_previous_chunk(caplog):
51-
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=6, separators=["."])
51+
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=6, separators=["."], split_unit="char")
5252
chunks = ["chunk1", "chunk2", "chunk3", "chunk4"]
5353
_ = splitter._apply_overlap(chunks)
5454
assert (
@@ -59,7 +59,7 @@ def test_apply_overlap_with_overlap_capturing_completely_previous_chunk(caplog):
5959

6060
def test_apply_overlap_single_chunk():
6161
# Test the case where there is only one chunk
62-
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=3, separators=["."])
62+
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=3, separators=["."], split_unit="char")
6363
chunks = ["chunk1"]
6464
result = splitter._apply_overlap(chunks)
6565
assert result == ["chunk1"]
@@ -74,7 +74,7 @@ def test_chunk_text_smaller_than_chunk_size():
7474

7575

7676
def test_chunk_text_by_period():
77-
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=0, separators=["."])
77+
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=0, separators=["."], split_unit="char")
7878
text = "This is a test. Another sentence. And one more."
7979
chunks = splitter._chunk_text(text)
8080
assert len(chunks) == 3
@@ -84,7 +84,7 @@ def test_chunk_text_by_period():
8484

8585

8686
def test_run_multiple_new_lines():
87-
splitter = RecursiveDocumentSplitter(split_length=20, separators=["\n\n", "\n"])
87+
splitter = RecursiveDocumentSplitter(split_length=20, separators=["\n\n", "\n"], split_unit="char")
8888
text = "This is a test.\n\n\nAnother test.\n\n\n\nFinal test."
8989
doc = Document(content=text)
9090
chunks = splitter.run([doc])["documents"]
@@ -110,6 +110,7 @@ def test_run_using_custom_sentence_tokenizer():
110110
splitter = RecursiveDocumentSplitter(
111111
split_length=400,
112112
split_overlap=0,
113+
split_unit="char",
113114
separators=["\n\n", "\n", "sentence", " "],
114115
sentence_splitter_params={"language": "en", "use_split_rules": True, "keep_white_spaces": False},
115116
)
@@ -134,8 +135,8 @@ def test_run_using_custom_sentence_tokenizer():
134135
) # noqa: E501
135136

136137

137-
def test_run_split_by_dot_count_page_breaks() -> None:
138-
document_splitter = RecursiveDocumentSplitter(separators=["."], split_length=30, split_overlap=0)
138+
def test_run_split_by_dot_count_page_breaks_split_unit_char() -> None:
139+
document_splitter = RecursiveDocumentSplitter(separators=["."], split_length=30, split_overlap=0, split_unit="char")
139140

140141
text = (
141142
"Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f"
@@ -181,8 +182,8 @@ def test_run_split_by_dot_count_page_breaks() -> None:
181182
assert documents[6].meta["split_idx_start"] == text.index(documents[6].content)
182183

183184

184-
def test_run_split_by_word_count_page_breaks():
185-
splitter = RecursiveDocumentSplitter(split_length=18, split_overlap=0, separators=["w"])
185+
def test_run_split_by_word_count_page_breaks_split_unit_char():
186+
splitter = RecursiveDocumentSplitter(split_length=18, split_overlap=0, separators=["w"], split_unit="char")
186187
text = "This is some text. \f This text is on another page. \f This is the last pag3."
187188
doc = Document(content=text)
188189
doc_chunks = splitter.run([doc])
@@ -216,7 +217,9 @@ def test_run_split_by_word_count_page_breaks():
216217

217218

218219
def test_run_split_by_page_break_count_page_breaks() -> None:
219-
document_splitter = RecursiveDocumentSplitter(separators=["\f"], split_length=50, split_overlap=0)
220+
document_splitter = RecursiveDocumentSplitter(
221+
separators=["\f"], split_length=50, split_overlap=0, split_unit="char"
222+
)
220223

221224
text = (
222225
"Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f"
@@ -247,8 +250,10 @@ def test_run_split_by_page_break_count_page_breaks() -> None:
247250
assert chunks_docs[3].meta["split_idx_start"] == text.index(chunks_docs[3].content)
248251

249252

250-
def test_run_split_by_new_line_count_page_breaks() -> None:
251-
document_splitter = RecursiveDocumentSplitter(separators=["\n"], split_length=21, split_overlap=0)
253+
def test_run_split_by_new_line_count_page_breaks_split_unit_char() -> None:
254+
document_splitter = RecursiveDocumentSplitter(
255+
separators=["\n"], split_length=21, split_overlap=0, split_unit="char"
256+
)
252257

253258
text = (
254259
"Sentence on page 1.\nAnother on page 1.\n\f"
@@ -298,8 +303,10 @@ def test_run_split_by_new_line_count_page_breaks() -> None:
298303
assert chunks_docs[6].meta["split_idx_start"] == text.index(chunks_docs[6].content)
299304

300305

301-
def test_run_split_by_sentence_count_page_breaks() -> None:
302-
document_splitter = RecursiveDocumentSplitter(separators=["sentence"], split_length=28, split_overlap=0)
306+
def test_run_split_by_sentence_count_page_breaks_split_unit_char() -> None:
307+
document_splitter = RecursiveDocumentSplitter(
308+
separators=["sentence"], split_length=28, split_overlap=0, split_unit="char"
309+
)
303310

304311
text = (
305312
"Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f"
@@ -347,7 +354,7 @@ def test_run_split_by_sentence_count_page_breaks() -> None:
347354

348355

349356
def test_run_split_document_with_overlap_character_unit():
350-
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=11, separators=[".", " "])
357+
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=11, separators=[".", " "], split_unit="char")
351358
text = """A simple sentence1. A bright sentence2. A clever sentence3. A joyful sentence4"""
352359

353360
doc = Document(content=text)
@@ -384,18 +391,18 @@ def test_run_split_document_with_overlap_character_unit():
384391

385392

386393
def test_run_separator_exists_but_split_length_too_small_fall_back_to_character_chunking():
387-
splitter = RecursiveDocumentSplitter(separators=[" "], split_length=2)
394+
splitter = RecursiveDocumentSplitter(separators=[" "], split_length=2, split_unit="char")
388395
doc = Document(content="This is some text. This is some more text.")
389396
result = splitter.run(documents=[doc])
390397
assert len(result["documents"]) == 21
391398
for doc in result["documents"]:
392399
assert len(doc.content) == 2
393400

394401

395-
def test_run_fallback_to_character_chunking():
402+
def test_run_fallback_to_character_chunking_by_default_length_too_short():
396403
text = "abczdefzghizjkl"
397404
separators = ["\n\n", "\n", "z"]
398-
splitter = RecursiveDocumentSplitter(split_length=2, separators=separators)
405+
splitter = RecursiveDocumentSplitter(split_length=2, separators=separators, split_unit="char")
399406
doc = Document(content=text)
400407
chunks = splitter.run([doc])["documents"]
401408
for chunk in chunks:
@@ -404,7 +411,7 @@ def test_run_fallback_to_character_chunking():
404411

405412
def test_run_custom_sentence_tokenizer_document_and_overlap_char_unit():
406413
"""Test that RecursiveDocumentSplitter works correctly with custom sentence tokenizer and overlap"""
407-
splitter = RecursiveDocumentSplitter(split_length=25, split_overlap=5, separators=["sentence"])
414+
splitter = RecursiveDocumentSplitter(split_length=25, split_overlap=5, separators=["sentence"], split_unit="char")
408415
text = "This is sentence one. This is sentence two. This is sentence three."
409416

410417
doc = Document(content=text)
@@ -485,6 +492,10 @@ def test_run_split_by_word_count_page_breaks_word_unit():
485492
doc_chunks = splitter.run([doc])
486493
doc_chunks = doc_chunks["documents"]
487494

495+
for doc in doc_chunks:
496+
print(doc.content)
497+
print(doc.meta)
498+
488499
assert len(doc_chunks) == 4
489500
assert doc_chunks[0].content == "This is some text."
490501
assert doc_chunks[0].meta["page_number"] == 1
@@ -546,9 +557,7 @@ def test_run_split_by_page_break_count_page_breaks_word_unit() -> None:
546557

547558

548559
def test_run_split_by_new_line_count_page_breaks_word_unit() -> None:
549-
document_splitter = RecursiveDocumentSplitter(
550-
separators=["\n"], split_length=21, split_overlap=0, split_unit="word"
551-
)
560+
document_splitter = RecursiveDocumentSplitter(separators=["\n"], split_length=4, split_overlap=0, split_unit="word")
552561

553562
text = (
554563
"Sentence on page 1.\nAnother on page 1.\n\f"
@@ -600,7 +609,7 @@ def test_run_split_by_new_line_count_page_breaks_word_unit() -> None:
600609

601610
def test_run_split_by_sentence_count_page_breaks_word_unit() -> None:
602611
document_splitter = RecursiveDocumentSplitter(
603-
separators=["sentence"], split_length=28, split_overlap=0, split_unit="word"
612+
separators=["sentence"], split_length=7, split_overlap=0, split_unit="word"
604613
)
605614

606615
text = (

0 commit comments

Comments
 (0)