Skip to content

Commit

Permalink
feat: super minimal implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Askir committed Jan 30, 2025
1 parent 9dffab1 commit 9aa5c7d
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 26 deletions.
34 changes: 18 additions & 16 deletions projects/extension/sql/idempotent/005-chunking.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
-------------------------------------------------------------------------------
-- chunking_character_text_splitter
create or replace function ai.chunking_character_text_splitter
( chunk_column pg_catalog.name
( chunk_column pg_catalog.name default ''
, chunk_size pg_catalog.int4 default 800
, chunk_overlap pg_catalog.int4 default 400
, separator pg_catalog.text default E'\n\n'
Expand All @@ -26,7 +26,7 @@ set search_path to pg_catalog, pg_temp
-------------------------------------------------------------------------------
-- chunking_recursive_character_text_splitter
create or replace function ai.chunking_recursive_character_text_splitter
( chunk_column pg_catalog.name
( chunk_column pg_catalog.name default ''
, chunk_size pg_catalog.int4 default 800
, chunk_overlap pg_catalog.int4 default 400
, separators pg_catalog.text[] default array[E'\n\n', E'\n', '.', '?', '!', ' ', '']
Expand Down Expand Up @@ -76,20 +76,22 @@ begin
end if;

_chunk_column = config operator(pg_catalog.->>) 'chunk_column';

select count(*) operator(pg_catalog.>) 0 into strict _found
from pg_catalog.pg_class k
inner join pg_catalog.pg_namespace n on (k.relnamespace operator(pg_catalog.=) n.oid)
inner join pg_catalog.pg_attribute a on (k.oid operator(pg_catalog.=) a.attrelid)
inner join pg_catalog.pg_type y on (a.atttypid operator(pg_catalog.=) y.oid)
where n.nspname operator(pg_catalog.=) source_schema
and k.relname operator(pg_catalog.=) source_table
and a.attnum operator(pg_catalog.>) 0
and a.attname operator(pg_catalog.=) _chunk_column
and y.typname in ('text', 'varchar', 'char', 'bpchar')
;
if not _found then
raise exception 'chunk column in config does not exist in the table: %', _chunk_column;

if _chunk_column operator(pg_catalog.!=) '' then
select count(*) operator(pg_catalog.>) 0 into strict _found
from pg_catalog.pg_class k
inner join pg_catalog.pg_namespace n on (k.relnamespace operator(pg_catalog.=) n.oid)
inner join pg_catalog.pg_attribute a on (k.oid operator(pg_catalog.=) a.attrelid)
inner join pg_catalog.pg_type y on (a.atttypid operator(pg_catalog.=) y.oid)
where n.nspname operator(pg_catalog.=) source_schema
and k.relname operator(pg_catalog.=) source_table
and a.attnum operator(pg_catalog.>) 0
and a.attname operator(pg_catalog.=) _chunk_column
and y.typname in ('text', 'varchar', 'char', 'bpchar')
;
if not _found then
raise exception 'chunk column in config does not exist in the table: %', _chunk_column;
end if;
end if;
end
$func$ language plpgsql stable security invoker
Expand Down
4 changes: 4 additions & 0 deletions projects/extension/sql/idempotent/013-vectorizer-api.sql
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ create or replace function ai.create_vectorizer
, formatting pg_catalog.jsonb default ai.formatting_python_template()
, scheduling pg_catalog.jsonb default ai.scheduling_default()
, processing pg_catalog.jsonb default ai.processing_default()
, loader pg_catalog.jsonb default null
, parser pg_catalog.jsonb default null
, target_schema pg_catalog.name default null
, target_table pg_catalog.name default null
, view_schema pg_catalog.name default null
Expand Down Expand Up @@ -259,6 +261,8 @@ begin
, 'formatting', formatting
, 'scheduling', scheduling
, 'processing', processing
, 'loader', loader
, 'parser', parser
)
);

Expand Down
14 changes: 14 additions & 0 deletions projects/extension/sql/idempotent/018-loader.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
-------------------------------------------------------------------------------
-- loader_file_loader
create or replace function ai.loader_file_loader
( url_column pg_catalog.text
) returns pg_catalog.jsonb
as $func$
select json_object
( 'implementation': 'file_loader'
, 'config_type': 'loader'
, 'url_column': url_column
)
$func$ language sql immutable security invoker
set search_path to pg_catalog, pg_temp
;
11 changes: 11 additions & 0 deletions projects/extension/sql/idempotent/019-parser.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-------------------------------------------------------------------------------
-- parser_pymupdf
create or replace function ai.parser_pymupdf() returns pg_catalog.jsonb
as $func$
select json_object
( 'implementation': 'pymupdf'
, 'config_type': 'parser'
)
$func$ language sql immutable security invoker
set search_path to pg_catalog, pg_temp
;
8 changes: 8 additions & 0 deletions projects/pgai/pgai/vectorizer/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Literal

from pydantic import BaseModel


class pgaiFileLoader(BaseModel):
implementation: Literal["file_loader"]
url_column: str
7 changes: 7 additions & 0 deletions projects/pgai/pgai/vectorizer/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Literal

from pydantic import BaseModel


class PyMuPDFParser(BaseModel):
implementation: Literal["pymupdf"]
26 changes: 26 additions & 0 deletions projects/pgai/pgai/vectorizer/vectorizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import io
import os
import threading
import time
Expand All @@ -7,10 +8,14 @@
from itertools import repeat
from typing import Any, TypeAlias

import fitz
import numpy as np
import psycopg
import pymupdf4llm
import smart_open
import structlog
from ddtrace import tracer
from filetype import filetype
from pgvector.psycopg import register_vector_async # type: ignore
from psycopg import AsyncConnection, sql
from psycopg.rows import dict_row
Expand All @@ -25,6 +30,8 @@
from .embedders import LiteLLM, Ollama, OpenAI, VoyageAI
from .embeddings import ChunkEmbeddingError
from .formatting import ChunkValue, PythonTemplate
from .loader import pgaiFileLoader
from .parser import PyMuPDFParser
from .processing import ProcessingDefault

logger = structlog.get_logger()
Expand Down Expand Up @@ -81,6 +88,8 @@ class Config:
LangChainCharacterTextSplitter | LangChainRecursiveCharacterTextSplitter
) = Field(..., discriminator="implementation")
formatting: PythonTemplate | ChunkValue = Field(..., discriminator="implementation")
parser: PyMuPDFParser | None = None
loader: pgaiFileLoader | None = None


@dataclass
Expand Down Expand Up @@ -708,6 +717,23 @@ async def _generate_embeddings(
documents: list[str] = []
for item in items:
pk = self._get_item_pk_values(item)
if self.vectorizer.config.loader is not None:
loader = self.vectorizer.config.loader
url = item[loader.url_column]
with smart_open.open(url, "rb") as file:
content = file.read()
file_like = io.BytesIO(content)
kind = filetype.guess(file_like)

if kind is None:
raise ValueError("Could not determine file type")

file_like.seek(0) # Reset buffer position
with fitz.open(stream=file_like, filetype="pdf") as pdf_document:
# Convert to markdown using pymupdf4llm
md_text = pymupdf4llm.to_markdown(pdf_document)
item.update({loader.url_column: md_text})

chunks = self.vectorizer.config.chunking.into_chunks(item)
for chunk_id, chunk in enumerate(chunks, 0):
formatted = self.vectorizer.config.formatting.format(chunk, item)
Expand Down
4 changes: 4 additions & 0 deletions projects/pgai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ dependencies = [
"litellm>=1.58.2,<1.59.0",
"google-cloud-aiplatform>=1.78.0,<2.0",
"boto3>=1.35.0,<2.0",
"pymupdf>=1.25.2",
"filetype>=1.2.0",
"smart-open>=7.1.0",
"pymupdf4llm>=0.0.17",
]
classifiers = [
"License :: OSI Approved :: PostgreSQL License",
Expand Down
4 changes: 2 additions & 2 deletions projects/pgai/tests/vectorizer/cli/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def configure_vectorizer(
formatting => ai.{formatting},
processing => ai.processing_default(batch_size => {batch_size},
concurrency => {concurrency}
{loader}
{parser}
)
{loader}
{parser}
)
""") # type: ignore
vectorizer_id: int = int(cur.fetchone()["create_vectorizer"]) # type: ignore
Expand Down
11 changes: 3 additions & 8 deletions projects/pgai/tests/vectorizer/cli/test_vectorizer_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@

s3_base = os.environ["S3_BASE_URL"]
docs = [
"sample_pdf.pdf",
"sample_with_table.pdf",
"stop recommending clean code.pdf",
"stop_recommending_clean_code.pdf",
"stop_recommending_clean_code.txt",
"r1cover.pdf",
]


Expand Down Expand Up @@ -49,7 +45,7 @@ def configure_document_vectorizer(
number_of_rows: int = 1,
concurrency: int = 1,
batch_size: int = 1,
chunking: str = "chunking_character_text_splitter()",
chunking: str = "chunking_character_text_splitter('url')",
formatting: str = "formatting_python_template('$chunk')",
) -> int:
"""Creates and configures a vectorizer for testing"""
Expand All @@ -66,7 +62,7 @@ def configure_document_vectorizer(
chunking=chunking,
formatting=formatting,
embedding=embedding,
loader="ai.loader_s3(url_column => 'url')", # this could also just be ai.file_loader() ?
loader="ai.loader_file_loader(url_column => 'url')",
parser="ai.parser_pymupdf()",
)

Expand All @@ -76,7 +72,6 @@ def test_simple_document_embedding(
):
"""Test that a document is successfully embedded"""
connection = cli_db[1]
setup_documents_table(connection, 1)
vectorizer_id = configure_document_vectorizer(cli_db[1])

run_vectorizer_worker(cli_db_url, vectorizer_id)
Expand Down
56 changes: 56 additions & 0 deletions projects/pgai/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9aa5c7d

Please sign in to comment.