Skip to content

Commit

Permalink
fix: handle 'null' value in chunking 'chunk_column'
Browse files Browse the repository at this point in the history
The chunking configuration takes a 'chunk_column' parameter, which
determines which column of the source row is chunked.

This commit treats a null entry as being equivalent to the empty string,
and no embeddings are generated.
  • Loading branch information
JamesGuthrie committed Jan 7, 2025
1 parent 89039b2 commit ac8f340
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
6 changes: 4 additions & 2 deletions projects/pgai/pgai/vectorizer/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def into_chunks(self, item: dict[str, Any]) -> list[str]:
Returns:
list[str]: A list of chunked strings.
"""
return self._chunker.split_text(item[self.chunk_column])
text = item[self.chunk_column] or ""
return self._chunker.split_text(text)


class LangChainRecursiveCharacterTextSplitter(BaseModel, Chunker):
Expand Down Expand Up @@ -126,4 +127,5 @@ def into_chunks(self, item: dict[str, Any]) -> list[str]:
Returns:
list[str]: A list of chunked strings.
"""
return self._chunker.split_text(item[self.chunk_column])
text = item[self.chunk_column] or ""
return self._chunker.split_text(text)
53 changes: 53 additions & 0 deletions projects/pgai/tests/vectorizer/test_vectorizer_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,3 +956,56 @@ def test_recursive_character_splitting(
assert sequences == list(
range(len(sequences))
), "Chunk sequences should be sequential starting from 0"


@pytest.mark.parametrize(
"test_params",
[
(
1,
1,
1,
"chunking_character_text_splitter('content')",
"formatting_python_template('$chunk')",
),
(
1,
1,
1,
"chunking_recursive_character_text_splitter('content')",
"formatting_python_template('$chunk')",
),
],
)
def test_vectorization_successful_with_null_contents(
cli_db: tuple[PostgresContainer, Connection],
cli_db_url: str,
configured_ollama_vectorizer_id: int,
test_params: tuple[int, int, int, str, str], # noqa: ARG001
):
_, conn = cli_db

with conn.cursor(row_factory=dict_row) as cur:
cur.execute("ALTER TABLE blog ALTER COLUMN content DROP NOT NULL;")
cur.execute("UPDATE blog SET content = null;")

result = CliRunner().invoke(
vectorizer_worker,
[
"--db-url",
cli_db_url,
"--once",
"--vectorizer-id",
str(configured_ollama_vectorizer_id),
],
catch_exceptions=False,
)

assert not result.exception
assert result.exit_code == 0

_, conn = cli_db

with conn.cursor(row_factory=dict_row) as cur:
cur.execute("SELECT count(*) as count FROM blog_embedding_store;")
assert cur.fetchone()["count"] == 0 # type: ignore

0 comments on commit ac8f340

Please sign in to comment.