Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

90 fix test weighted search fail in ailab db #102

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
root = true

# Unix-style newlines with a newline ending every file
[*]
charset = utf-8
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true

# Exclude binary files
[*.svg]
insert_final_newline = false
48 changes: 17 additions & 31 deletions DEVELOPER.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,24 @@

### Run latest schema locally

* Setup .env environment variables
* LOUIS_DSN: Data Source Name (DSN) used for configuring a database connection in Louis's system. It should follow this pattern, replacing each variable with your own values :
1. Setup .env environment variables
* **LOUIS_DSN:** Data Source Name (DSN) used for configuring a database connection in Louis's system. It should follow this pattern, replacing each variable with your own values :
`LOUIS_DSN=postgresql://PGUSER:PGPASSWORD@DB_SERVER_CONTAINER_NAME/PGBASE`

* PGBASE: The base directory where PostgreSQL related files or resources are stored or accessed. it can be the name of your folder followed by test (ex: louis-test).

* PGUSER: The username or role required to authenticate and access a PostgreSQL database.

* USER: The username required for validation and access, it can be the same as PGUSER.

* PGHOST: The hostname or IP address of the server where the PostgreSQL database is hosted. If you want to use it locally, it should be `localhost`.

* PGPASSWORD: The password for the user authentication when connecting to the PostgreSQL database.

* POSTGRES_PASSWORD: The password for the database, for authentication when connecting to the PostgreSQL database.

* PGDATA: Path to the directory where PostgreSQL data files are stored. If it's not set, it will automatically select it for you.

* OPENAI_API_KEY: The API key required for authentication when making requests to the OpenAI API. It can be found [here](https://portal.azure.com/#home).

* OPENAI_ENDPOINT: The link used to call into Azure OpenAI endpoints. It can be found at the same place as the OPENAI_API_KEY.

* OPENAI_API_ENGINE: The name of the model deployment you want to use (ex:ailab-gpt-35-turbo).

* LOUIS_SCHEMA: The Louis schema within database (ex: louis_v005).

* DB_SERVER_CONTAINER_NAME: The name of your database server container (ex: louis-db-server).

* AILAB_SCHEMA_VERSION: The version of the schema you want to use.

* Run database locally (see bin/postgres.sh)
* Restore latest schema dump
* **PGBASE:** The base directory where PostgreSQL related files or resources are stored or accessed. it can be the name of your folder followed by test (ex: louis-test).
* **PGUSER:** The username or role required to authenticate and access a PostgreSQL database.
* **USER:** The username required for validation and access, it can be the same as PGUSER.
* **PGHOST:** The hostname or IP address of the server where the PostgreSQL database is hosted. If you want to use it locally, it should be `localhost`.
* **PGPASSWORD:** The password for the user authentication when connecting to the PostgreSQL database.
* **POSTGRES_PASSWORD:** The password for the database, for authentication when connecting to the PostgreSQL database.
* **PGDATA:** Path to the directory where PostgreSQL data files are stored. If it's not set, it will automatically select it for you.
* **OPENAI_API_KEY:** The API key required for authentication when making requests to the OpenAI API. It can be found [here](https://portal.azure.com/#home).
* **OPENAI_ENDPOINT:** The link used to call into Azure OpenAI endpoints. It can be found at the same place as the OPENAI_API_KEY.
* **OPENAI_API_ENGINE:** The name of the model deployment you want to use (ex:ailab-gpt-35-turbo).
* **LOUIS_SCHEMA:** The Louis schema within database (ex: louis_v005).
* **DB_SERVER_CONTAINER_NAME:** The name of your database server container (ex: louis-db-server).
* **AILAB_SCHEMA_VERSION:** The version of the schema you want to use.
1. Run database locally (see bin/postgres.sh)
1. Restore latest schema dump

### before every change

Expand Down
1 change: 1 addition & 0 deletions sql/2024-03-20-html_content_to_chunk_chunk_id_idx.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
create index html_content_to_chunk_chunk_id_idx on html_content_to_chunk(chunk_id);
11 changes: 11 additions & 0 deletions sql/2024-03-20-md5hash-to-uuid-with-indexes.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
ALTER TABLE html_content ADD md5hash_uuid uuid;
ALTER TABLE html_content_to_chunk ADD md5hash_uuid uuid;
ALTER TABLE crawl ADD md5hash_uuid uuid;

update html_content set md5hash_uuid = md5hash::uuid;
update html_content_to_chunk set md5hash_uuid = md5hash::uuid;
update crawl set md5hash_uuid = md5hash::uuid;

create index html_content_to_chunk_md5hash_uuid_idx on html_content_to_chunk(md5hash_uuid);
create index html_content_to_chunk_md5hash_uuid_idx on html_content(md5hash_uuid);
create index crawl_md5hash_uuid_idx on crawl(md5hash_uuid);
130 changes: 130 additions & 0 deletions sql/2024-03-20-weighted-search-updated.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
DROP FUNCTION if exists search(text,vector,double precision,integer,jsonb);
CREATE OR REPLACE FUNCTION search(
query text,
query_embedding vector,
match_threshold double precision,
match_count integer,
weights_parameter JSONB
)
RETURNS JSONB
AS
$BODY$
DECLARE
similarity_search boolean;
query_id uuid;
query_result JSONB;
BEGIN

create temp table weights (
score_type score_type,
weight double precision
) on commit drop;

INSERT INTO weights
SELECT key::score_type, value::float
FROM jsonb_each_text(weights_parameter)
WHERE value::float > 0.0;

-- check if there is a query embedding, if not skip similarity search
if query_embedding is null then
similarity_search := false;
else
similarity_search := true;
insert into query(query, embedding) values (query, query_embedding)
on conflict do nothing
returning id, result into query_id, query_result;

if query_result is not null then
return query_result;
end if;
end if;

if similarity_search then
create temp table selected_chunks on commit drop as
select
ada_002.id as embedding_id,
chunk.id as chunk_id,
crawl.id as crawl_id,
score,
score_type,
ada_002.embedding <=> query_embedding as closeness
from ada_002
inner join "token" on ada_002.token_id = "token".id
inner join "chunk" on "token".chunk_id = chunk.id
inner join html_content_to_chunk hctc on hctc.chunk_id = chunk.id
inner join html_content hc on hctc.md5hash_uuid = hc.md5hash_uuid
inner join crawl on hc.md5hash_uuid = crawl.md5hash_uuid
inner join "score" on crawl.id = score.entity_id
order by ada_002.embedding <=> query_embedding
limit match_count*10;


create temp table similarity on commit drop as
with chunk_closest_by_id as (
select c.crawl_id, min(c.closeness) as closeness
from selected_chunks c
group by c.crawl_id
)
select c.crawl_id as crawl_id, c.chunk_id as chunk_id, (1-c.closeness) as score
from selected_chunks c
inner join chunk_closest_by_id cc on c.crawl_id = cc.crawl_id and c.closeness = cc.closeness
order by score desc
limit match_count;

create temp table measurements on commit drop as
select s.crawl_id as id, s.chunk_id, 'similarity'::score_type as score_type, s.score as score
from similarity s;

insert into measurements
select entity_id as id, chunk_id, s.score_type, s.score
from score s
inner join similarity sim on s.entity_id = sim.crawl_id
where s.entity_id in (select m.id from measurements m);
else
-- take the longest chunk per document based on the number of tokens_count
-- grouped by url to avoid duplicates

create or replace temp view measurements as
select entity_id as id, chunk_id, s.score_type, s.score
from default_chunk d
inner join score s on s.entity_id = d.id
where s.score_type in (select w.score_type from weights w);
end if;

with matches as (
select m.id, m.chunk_id,
avg(m.score * w.weight) as score,
jsonb_object_agg(m.score_type, m.score) as scores
from measurements m
inner join weights w on m.score_type = w.score_type
group by m.id, m.chunk_id
order by score desc
limit match_count
)
select json_agg(r) as search from (
select
query_id,
c.id,
c.url,
c.title,
ck.title as subtitle,
ck.text_content as content,
c.last_updated,
m.score,
m.scores
from matches m
inner join crawl c on m.id = c.id
inner join chunk ck on m.chunk_id = ck.id
order by m.score desc
limit match_count
) r into query_result;

if query_id is not null then
update query set result = query_result where id = query_id;
end if;

return query_result;

END;
$BODY$
LANGUAGE plpgsql;
8 changes: 5 additions & 3 deletions tests/test_db_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@ def test_weighted_search(self):
weights = json.dumps(
{'similarity': 1.0, 'typicality': 0.2, 'recency': 1.0, 'traffic': 1.0, 'current': 0.5})
self.cursor.execute(
"SELECT * FROM search(%s, %s::vector, %s::float, %s::integer, %s::jsonb)", (
"SELECT * FROM searchmodified(%s, %s::vector, %s::float, %s::integer, %s::jsonb)", (
query, embeddings, test.MATCH_THRESHOLD, test.MATCH_COUNT, weights))
results = self.cursor.fetchall()
result = results[0]['search']
result = results[0]['searchmodified']
for res in result:
print(res['title'])
self.assertEqual(
result[0]['title'],
"Dr. Harpreet S. Kochhar - Canadian Food Inspection Agency")

query_id = result[0]['query_id']
self.cursor.execute("SELECT * FROM query where id = %s::uuid", (query_id,))
result = self.cursor.fetchall()
Expand Down
Loading