diff --git a/projects/extension/sql/idempotent/008-embedding.sql b/projects/extension/sql/idempotent/008-embedding.sql index e03e21ed5..90703e2e6 100644 --- a/projects/extension/sql/idempotent/008-embedding.sql +++ b/projects/extension/sql/idempotent/008-embedding.sql @@ -6,8 +6,20 @@ create or replace function ai.embedding_openai , dimensions pg_catalog.int4 , chat_user pg_catalog.text default null , api_key_name pg_catalog.text default 'OPENAI_API_KEY' +, use_batch_api pg_catalog.bool default false +, embedding_batch_schema pg_catalog.name default null +, embedding_batch_table pg_catalog.name default null +, embedding_batch_chunks_table pg_catalog.name default null ) returns pg_catalog.jsonb as $func$ +declare + _vectorizer_id pg_catalog.int4; +begin + _vectorizer_id = pg_catalog.nextval('ai.vectorizer_id_seq'::pg_catalog.regclass); + embedding_batch_schema = coalesce(embedding_batch_schema, 'ai'); + embedding_batch_table = coalesce(embedding_batch_table, pg_catalog.concat('_vectorizer_embedding_batches_', _vectorizer_id)); + embedding_batch_chunks_table = coalesce(embedding_batch_chunks_table, pg_catalog.concat('_vectorizer_embedding_batch_chunks_', _vectorizer_id)); + select json_object ( 'implementation': 'openai' , 'config_type': 'embedding' @@ -15,6 +27,10 @@ as $func$ , 'dimensions': dimensions , 'user': chat_user , 'api_key_name': api_key_name + , 'use_batch_api': use_batch_api + , 'embedding_batch_schema': embedding_batch_schema + , 'embedding_batch_table': embedding_batch_table + , 'embedding_batch_chunks_table': embedding_batch_chunks_table absent on null ) $func$ language sql immutable security invoker @@ -81,6 +97,9 @@ as $func$ declare _config_type pg_catalog.text; _implementation pg_catalog.text; + _embedding_batch_schema pg_catalog.text; + _embedding_batch_table pg_catalog.text; + _embedding_batch_chunks_table pg_catalog.text; begin if pg_catalog.jsonb_typeof(config) operator(pg_catalog.!=) 'object' then raise exception 'embedding config is not a jsonb object'; @@ -93,6 +112,19 @@ begin _implementation = config operator(pg_catalog.->>) 'implementation'; case _implementation when 'openai' then + -- make sure embedding batch table name is available + select (config operator (pg_catalog.->> 'embedding_batch_schema'))::text into _embedding_batch_schema; + select (config operator (pg_catalog.->> 'embedding_batch_table'))::text into _embedding_batch_table; + select (config operator (pg_catalog.->> 'embedding_batch_chunks_table'))::text into _embedding_batch_chunks_table; + if pg_catalog.to_regclass(pg_catalog.format('%I.%I', _embedding_batch_schema, _embedding_batch_table)) is not null then + raise exception 'an object named %.% already exists. specify an alternate embedding_batch_table explicitly', queue_schema, queue_table; + end if; + + -- make sure embedding batch chunks table name is available + if pg_catalog.to_regclass(pg_catalog.format('%I.%I', _embedding_batch_schema, _embedding_batch_chunks_table)) is not null then + raise exception 'an object named %.% already exists. specify an alternate embedding_batch_chunks_table explicitly', queue_schema, queue_table; + end if; + -- ok when 'ollama' then -- ok diff --git a/projects/extension/sql/idempotent/013-vectorizer-api.sql b/projects/extension/sql/idempotent/013-vectorizer-api.sql index 66fff6286..e81415abb 100644 --- a/projects/extension/sql/idempotent/013-vectorizer-api.sql +++ b/projects/extension/sql/idempotent/013-vectorizer-api.sql @@ -1,5 +1,3 @@ - - ------------------------------------------------------------------------------- -- execute_vectorizer create or replace function ai.execute_vectorizer(vectorizer_id pg_catalog.int4) returns void @@ -44,6 +42,7 @@ declare _vectorizer_id pg_catalog.int4; _sql pg_catalog.text; _job_id pg_catalog.int8; + _implementation pg_catalog.text; begin -- make sure all the roles listed in grant_to exist if grant_to is not null then @@ -225,6 +224,17 @@ begin scheduling = pg_catalog.jsonb_insert(scheduling, array['job_id'], pg_catalog.to_jsonb(_job_id)); end if; + -- create batch embedding tables + select (embedding operator (pg_catalog.->> 'implementation'))::text into _implementation; + if _implementation = 'openai' then + perform ai._vectorizer_create_embedding_batches_table + (embedding_batch_schema + , embedding_batch_table + , embedding_batch_chunks_table + , grant_to + ); + end if; + insert into ai.vectorizer ( id , source_schema diff --git a/projects/extension/sql/idempotent/016-openai-batch-api.sql b/projects/extension/sql/idempotent/016-openai-batch-api.sql new file mode 100644 index 000000000..863863bb4 --- /dev/null +++ b/projects/extension/sql/idempotent/016-openai-batch-api.sql @@ -0,0 +1,98 @@ +------------------------------------------------------------------------------- +-- _vectorizer_create_queue_table +create or replace function ai._vectorizer_create_embedding_batches_table +( embedding_batch_schema name +, embedding_batch_table name +, embedding_batch_chunks_table name +, grant_to name[] +) returns void as +$func$ +declare + _sql text; +begin + -- create the batches table + select pg_catalog.format + ( $sql$create table %I.%I( + external_batch_id VARCHAR(255) PRIMARY KEY, + input_file_id VARCHAR(255) NOT NULL, + output_file_id VARCHAR(255), + status VARCHAR(255) NOT NULL, + errors JSONB, + created_at TIMESTAMP(0) NOT NULL DEFAULT NOW(), + expires_at TIMESTAMP(0), + completed_at TIMESTAMP(0), + failed_at TIMESTAMP(0), + next_attempt_after TIMESTAMPTZ, + total_attempts BIGINT NOT NULL DEFAULT 0 +))$sql$ + , embedding_batch_schema + , embedding_batch_table + ) into strict _sql + ; + execute _sql; + + -- create the index + select pg_catalog.format + ( $sql$create index on %I.%I (status)$sql$ + , embedding_batch_schema, embedding_batch_table + ) into strict _sql + ; + execute _sql; + + -- create the batch chunks table + select pg_catalog.format + ( $sql$create table %I.%I( + id VARCHAR(255) PRIMARY KEY, + embedding_batch_id VARCHAR(255) REFERENCES %I.%I (external_batch_id) ON DELETE CASCADE, + chunk TEXT +))$sql$ + , embedding_batch_schema + , embedding_batch_chunks_table + , embedding_batch_schema + , embedding_batch_table + ) into strict _sql + ; + execute _sql; + + if grant_to is not null then + -- grant usage on queue schema to grant_to roles + select pg_catalog.format + ( $sql$grant usage on schema %I to %s$sql$ + , embedding_batch_schema + , ( + select pg_catalog.string_agg(pg_catalog.quote_ident(x), ', ') + from pg_catalog.unnest(grant_to) x + ) + ) into strict _sql; + execute _sql; + + -- grant select, update, delete on batches table to grant_to roles + select pg_catalog.format + ( $sql$grant select, insert, update, delete on %I.%I to %s$sql$ + , embedding_batch_schema + , embedding_batch_table + , ( + select pg_catalog.string_agg(pg_catalog.quote_ident(x), ', ') + from pg_catalog.unnest(grant_to) x + ) + ) into strict _sql; + execute _sql; + + -- grant select, update, delete on batch chunks table to grant_to roles + select pg_catalog.format + ( $sql$grant select, insert, update, delete on %I.%I to %s$sql$ + , embedding_batch_schema + , embedding_batch_chunks_table + , ( + select pg_catalog.string_agg(pg_catalog.quote_ident(x), ', ') + from pg_catalog.unnest(grant_to) x + ) + ) into strict _sql; + execute _sql; + end if; +end; +$func$ + language plpgsql volatile security invoker + set search_path to pg_catalog, pg_temp +; + diff --git a/projects/pgai/pgai/vectorizer/embedders/openai.py b/projects/pgai/pgai/vectorizer/embedders/openai.py index 7c4e9aa5c..cb5958fe4 100644 --- a/projects/pgai/pgai/vectorizer/embedders/openai.py +++ b/projects/pgai/pgai/vectorizer/embedders/openai.py @@ -1,4 +1,6 @@ +import json import re +import tempfile from collections.abc import Sequence from functools import cached_property from typing import Any, Literal @@ -8,6 +10,7 @@ from openai import resources from pydantic import BaseModel from typing_extensions import override +from psycopg import AsyncConnection from ..embeddings import ( ApiKeyMixin, @@ -21,6 +24,7 @@ Usage, logger, ) +from ..vectorizer import AsyncBatch TOKEN_CONTEXT_LENGTH_ERROR = "chunk exceeds model context length" @@ -39,12 +43,20 @@ class OpenAI(ApiKeyMixin, BaseModel, Embedder): model (str): The name of the OpenAI model used for embeddings. dimensions (int | None): Optional dimensions for the embeddings. user (str | None): Optional user identifier for OpenAI API usage. + use_batch (bool): Whether to use OpenAI Batch API. + embedding_batch_schema (str | None): The schema where the embedding batches are stored. + embedding_batch_table (str | None): The table where the embedding batches are stored. + embedding_batch_chunks_table (str | None): The table where the embedding batch chunks are stored. """ implementation: Literal["openai"] model: str dimensions: int | None = None user: str | None = None + use_batch: bool = False + embedding_batch_schema: str | None = None + embedding_batch_table: str | None = None + embedding_batch_chunks_table: str | None = None @cached_property def _openai_dimensions(self) -> int | openai.NotGiven: @@ -58,9 +70,13 @@ def _openai_dimensions(self) -> int | openai.NotGiven: def _openai_user(self) -> str | openai.NotGiven: return self.user if self.user is not None else openai.NOT_GIVEN + @cached_property + def _client(self) -> resources.Client: + return openai.AsyncOpenAI(api_key=self._api_key, max_retries=3) + @cached_property def _embedder(self) -> resources.AsyncEmbeddings: - return openai.AsyncOpenAI(api_key=self._api_key, max_retries=3).embeddings + return self._client.embeddings @override def _max_chunks_per_batch(self) -> int: @@ -129,6 +145,55 @@ async def embed( model_token_length, encoded_documents ) + async def create_and_submit_embedding_batch( + self, + documents: list[dict[str, Any]], + ) -> AsyncBatch: + """ + Creates a batch of embeddings using OpenAI's embeddings API as outlined in + https://platform.openai.com/docs/guides/batch/batch-api?lang=python + + Args: + documents (list[str]): A list of document chunks to be embedded. + + Returns: + + """ + + with tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl", mode="w") as temp_file: + for document in documents: + entry = { + "custom_id": document["unique_full_chunk_id"], + "method": "POST", + "url": "/v1/embeddings", + "body": { + "model": self.model, + "input": document["chunk"], + }, + } + temp_file.write(json.dumps(entry) + "\n") + + temp_file.close() + + with open(temp_file.name, "rb") as file: + batch_input_file = self._client.files.create( + file=file, + purpose="batch", + ) + + openai_batch = self._client.batches.create( + input_file_id=batch_input_file.id, + endpoint="/v1/embeddings", + completion_window="24h", + ) + + batch = AsyncBatch() + batch.external_batch_id = openai_batch.id + batch.input_file_id = openai_batch.input_file_id + batch.status = openai_batch.status + + return batch + async def _filter_by_length_and_embed( self, model_token_length: int, encoded_documents: list[list[int]] ) -> Sequence[EmbeddingVector | ChunkEmbeddingError]: @@ -200,3 +265,83 @@ async def _encode(self, documents: list[str]) -> list[list[int]]: @cached_property def _encoder(self) -> tiktoken.Encoding: return tiktoken.encoding_for_model(self.model) + + def is_api_async(self) -> bool: + return self.use_batch + + async def fetch_async_embedding_status(self, batch: AsyncBatch) -> AsyncBatch: + openai_batch = self._client.batches.retrieve(batch.external_batch_id) + + batch.status = openai_batch.status + batch.completed_at = openai_batch.completed_at + batch.failed_at = openai_batch.failed_at + batch.errors = openai_batch.errors + + return batch + + async def process_async_embedding( + self, + conn: AsyncConnection, + batch: AsyncBatch, + ): + """ + Writes embeddings from an OpenAI batch embedding to the database. + + - Deletes existing embeddings for the items. + - Loads created embeddings from the batch. + - Writes created embeddings to the database. + - Logs any non-fatal errors encountered during embedding. + + Args: + conn (AsyncConnection): The database connection. + batch: The batch as stored in the queue table. + """ + openai_batch = self._client.batches.retrieve(batch.external_batch_id) + batch_file = self._client.files.content(openai_batch.output_file_id) + + batch_data = batch_file.text.strip().split("\n") + num_records = 0 + all_items = [] + all_records: list[EmbeddingRecord] = [] + + async with conn.cursor() as cursor: + await cursor.execute( + self.queries.fetch_chunks_for_batch_id_query + (batch.id,) + ) + embedding_batch_chunks = {row[0]: row[1] for row in await cursor.fetchall()} + + for line in batch_data: + json_line = json.loads(line) + if "custom_id" in json_line and "response" in json_line: + + custom_id = json_line["custom_id"] + pk_names, document_id, chunk_seq = custom_id.split(":::") + embedding_data = json_line["response"]["body"]["data"][0]["embedding"] + + resolved_id = document_id.split(",") + resolved_pk = pk_names.split(",") + item = {pk: id_value + for pk, id_value in zip(resolved_pk, resolved_id, strict=False)} + item[self.vectorizer.config.chunking.chunk_column] = embedding_batch_chunks[custom_id] + + all_items.append(item) + all_records.append([ + resolved_id + + [chunk_seq, embedding_batch_chunks[custom_id]] + + [np.array(embedding_data)]]) + + await self._delete_embeddings(conn, all_items) + for records in all_records: + await self._copy_embeddings(conn, records) + + return num_records + + + async def finalize_async_embedding( + self, + batch: AsyncBatch, + ): + openai_batch = self._client.batches.retrieve(batch.external_batch_id) + await self._client.files.delete(openai_batch.input_file_id) + await self._client.files.delete(openai_batch.output_file_id) diff --git a/projects/pgai/pgai/vectorizer/embeddings.py b/projects/pgai/pgai/vectorizer/embeddings.py index a40f91f02..96d722d3e 100644 --- a/projects/pgai/pgai/vectorizer/embeddings.py +++ b/projects/pgai/pgai/vectorizer/embeddings.py @@ -3,11 +3,14 @@ from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass -from typing import Generic, TypeAlias, TypeVar +from typing import Generic, TypeAlias, TypeVar, Any +from psycopg import AsyncConnection import structlog from ddtrace import tracer +from .vectorizer import AsyncBatch + logger = structlog.get_logger() @@ -164,6 +167,55 @@ async def setup(self) -> None: # noqa: B027 empty on purpose Setup the embedder """ + @abstractmethod + def is_api_async(self) -> bool: + return False + + @abstractmethod + async def fetch_async_embedding_status(self, batch: AsyncBatch) -> AsyncBatch: + """ + Will receive a row from the batch embeddings queue table and should + check if the embedding has been processed and is ready to be stored. + + If it is ready, the status of the async batch should be set to "completed". + """ + + @abstractmethod + async def process_async_embedding( + self, + conn: AsyncConnection, + batch: AsyncBatch, + ): + """ + Writes embeddings from a batch embedding to the database. + + - Deletes existing embeddings for the items. + - Loads created embeddings from the batch. + - Writes created embeddings to the database. + - Logs any non-fatal errors encountered during embedding. + + Args: + conn (AsyncConnection): The database connection. + batch: The batch as retrieved from the database. + """ + + async def finalize_async_embedding( + self, + batch: AsyncBatch, + ): + """ + When the batch was processed, this method allows to clean up any + files from the external service. + """ + + @abstractmethod + async def create_and_submit_embedding_batch( + self, + documents: list[dict[str, Any]], + ) -> AsyncBatch: + """ + Receives a bunch of documents and creates a batch of documents for it with an external service. + """ class ApiKeyMixin: """ diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index d57c249f5..1fc4e5c2d 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -3,6 +3,7 @@ import threading import time from collections.abc import Callable +from datetime import datetime, timezone from functools import cached_property from itertools import repeat from typing import Any, TypeAlias @@ -83,6 +84,36 @@ class Config: formatting: PythonTemplate | ChunkValue = Field(..., discriminator="implementation") +@dataclass +class AsyncBatch: + """ + Represents a record in the external batch table. + + Attributes: + external_batch_id (str): Primary key of the batch. + input_file_id (str): The ID of the input file. This is mandatory. + status (str): The current status of the batch. + errors (dict | None): Dictionary representing error details in JSONB format. + created_at (datetime): The timestamp when the record was created (defaults to current time). + expires_at (datetime | None): The optional expiration timestamp for the batch. + completed_at (datetime | None): The optional timestamp when the batch processing was completed. + failed_at (datetime | None): The timestamp when the batch processing failed (if applicable). + next_attempt_after (datetime | None): The timestamp when the batch can be retried next. + total_attempts (int): Count of the total number of attempts made to process this batch. + """ + + external_batch_id: str + input_file_id: str + status: str + errors: dict | None = None + created_at: datetime = Field(default_factory=datetime.now) + expires_at: datetime | None = None + completed_at: datetime | None = None + failed_at: datetime | None = None + next_attempt_after: datetime | None = None + total_attempts: int = 0 + + @dataclass class Vectorizer: """ @@ -319,6 +350,116 @@ def insert_errors_query(self) -> sql.Composed: self.errors_table_ident, ) + @cached_property + def fetch_batches_to_process_query(self) -> sql.Composed: + batch_schema = self.vectorizer.config.embedding.batch_schema + batch_table = self.vectorizer.config.embedding.batch_table + + return sql.SQL( + """ + WITH locked_rows AS ( + SELECT external_batch_id + FROM {batch_table} + WHERE next_attempt_after is null or next_attempt_after < NOW() + ORDER BY created_at DESC + LIMIT 1 + FOR UPDATE SKIP LOCKED + ), + UPDATE + {batch_table} batches + SET + total_attempts = batches.total_attempts + 1, + next_attempt_after = %s + FRO + locked_rows l + WHERE + l.external_batch_id = cfw.external_batch_id + RETURNING l.external_batch_id + """ + ).format(batch_table=sql.Identifier(batch_schema, batch_table)) + + @cached_property + def update_batch_embedding_query(self) -> sql.Composed: + return sql.SQL(""" + UPDATE {}.{} SET + status = %s + completed_at = %s, + failed_at = %s, + errors = %s + WHERE external_batch_id = %s + """).format( + self.vectorizer.config.embedding.embedding_batch_schema, + self.vectorizer.config.embedding.embedding_batch_table, + ) + + @cached_property + def delete_batch_embedding_from_queue_query(self) -> sql.Composed: + return sql.SQL(""" + DELETE FROM {}.{} + WHERE external_batch_id = %s + """).format( + self.vectorizer.config.embedding.embedding_batch_schema, + self.vectorizer.config.embedding.embedding_batch_table, + ) + + @cached_property + def update_batch_embedding_status_query(self) -> sql.Composed: + return sql.SQL( + "UPDATE {}.{} SET status = %s WHERE external_batch_id = %s" + ).format( + self.vectorizer.config.embedding.embedding_batch_schema, + self.vectorizer.config.embedding.embedding_batch_table, + ) + + @cached_property + def fetch_chunks_for_batch_id_query(self) -> sql.Composed: + return sql.SQL( + "SELECT id, chunk FROM {}.{} WHERE embedding_batch_id = %s", + ).format( + self.vectorizer.config.embedding.embedding_batch_schema, + self.vectorizer.config.embedding.embedding_batch_chunks_table, + ) + + @cached_property + def insert_batch_embedding_query(self) -> sql.Composed: + return sql.SQL(""" + INSERT INTO {}.{} ( + external_batch_id, + input_file_id, + output_file_id, + status, + errors, + expires_at + ) VALUES ( + %s, + %s, + %s, + %s, + %s, + %s + ) + """).format( + self.vectorizer.config.embedding.embedding_batch_schema, + self.vectorizer.config.embedding.embedding_batch_table, + ) + + @cached_property + def insert_batch_embedding_chunks_query(self) -> sql.Composed: + return sql.SQL(""" + INSERT INTO {}.{} ( + id, + embedding_batch_id, + chunk + ) VALUES ( + %s, + %s, + %s + ) + """).format( + self.vectorizer.config.embedding.embedding_batch_schema, + self.vectorizer.config.embedding.embedding_batch_chunks_table, + ) + def _pks_placeholders_tuples(self, items_count: int) -> sql.Composed: """Generates a comma separated list of tuples with placeholders for the primary key fields of the source table. @@ -465,6 +606,10 @@ async def run(self) -> int: await register_vector_async(conn) await self.vectorizer.config.embedding.setup() while True: + if self.vectorizer.config.embedding.is_api_async(): + res = await self._process_async_embeddings(conn) + return res + if not self._continue_processing(loops, res): return res items_processed = await self._do_batch(conn) @@ -473,6 +618,127 @@ async def run(self) -> int: res += items_processed loops += 1 + async def _process_async_embeddings(self, conn): + async with conn.transaction(): + await self.check_and_store_async_batches(conn) + await self.create_async_batches(conn) + + async def check_and_store_async_batches(self, conn: AsyncConnection): + """ + Checks if chunks submitted with create_async_batches completed and + stores them when they are completed. + + This function is only called when is_api_async returns true. + + Args: + conn (AsyncConnection): The asynchronous database connection. + """ + async with conn.cursor() as cursor: + with conn.transaction(): + await cursor.execute(self.queries.fetch_batches_to_process_query) + + for batch_row in await cursor.fetchall(): + + batch = AsyncBatch(**batch_row) + batch = self.vectorizer.config.embedding.fetch_async_embedding_status(batch) + + with conn.transaction(): + await conn.execute(self.queries.update_batch_embedding_query, ( + batch.status, + datetime.fromtimestamp(batch.completed_at, timezone.utc) + if batch.completed_at else None, + datetime.fromtimestamp(batch.failed_at, timezone.utc) + if batch.failed_at else None, + Jsonb(batch.errors), + batch.external_batch_id, + )) + + # batch has been processed successfully by the external api, that means we can + # collect the results and store them in the database. + if batch.status == "completed": + + with conn.transaction(): + await self.vectorizer.config.embedding.write_embeddings_from_batch(conn, batch) + + batch.status = "processed" + + await cursor.execute( + self.queries.update_batch_embedding_status_query, + ( + batch.status, + batch.external_batch_id, + )) + + if batch.status == "processed": + with conn.transaction(): + await self.vectorizer.config.embedding.finalize_async_embedding(batch) + await cursor.execute( + self.queries.delete_batch_embedding_from_queue_query, + ( + batch.external_batch_id, + )) + + + async def create_async_batches(self, conn: AsyncConnection) -> int: + """ + Submits chunks for async embedding processing. + This allows to process very large amounts of data faster than with the + embeddings api, because batch apis usually have vastly higher rate limits. + + This function is only called when is_api_async returns true. + + Args: + conn (AsyncConnection): The asynchronous database connection. + """ + try: + items = await self._fetch_work(conn) + + await logger.adebug(f"Items pulled from queue for batch embedding: {len(items)}") + + # Filter out items that were deleted from the source table. + # We use the first primary key column, since they can only + # be null if the LEFT JOIN didn't find a match. + items = [ + i + for i in items + if i[self.vectorizer.source_pk[0].attname] is not None + ] + + if len(items) == 0: + return 0 + + created_batch, documents = await self._generate_embedding_batch(items) + + await conn.execute(self.queries.insert_batch_embedding_query, ( + created_batch.external_id, + created_batch.input_file_id, + created_batch.output_file_id, + created_batch.status, + created_batch.errors, + datetime.fromtimestamp(created_batch.expires_at, timezone.utc), + )) + + for doc in documents: + await conn.execute( + self.queries.insert_batch_embedding_chunks_query, + ( + doc["unique_full_chunk_id"], + created_batch.id, + doc["chunk"] + )) + + return len(items) + except Exception as e: + await self._insert_vectorizer_error( + conn, + ( + self.vectorizer.id, + VECTORIZER_FAILED, + Jsonb({"error_reason": str(e)}), + ), + ) + raise e + @tracer.wrap() async def _do_batch(self, conn: AsyncConnection) -> int: """ @@ -732,6 +998,31 @@ async def _generate_embeddings( records.append(record + [np.array(embedding)]) return records, errors + async def _generate_embedding_batch( + self, items: list[SourceRow] + ) -> tuple[AsyncBatch, list[dict[str, Any]]]: + documents: list[dict[str, Any]] = [] + for item in items: + pk = self._get_item_pk_values(item) + chunks = self.vectorizer.config.chunking.into_chunks(item) + for chunk_id, chunk in enumerate(chunks, 0): + formatted = self.vectorizer.config.formatting.format(chunk, item) + unique_full_chunk_id = [ + ",".join(self.queries.pk_attnames), + ",".join(map(str, pk)), + str(chunk_id), + ] + documents.append({ + "unique_full_chunk_id": ":::".join(unique_full_chunk_id), + "chunk": formatted, + }) + + try: + batch = await self.vectorizer.config.embedding.create_and_submit_embedding_batch(documents) + return batch, documents + except Exception as e: + raise EmbeddingProviderError() from e + def _vectorizer_error_record( self, record: EmbeddingRecord, chunk_error: ChunkEmbeddingError ) -> VectorizerErrorRecord: