Skip to content

Commit

Permalink
Merge pull request #132 from StampyAI/refactor-text-splitter
Browse files Browse the repository at this point in the history
text splitter minor refactor
  • Loading branch information
henri123lemoine authored Aug 16, 2023
2 parents fecc5b1 + 4663669 commit 09e19eb
Showing 1 changed file with 40 additions and 40 deletions.
80 changes: 40 additions & 40 deletions align_data/pinecone/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
from langchain.text_splitter import TextSplitter
from nltk.tokenize import sent_tokenize

# TODO: Fix this.
# sent_tokenize has strange behavior sometimes: 'The units could be anything (characters, words, sentences, etc.), depending on how you want to chunk your text.'
# splits into ['The units could be anything (characters, words, sentences, etc.', '), depending on how you want to chunk your text.']

StrToIntFunction = Callable[[str], int]
StrIntBoolToStrFunction = Callable[[str, int, bool], str]

def default_truncate_function(string: str, length: int, from_end: bool = False) -> str:
return string[-length:] if from_end else string[:length]

def default_truncate_function(string: str, length: int, from_end: bool = False) -> str:
return string[-length:] if from_end else string[:length]
Expand All @@ -14,22 +23,22 @@ class ParagraphSentenceUnitTextSplitter(TextSplitter):
@param min_chunk_size: The minimum number of units in a chunk.
@param max_chunk_size: The maximum number of units in a chunk.
@param length_function: A function that returns the length of a string in units.
@param length_function: A function that returns the length of a string in units. Defaults to len().
@param truncate_function: A function that truncates a string to a given unit length.
"""

DEFAULT_MIN_CHUNK_SIZE = 900
DEFAULT_MAX_CHUNK_SIZE = 1100
DEFAULT_LENGTH_FUNCTION = lambda string: len(string)
DEFAULT_TRUNCATE_FUNCTION = default_truncate_function

DEFAULT_MIN_CHUNK_SIZE: int = 900
DEFAULT_MAX_CHUNK_SIZE: int = 1100
DEFAULT_LENGTH_FUNCTION: StrToIntFunction = len
DEFAULT_TRUNCATE_FUNCTION: StrIntBoolToStrFunction = default_truncate_function

def __init__(
self,
min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE,
max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE,
length_function: Callable[[str], int] = DEFAULT_LENGTH_FUNCTION,
truncate_function: Callable[[str, int], str] = DEFAULT_TRUNCATE_FUNCTION,
**kwargs: Any,
length_function: StrToIntFunction = DEFAULT_LENGTH_FUNCTION,
truncate_function: StrIntBoolToStrFunction = DEFAULT_TRUNCATE_FUNCTION,
**kwargs: Any
):
super().__init__(**kwargs)
self.min_chunk_size = min_chunk_size
Expand All @@ -39,8 +48,9 @@ def __init__(
self._truncate_function = truncate_function

def split_text(self, text: str) -> List[str]:
blocks = []
current_block = ""
"""Split text into chunks of length between min_chunk_size and max_chunk_size."""
blocks: List[str] = []
current_block: str = ""

paragraphs = text.split("\n\n")
for paragraph in paragraphs:
Expand All @@ -56,10 +66,9 @@ def split_text(self, text: str) -> List[str]:
continue

blocks = self._handle_remaining_text(current_block, blocks)

return [block.strip() for block in blocks]

def _handle_large_paragraph(self, current_block, blocks, paragraph):
def _handle_large_paragraph(self, current_block: str, blocks: List[str], paragraph: str) -> str:
# Undo adding the whole paragraph
offset = len(paragraph) + 2 # +2 accounts for "\n\n"
current_block = current_block[:-offset]
Expand All @@ -75,44 +84,35 @@ def _handle_large_paragraph(self, current_block, blocks, paragraph):
blocks.append(current_block)
current_block = ""
else:
current_block = self._truncate_large_block(
current_block, blocks, sentence
)

current_block = self._truncate_large_block(current_block, blocks)
return current_block

def _truncate_large_block(self, current_block, blocks, sentence):
def _truncate_large_block(self, current_block: str, blocks: List[str]) -> str:
while self._length_function(current_block) > self.max_chunk_size:
# Truncate current_block to max size, set remaining sentence as next sentence
# Truncate current_block to max size, set remaining text as current_block
truncated_block = self._truncate_function(
current_block, self.max_chunk_size
current_block, self.max_chunk_size
)
blocks.append(truncated_block)

remaining_sentence = current_block[len(truncated_block) :].lstrip()
current_block = sentence = remaining_sentence

current_block = current_block[len(truncated_block):].lstrip()

return current_block

def _handle_remaining_text(self, current_block, blocks):
def _handle_remaining_text(self, last_block: str, blocks: List[str]) -> List[str]:
if blocks == []: # no blocks were added
return [current_block]
elif current_block: # any leftover text
len_current_block = self._length_function(current_block)
if len_current_block < self.min_chunk_size:
# it needs to take the last min_chunk_size-len_current_block units from the previous block
previous_block = blocks[-1]
required_units = (
self.min_chunk_size - len_current_block
) # calculate the required units

return [last_block]
elif last_block: # any leftover text
len_last_block = self._length_function(last_block)
if self.min_chunk_size - len_last_block > 0:
# Add text from previous block to last block if last_block is too short
part_prev_block = self._truncate_function(
previous_block, required_units, from_end=True
) # get the required units from the previous block
last_block = part_prev_block + current_block
string=blocks[-1],
length=self.min_chunk_size - len_last_block,
from_end=True
)
last_block = part_prev_block + last_block

blocks.append(last_block)
else:
blocks.append(current_block)
blocks.append(last_block)

return blocks

0 comments on commit 09e19eb

Please sign in to comment.