Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'development' into feat-cohere
Browse files Browse the repository at this point in the history
Udayk02 authored Jan 5, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents d44ffb3 + b2aa570 commit 2c7aef1
Showing 5 changed files with 71 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-test-push.yml
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v4
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -38,13 +38,14 @@ dependencies = [
[project.urls]
Homepage = "https://github.com/bhavnicksm/chonkie"
Documentation = "https://docs.chonkie.ai"

[project.optional-dependencies]
model2vec = ["model2vec>=0.1.0", "numpy>=1.23.0, <2.2"]
model2vec = ["model2vec>=0.3.0", "numpy>=1.23.0, <2.2"]
st = ["sentence-transformers>=3.0.0", "numpy>=1.23.0, <2.2"]
openai = ["openai>=1.0.0", "numpy>=1.23.0, <2.2"]
semantic = ["model2vec>=0.1.0", "numpy>=1.23.0, <2.2"]
semantic = ["model2vec>=0.3.0", "numpy>=1.23.0, <2.2"]
cohere = ["cohere>=5.13.0", "numpy>=1.23.0, <2.2"]
all = ["sentence-transformers>=3.0.0", "numpy>=1.23.0, <2.2", "openai>=1.0.0", "model2vec>=0.1.0", "cohere>=5.13.0"]
all = ["sentence-transformers>=3.0.0", "numpy>=1.23.0, <2.2", "openai>=1.0.0", "model2vec>=0.3.0", "cohere>=5.13.0"]
dev = [
"pytest>=6.2.0",
"pytest-cov>=4.0.0",
@@ -64,4 +65,4 @@ packages = ["chonkie",
"chonkie.refinery"]

[tool.ruff]
select = ["F", "I", "D", "DOC"]
select = ["F", "I", "D", "DOC"]
6 changes: 3 additions & 3 deletions src/chonkie/chunker/base.py
Original file line number Diff line number Diff line change
@@ -183,11 +183,11 @@ def _decode(self, tokens) -> str:
def _decode_batch(self, token_lists: List[List[int]]) -> List[str]:
"""Decode a batch of token lists using the backend tokenizer."""
if self._tokenizer_backend == "transformers":
return [self.tokenizer.decode(tokens) for tokens in token_lists]
return self.tokenizer.batch_decode(token_lists, skip_special_tokens=True)
elif self._tokenizer_backend == "tokenizers":
return [self.tokenizer.decode(tokens) for tokens in token_lists]
return self.tokenizer.decode_batch(token_lists)
elif self._tokenizer_backend == "tiktoken":
return [self.tokenizer.decode(tokens) for tokens in token_lists]
return self.tokenizer.decode_batch(token_lists)
elif self._tokenizer_backend == "callable":
raise NotImplementedError(
"Callable tokenizer backend does not support batch decoding."
87 changes: 45 additions & 42 deletions src/chonkie/chunker/token.py
Original file line number Diff line number Diff line change
@@ -52,17 +52,27 @@ def __init__(
def _create_chunks(
self,
chunk_texts: List[str],
token_counts: List[int],
decoded_text: str,
token_groups: List[List[int]],
token_counts: List[int]
) -> List[Chunk]:
"""Create chunks from a list of texts."""
# package everything as Chunk objects and send out the result
# Find the overlap lengths for index calculation
if self.chunk_overlap > 0:
# we get the overlap texts, that gives you the start_index for the next chunk
# if the token group is smaller than the overlap, we just use the whole token group
overlap_texts = self._decode_batch([token_group[-self.chunk_overlap:]
if (len(token_group) > self.chunk_overlap)
else token_group
for token_group in token_groups])
overlap_lengths = [len(overlap_text) for overlap_text in overlap_texts]
else:
overlap_lengths = [0] * len(token_groups)

# Create the chunks
chunks = []
current_index = 0
for chunk_text, token_count in zip(chunk_texts, token_counts):
start_index = decoded_text.find(
chunk_text, current_index
) # Find needs to be run every single time because of unknown overlap length
for chunk_text, overlap_length, token_count in zip(chunk_texts, overlap_lengths, token_counts):
start_index = current_index
end_index = start_index + len(chunk_text)
chunks.append(
Chunk(
@@ -72,7 +82,8 @@ def _create_chunks(
token_count=token_count,
)
)
current_index = end_index
current_index = end_index - overlap_length

return chunks

def chunk(self, text: str) -> List[Chunk]:
@@ -91,40 +102,24 @@ def chunk(self, text: str) -> List[Chunk]:
# Encode full text
text_tokens = self._encode(text)

# We decode the text because the tokenizer might result in a different output than text
decoded_text = self._decode(text_tokens)

# Calculate chunk positions
token_groups = [
text_tokens[
start_index : min(start_index + self.chunk_size, len(text_tokens))
]
for start_index in range(
0, len(text_tokens), self.chunk_size - self.chunk_overlap
)
]
token_counts = [
len(toks) for toks in token_groups
] # get the token counts; it's prolly chunk_size, but len doesn't take too long
token_groups = [text_tokens[start_index : min(start_index + self.chunk_size, len(text_tokens))]
for start_index in range(0, len(text_tokens), self.chunk_size - self.chunk_overlap)]
token_counts = [len(toks) for toks in token_groups]

chunk_texts = self._decode_batch(
token_groups
) # decrease the time by decoding in one go (?)
# decode the token groups into the chunk texts
chunk_texts = self._decode_batch(token_groups)

chunks = self._create_chunks(chunk_texts, token_counts, decoded_text)
# Create the chunks from the token groups and token counts
chunks = self._create_chunks(chunk_texts, token_groups, token_counts)

return chunks

def _chunk_generator(
self, tokens: List[int]
) -> Generator[Tuple[List[int], int, int], None, None]:
def _token_group_generator(self, tokens: List[int]) -> Generator[List[int], None, None]:
"""Generate chunks from a list of tokens."""
stride = self.chunk_size - self.chunk_overlap
for start in range(0, len(tokens), stride):
for start in range(0, len(tokens), self.chunk_size - self.chunk_overlap):
end = min(start + self.chunk_size, len(tokens))
yield tokens[start:end], start, end
if end == len(tokens):
break
yield tokens[start:end]

def _process_batch(self,
chunks: List[Tuple[List[int], int, int]],
@@ -148,22 +143,28 @@ def _process_batch(self,

def _process_text_batch(self, texts: List[str]) -> List[List[Chunk]]:
"""Process a batch of texts."""
# encode the texts into tokens in a batch
tokens_list = self._encode_batch(texts)
decoded_texts = self._decode_batch(tokens_list)
result = []

for tokens, text in zip(tokens_list, decoded_texts):
for tokens in tokens_list:
if not tokens:
result.append([])
continue

chunks = []
chunk_batch = []
# get the token groups
token_groups = []
for token_group in self._token_group_generator(tokens):
token_groups.append(token_group)

# get the token counts
token_counts = [len(token_group) for token_group in token_groups]

for chunk_data in self._chunk_generator(tokens):
chunk_batch.append(chunk_data)
# decode the token groups into the chunk texts
chunk_texts = self._decode_batch(token_groups)

chunks.extend(self._process_batch(chunk_batch, text))
# create the chunks from the token groups and token counts
chunks = self._create_chunks(chunk_texts, token_groups, token_counts)
result.append(chunks)

return result
@@ -181,6 +182,7 @@ def chunk_batch(
List of lists of Chunk objects containing the chunked text and metadata
"""
# if batch_size is not None, we process the texts in mini-batches to avoid memory issues
if batch_size is not None:
chunks = []
for i in range(0, len(texts), batch_size):
@@ -193,6 +195,7 @@ def chunk_batch(
def __repr__(self) -> str:
"""Return a string representation of the TokenChunker."""
return (
f"TokenChunker(chunk_size={self.chunk_size}, "
f"TokenChunker(tokenizer={self.tokenizer}, "
f"chunk_size={self.chunk_size}, "
f"chunk_overlap={self.chunk_overlap})"
)
34 changes: 17 additions & 17 deletions tests/chunker/test_token_chunker.py
Original file line number Diff line number Diff line change
@@ -152,9 +152,9 @@ def test_token_chunker_initialization_tik(tiktokenizer):
assert chunker.chunk_overlap == 128


def test_token_chunker_chunking(tokenizer, sample_text):
def test_token_chunker_chunking(tiktokenizer, sample_text):
"""Test that the TokenChunker can chunk a sample text into tokens."""
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)
chunks = chunker.chunk(sample_text)

assert len(chunks) > 0
@@ -196,9 +196,9 @@ def test_token_chunker_chunking_tik(tiktokenizer, sample_text):
assert all([chunk.end_index is not None for chunk in chunks])


def test_token_chunker_empty_text(tokenizer):
def test_token_chunker_empty_text(tiktokenizer):
"""Test that the TokenChunker can handle empty text input."""
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)
chunks = chunker.chunk("")

assert len(chunks) == 0
@@ -246,9 +246,9 @@ def test_token_chunker_single_chunk_text(tokenizer):
assert chunks[0].text == "Hello, how are you?"


def test_token_chunker_batch_chunking(tokenizer, sample_batch):
def test_token_chunker_batch_chunking(tiktokenizer, sample_batch):
"""Test that the TokenChunker can chunk a batch of texts into tokens."""
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)
chunks = chunker.chunk_batch(sample_batch)

assert len(chunks) > 0
@@ -267,16 +267,16 @@ def test_token_chunker_batch_chunking(tokenizer, sample_batch):
)


def test_token_chunker_repr(tokenizer):
def test_token_chunker_repr(tiktokenizer):
"""Test that the TokenChunker has a string representation."""
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)

assert repr(chunker) == "TokenChunker(chunk_size=512, chunk_overlap=128)"
assert repr(chunker) == "TokenChunker(tokenizer=<Encoding 'gpt2'>, chunk_size=512, chunk_overlap=128)"


def test_token_chunker_call(tokenizer, sample_text):
def test_token_chunker_call(tiktokenizer, sample_text):
"""Test that the TokenChunker can be called directly."""
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)
chunks = chunker(sample_text)

assert len(chunks) > 0
@@ -305,7 +305,7 @@ def verify_chunk_indices(chunks: List[Chunk], original_text: str):
)


def test_token_chunker_indices(sample_text):
def test_token_chunker_indices(tiktokenizer, sample_text):
"""Test that TokenChunker's indices correctly map to original text."""
tokenizer = Tokenizer.from_pretrained("gpt2")
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
@@ -321,19 +321,19 @@ def test_token_chunker_indices_complex_md(sample_complex_markdown_text):
verify_chunk_indices(chunks, sample_complex_markdown_text)


def test_token_chunker_token_counts(tokenizer, sample_text):
def test_token_chunker_token_counts(tiktokenizer, sample_text):
"""Test that the TokenChunker correctly calculates token counts."""
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)
chunks = chunker.chunk(sample_text)
assert all([chunk.token_count > 0 for chunk in chunks]), "All chunks must have a positive token count"
assert all([chunk.token_count <= 512 for chunk in chunks]), "All chunks must have a token count less than or equal to 512"

token_counts = [len(tokenizer.encode(chunk.text)) for chunk in chunks]
token_counts = [len(tiktokenizer.encode(chunk.text)) for chunk in chunks]
assert all([chunk.token_count == token_count for chunk, token_count in zip(chunks, token_counts)]), "All chunks must have a token count equal to the length of the encoded text"

def test_token_chunker_indices_batch(tokenizer, sample_text):
def test_token_chunker_indices_batch(tiktokenizer, sample_text):
"""Test that TokenChunker's indices correctly map to original text."""
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)
chunks = chunker.chunk_batch([sample_text]*10)[-1]
verify_chunk_indices(chunks, sample_text)

0 comments on commit 2c7aef1

Please sign in to comment.