Skip to content

Commit

Permalink
feat:chunk by separater add enable_merge parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Aries-ckt committed Jan 3, 2024
1 parent f95ce78 commit 6963584
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
5 changes: 5 additions & 0 deletions dbgpt/rag/chunk_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class ChunkParameters(BaseModel):
default="\n",
description="chunk separator",
)
enable_merge: bool = Field(
default=None,
description="enable chunk merge by chunk_size.",
)


class ChunkManager:
Expand Down Expand Up @@ -134,4 +138,5 @@ def _select_text_splitter(
chunk_size=self._chunk_parameters.chunk_size,
chunk_overlap=self._chunk_parameters.chunk_overlap,
separator=self._chunk_parameters.separator,
enable_merge=self._chunk_parameters.enable_merge,
)
10 changes: 9 additions & 1 deletion dbgpt/rag/knowledge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,14 @@ class ChunkStrategy(Enum):
)
CHUNK_BY_SEPARATOR = (
SeparatorTextSplitter,
[{"param_name": "separator", "param_type": "string", "default_value": "\n"}],
[
{"param_name": "separator", "param_type": "string", "default_value": "\n"},
{
"param_name": "enable_merge",
"param_type": "boolean",
"default_value": False,
},
],
"separator",
"split document by separator",
)
Expand All @@ -80,6 +87,7 @@ def __init__(self, splitter_class, parameters, alias, description):
self.description = description

def match(self, *args, **kwargs):
kwargs = {k: v for k, v in kwargs.items() if v is not None}
return self.value[0](*args, **kwargs)


Expand Down
5 changes: 4 additions & 1 deletion dbgpt/rag/text_splitter/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,7 @@ class SeparatorTextSplitter(CharacterTextSplitter):

def __init__(self, separator: str = "\n", filters: list = [], **kwargs: Any):
"""Create a new TextSplitter."""
self._merge = kwargs.pop("enable_merge") or False
super().__init__(**kwargs)
self._separator = separator
self._filter = filters
Expand All @@ -696,7 +697,9 @@ def split_text(
splits = text.split(separator)
else:
splits = list(text)
return self._merge_splits(splits, separator, chunk_overlap=0, **kwargs)
if self._merge:
return self._merge_splits(splits, separator, chunk_overlap=0, **kwargs)
return list(filter(None, text.split(separator)))


class PageTextSplitter(TextSplitter):
Expand Down

0 comments on commit 6963584

Please sign in to comment.