Skip to content

Commit

Permalink
fix: rework retries logic so it's loader specific
Browse files Browse the repository at this point in the history
both in the extension and in the vectorizer.
  • Loading branch information
adolsalamanca committed Mar 5, 2025
1 parent a029829 commit c5ec403
Show file tree
Hide file tree
Showing 16 changed files with 125 additions and 82 deletions.
4 changes: 2 additions & 2 deletions projects/extension/sql/idempotent/012-vectorizer-int.sql
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,8 @@ begin
create table %I.%I
( %s
, queued_at pg_catalog.timestamptz not null default now()
, retries pg_catalog.int4 not null default 0
, retry_after pg_catalog.timestamptz
, loading_retries pg_catalog.int4 not null default 0
, loading_retry_after pg_catalog.timestamptz
)
$sql$
, queue_schema, queue_table
Expand Down
12 changes: 10 additions & 2 deletions projects/extension/sql/idempotent/013-loading.sql
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
-------------------------------------------------------------------------------
-- loading_column
create or replace function ai.loading_column
( column_name pg_catalog.name)
( column_name pg_catalog.name
, retries pg_catalog.int4 default 6)
returns pg_catalog.jsonb
as $func$
select json_object
( 'implementation': 'column'
, 'config_type': 'loading'
, 'column_name': column_name
, 'retries': retries
)
$func$ language sql immutable security invoker
set search_path to pg_catalog, pg_temp
Expand All @@ -16,13 +18,15 @@ set search_path to pg_catalog, pg_temp
-------------------------------------------------------------------------------
-- loading_uri
create or replace function ai.loading_uri
( column_name pg_catalog.name)
( column_name pg_catalog.name
, retries pg_catalog.int4 default 6)
returns pg_catalog.jsonb
as $func$
select json_object
( 'implementation': 'uri'
, 'config_type': 'loading'
, 'column_name': column_name
, 'retries': retries
)
$func$ language sql immutable security invoker
set search_path to pg_catalog, pg_temp
Expand Down Expand Up @@ -61,6 +65,10 @@ end if;
if _column_name is null then
raise exception 'invalid loading config, missing column_name';
end if;

if (config operator(pg_catalog.->>) 'retries')::int < 0 then
raise exception 'invalid loading config, retries must be a non-negative integer';
end if;

select y.typname into _column_type
from pg_catalog.pg_class k
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ BEGIN
'loading', json_object(
'implementation': 'column',
'config_type': 'loading',
'column_name': _chunk_column),
'column_name': _chunk_column,
'retries': 6
),
'parsing', json_object(
'implementation': 'auto',
'config_type': 'parsing'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ begin

select pg_catalog.format
( $sql$alter table %I.%I
add column if not exists retries pg_catalog.int4 not null default 0
, add column if not exists retry_after pg_catalog.timestamptz default null$sql$
add column if not exists loading_retries pg_catalog.int4 not null default 0
, add column if not exists loading_retry_after pg_catalog.timestamptz default null$sql$
, _rec.queue_schema
, _rec.queue_table
) into strict _sql;
Expand Down
4 changes: 2 additions & 2 deletions projects/extension/tests/contents/output16.expected
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ CREATE EXTENSION
function ai.litellm_embed(text,text[],text,text,jsonb,boolean)
function ai.load_dataset_multi_txn(text,text,text,name,name,text,jsonb,integer,integer,integer,jsonb)
function ai.load_dataset(text,text,text,name,name,text,jsonb,integer,integer,jsonb)
function ai.loading_column(name)
function ai.loading_uri(name)
function ai.loading_column(name,integer)
function ai.loading_uri(name,integer)
function ai.ollama_chat_complete(text,jsonb,text,text,jsonb,jsonb,jsonb,boolean)
function ai.ollama_embed(text,text,text,text,jsonb,boolean)
function ai.ollama_generate(text,text,text,bytea[],text,jsonb,text,text,integer[],boolean)
Expand Down
4 changes: 2 additions & 2 deletions projects/extension/tests/contents/output17.expected
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ CREATE EXTENSION
function ai.litellm_embed(text,text[],text,text,jsonb,boolean)
function ai.load_dataset_multi_txn(text,text,text,name,name,text,jsonb,integer,integer,integer,jsonb)
function ai.load_dataset(text,text,text,name,name,text,jsonb,integer,integer,jsonb)
function ai.loading_column(name)
function ai.loading_uri(name)
function ai.loading_column(name,integer)
function ai.loading_uri(name,integer)
function ai.ollama_chat_complete(text,jsonb,text,text,jsonb,jsonb,jsonb,boolean)
function ai.ollama_embed(text,text,text,text,jsonb,boolean)
function ai.ollama_generate(text,text,text,bytea[],text,jsonb,text,text,integer[],boolean)
Expand Down
16 changes: 8 additions & 8 deletions projects/extension/tests/privileges/function.expected
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,14 @@
p | bob | execute | no | ai | load_dataset_multi_txn(IN name text, IN config_name text, IN split text, IN schema_name name, IN table_name name, IN if_table_exists text, IN field_types jsonb, IN batch_size integer, IN max_batches integer, IN commit_every_n_batches integer, IN kwargs jsonb)
p | fred | execute | no | ai | load_dataset_multi_txn(IN name text, IN config_name text, IN split text, IN schema_name name, IN table_name name, IN if_table_exists text, IN field_types jsonb, IN batch_size integer, IN max_batches integer, IN commit_every_n_batches integer, IN kwargs jsonb)
p | jill | execute | YES | ai | load_dataset_multi_txn(IN name text, IN config_name text, IN split text, IN schema_name name, IN table_name name, IN if_table_exists text, IN field_types jsonb, IN batch_size integer, IN max_batches integer, IN commit_every_n_batches integer, IN kwargs jsonb)
f | alice | execute | YES | ai | loading_column(column_name name)
f | bob | execute | no | ai | loading_column(column_name name)
f | fred | execute | no | ai | loading_column(column_name name)
f | jill | execute | YES | ai | loading_column(column_name name)
f | alice | execute | YES | ai | loading_uri(column_name name)
f | bob | execute | no | ai | loading_uri(column_name name)
f | fred | execute | no | ai | loading_uri(column_name name)
f | jill | execute | YES | ai | loading_uri(column_name name)
f | alice | execute | YES | ai | loading_column(column_name name, retries integer)
f | bob | execute | no | ai | loading_column(column_name name, retries integer)
f | fred | execute | no | ai | loading_column(column_name name, retries integer)
f | jill | execute | YES | ai | loading_column(column_name name, retries integer)
f | alice | execute | YES | ai | loading_uri(column_name name, retries integer)
f | bob | execute | no | ai | loading_uri(column_name name, retries integer)
f | fred | execute | no | ai | loading_uri(column_name name, retries integer)
f | jill | execute | YES | ai | loading_uri(column_name name, retries integer)
f | alice | execute | YES | ai | ollama_chat_complete(model text, messages jsonb, host text, keep_alive text, chat_options jsonb, tools jsonb, response_format jsonb, "verbose" boolean)
f | bob | execute | no | ai | ollama_chat_complete(model text, messages jsonb, host text, keep_alive text, chat_options jsonb, tools jsonb, response_format jsonb, "verbose" boolean)
f | fred | execute | no | ai | ollama_chat_complete(model text, messages jsonb, host text, keep_alive text, chat_options jsonb, tools jsonb, response_format jsonb, "verbose" boolean)
Expand Down
27 changes: 27 additions & 0 deletions projects/extension/tests/vectorizer/test_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ def test_loading_column():
"config_type": "loading",
"implementation": "column",
"column_name": "content",
"retries": 6,
},
),
(
"select ai.loading_column('content', 10)",
{
"config_type": "loading",
"implementation": "column",
"column_name": "content",
"retries": 10,
},
),
]
Expand All @@ -43,6 +53,16 @@ def test_loading_uri():
"config_type": "loading",
"implementation": "uri",
"column_name": "s3_uri",
"retries": 6,
},
),
(
"select ai.loading_uri('s3_uri', 3)",
{
"config_type": "loading",
"implementation": "uri",
"column_name": "s3_uri",
"retries": 3,
},
),
]
Expand Down Expand Up @@ -126,6 +146,13 @@ def test_validate_loading():
""",
"invalid config_type for loading config",
),
(
"""
select ai._validate_loading
( ai.loading_column('body', -1), 'public', 'thing' )
""",
"invalid loading config, retries must be a non-negative integer",
),
]
with psycopg.connect(db_url("test"), autocommit=True) as con:
with con.cursor() as cur:
Expand Down
16 changes: 8 additions & 8 deletions projects/extension/tests/vectorizer/test_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"loading": {
"config_type": "loading",
"implementation": "column",
"retries": 6,
"column_name": "body"
},
"parsing": {
Expand Down Expand Up @@ -136,14 +137,13 @@


QUEUE_TABLE = """
Table "ai._vectorizer_q_1"
Column | Type | Collation | Nullable | Default | Storage | Compression | Stats target | Description
-------------+--------------------------+-----------+----------+---------+----------+-------------+--------------+-------------
title | text | | not null | | extended | | |
published | timestamp with time zone | | not null | | plain | | |
queued_at | timestamp with time zone | | not null | now() | plain | | |
retries | integer | | not null | 0 | plain | | |
retry_after | timestamp with time zone | | | | plain | | |
Table "ai._vectorizer_q_1"
Column | Type | Collation | Nullable | Default | Storage | Compression | Stats target | Description
---------------------+--------------------------+-----------+----------+---------+---------+-------------+--------------+-------------
id | integer | | not null | | plain | | |
queued_at | timestamp with time zone | | not null | now() | plain | | |
loading_retries | integer | | not null | 0 | plain | | |
loading_retry_after | timestamp with time zone | | | | plain | | |
Indexes:
"_vectorizer_q_1_title_published_idx" btree (title, published)
Access method: heap
Expand Down
7 changes: 7 additions & 0 deletions projects/pgai/.vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
32 changes: 13 additions & 19 deletions projects/pgai/pgai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,21 +136,20 @@ def run_vectorizer(
db_url: str,
vectorizer: Vectorizer,
concurrency: int,
loading_retries: int,
features: Features,
) -> None:
async def run_workers(
db_url: str, vectorizer: Vectorizer, concurrency: int, loading_retries: int
db_url: str,
vectorizer: Vectorizer,
concurrency: int,
) -> list[int]:
tasks = [
asyncio.create_task(
Worker(db_url, vectorizer, features, loading_retries).run()
)
asyncio.create_task(Worker(db_url, vectorizer, features).run())
for _ in range(concurrency)
]
return await asyncio.gather(*tasks)

results = asyncio.run(run_workers(db_url, vectorizer, concurrency, loading_retries))
results = asyncio.run(run_workers(db_url, vectorizer, concurrency))
items = sum(results)
log.info("finished processing vectorizer", items=items, vectorizer_id=vectorizer.id)

Expand Down Expand Up @@ -233,7 +232,10 @@ def shutdown_handler(signum: int, _frame: Any):
type=TimeDurationParamType(),
default="5m",
show_default=True,
help="The interval, in duration string or integer (seconds), to wait before checking for new work after processing all available work in the queue.", # noqa
help="The interval, in duration string or integer (seconds), "
"to wait before checking for new work after processing "
"all available work in the queue.",
# noqa
)
@click.option(
"--once",
Expand All @@ -250,13 +252,6 @@ def shutdown_handler(signum: int, _frame: Any):
show_default=True,
help="Exit immediately when an error occurs.",
)
@click.option(
"--loading-retries",
type=click.INT,
default=6,
show_default=True,
help="Number of retries for loading processing.",
)
def vectorizer_worker(
db_url: str,
vectorizer_ids: Sequence[int],
Expand All @@ -265,7 +260,6 @@ def vectorizer_worker(
poll_interval: int,
once: bool,
exit_on_error: bool | None,
loading_retries: int,
) -> None:
# gracefully handle being asked to shut down
signal.signal(signal.SIGINT, shutdown_handler)
Expand Down Expand Up @@ -313,7 +307,9 @@ def vectorizer_worker(
)
if len(valid_vectorizer_ids) != len(vectorizer_ids):
log.error(
f"invalid vectorizers, wanted: {list(vectorizer_ids)}, got: {valid_vectorizer_ids}" # noqa: E501 (line too long)
f"invalid vectorizers, wanted: {list(vectorizer_ids)},"
f" got: {valid_vectorizer_ids}"
# noqa: E501 (line too long)
)
if exit_on_error:
sys.exit(1)
Expand All @@ -329,9 +325,7 @@ def vectorizer_worker(
try:
vectorizer = get_vectorizer(db_url, vectorizer_id)
log.info("running vectorizer", vectorizer_id=vectorizer_id)
run_vectorizer(
db_url, vectorizer, concurrency, loading_retries, features
)
run_vectorizer(db_url, vectorizer, concurrency, features)
except (VectorizerNotFoundError, ApiKeyNotFoundError) as e:
log.error(
f"error getting vectorizer: {type(e).__name__}: {str(e)} "
Expand Down
2 changes: 2 additions & 0 deletions projects/pgai/pgai/vectorizer/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def guess_filetype(file_like: BytesIO, file_path: str | None = None) -> str | No
class RowLoading(BaseModel):
implementation: Literal["column"]
column_name: str
retries: int = 6

def load(self, row: dict[str, str]) -> str | LoadedDocument:
content = row[self.column_name] or ""
Expand All @@ -42,6 +43,7 @@ def load(self, row: dict[str, str]) -> str | LoadedDocument:
class UriLoading(BaseModel):
implementation: Literal["uri"]
column_name: str
retries: int = 6

def load(self, row: dict[str, str]) -> LoadedDocument:
content = BytesIO(
Expand Down
18 changes: 8 additions & 10 deletions projects/pgai/pgai/vectorizer/vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def fetch_work_query(self) -> sql.Composed:
WITH selected_rows AS (
SELECT {pk_fields}
FROM {queue_table}
WHERE retry_after is null or retry_after < now()
WHERE loading_retry_after is null or loading_retry_after < now()
LIMIT %s
FOR UPDATE SKIP LOCKED
),
Expand Down Expand Up @@ -374,19 +374,19 @@ def requeue_or_remove_work_query(self) -> sql.Composed:
WHERE
({pk_fields}) IN ({pk_values})
AND
(retry_after is null or retry_after < now())
(loading_retry_after is null or loading_retry_after < now())
) AS source
ON {merge_predicates}
WHEN MATCHED
AND target.retries >= %(loading_retries)s THEN
AND target.loading_retries >= %(loading_retries)s THEN
DELETE
WHEN MATCHED THEN
UPDATE
SET
retries = target.retries + 1,
retry_after = now() +
(INTERVAL '3 minutes' * (target.retries + 1))
RETURNING target.retries < %(loading_retries)s AS is_retryable
loading_retries = target.loading_retries + 1,
loading_retry_after = now() +
(INTERVAL '3 minutes' * (target.loading_retries + 1))
RETURNING target.loading_retries < %(loading_retries)s AS is_retryable
""").format(
pk_fields=self.pk_fields_sql,
queue_table=sql.Identifier(
Expand Down Expand Up @@ -504,12 +504,10 @@ def __init__(
db_url: str,
vectorizer: Vectorizer,
features: Features,
loading_retries: int,
should_continue_processing_hook: None | Callable[[int, int], bool] = None,
):
self.db_url = db_url
self.vectorizer = vectorizer
self.loading_retries = loading_retries
self.queries = VectorizerQueryBuilder(vectorizer)
self._should_continue_processing_hook = should_continue_processing_hook or (
lambda _loops, _res: True
Expand Down Expand Up @@ -566,7 +564,7 @@ async def run(self) -> int:
except UriLoadingError as e:
async with conn.transaction():
is_retryable = await self._requeue_or_remove_work(
conn, self.loading_retries, e.pk_values
conn, self.vectorizer.config.loading.retries, e.pk_values
)
await self._insert_vectorizer_error(
conn,
Expand Down
4 changes: 2 additions & 2 deletions projects/pgai/tests/vectorizer/cli/test_vectorizer_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def should_continue_processing_hook(_loops: int, _res: int) -> bool:
):
results = asyncio.run(
Worker(
cli_db_url, vectorizer, features, 6, should_continue_processing_hook
cli_db_url, vectorizer, features, should_continue_processing_hook
).run()
)
# Then it successfully exits after the first batch.
Expand Down Expand Up @@ -412,7 +412,7 @@ def test_disabled_vectorizer_is_backwards_compatible(

# When the vectorizer is executed.
with vcr_.use_cassette("test_disabled_vectorizer_is_backwards_compatible.yaml"):
results = asyncio.run(Worker(cli_db_url, vectorizer, features, 6).run())
results = asyncio.run(Worker(cli_db_url, vectorizer, features).run())

# Then the disable is ignored and the vectorizer successfully exits after
# processing the batches.
Expand Down
Loading

0 comments on commit c5ec403

Please sign in to comment.