Skip to content

Commit

Permalink
Add support for halfvec and bit vector types with PGVector ANN, closes
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Dec 10, 2024
1 parent d39b85e commit 5e56dad
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 43 deletions.
1 change: 1 addition & 0 deletions docs/embeddings/configuration/ann.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pgvector:
schema: database schema to store vectors - defaults to being
determined by the database
table: database table to store vectors - defaults to `vectors`
precision: vector float precision (half or full) - defaults to `full`
efconstruction: ef_construction param (int) - defaults to 200
m: M param for init_index (int) - defaults to 16
```
Expand Down
2 changes: 1 addition & 1 deletion docs/embeddings/configuration/vectors.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Enables truncation of vectors to this dimensionality. This is only useful for mo
quantize: int|boolean
```

Enables scalar vector quantization at the specified precision. Supports 1-bit through 8-bit quantization. Scalar quantization transforms continuous floating point values to discrete unsigned integers. Only the `faiss`, `numpy` and `torch` ANN backends support storing these vectors.
Enables scalar vector quantization at the specified precision. Supports 1-bit through 8-bit quantization. Scalar quantization transforms continuous floating point values to discrete unsigned integers. The `faiss`, `pgvector`, `numpy` and `torch` ANN backends support storing these vectors.

This parameter supports booleans for backwards compatability. When set to true/false, this flag sets [faiss.quantize](../ann/#faiss).

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
extras["ann"] = [
"annoy>=1.16.3",
"hnswlib>=0.5.0",
"pgvector>=0.2.5",
"pgvector>=0.3.0",
"sqlalchemy>=2.0.20",
"sqlite-vec>=0.1.1",
]
Expand Down
141 changes: 117 additions & 24 deletions src/python/txtai/ann/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

import os

import numpy as np

# Conditional import
try:
from pgvector.sqlalchemy import Vector
from pgvector.sqlalchemy import BIT, HALFVEC, VECTOR

from sqlalchemy import create_engine, delete, func, text, Column, Index, Integer, MetaData, StaticPool, Table
from sqlalchemy.orm import Session
Expand All @@ -33,6 +35,10 @@ def __init__(self, config):
# Database connection
self.engine, self.database, self.connection, self.table = None, None, None, None

# Scalar quantization
quantize = self.config.get("quantize")
self.qbits = quantize if quantize and isinstance(quantize, int) and not isinstance(quantize, bool) else None

def load(self, path):
# Initialize tables
self.initialize()
Expand All @@ -41,14 +47,18 @@ def index(self, embeddings):
# Initialize tables
self.initialize(recreate=True)

self.database.execute(self.table.insert(), [{"indexid": x, "embedding": row} for x, row in enumerate(embeddings)])
# Prepare embeddings and insert rows
self.database.execute(self.table.insert(), [{"indexid": x, "embedding": self.prepare(row)} for x, row in enumerate(embeddings)])

# Add id offset and index build metadata
self.config["offset"] = embeddings.shape[0]
self.metadata(self.settings())

def append(self, embeddings):
self.database.execute(self.table.insert(), [{"indexid": x + self.config["offset"], "embedding": row} for x, row in enumerate(embeddings)])
# Prepare embeddings and insert rows
self.database.execute(
self.table.insert(), [{"indexid": x + self.config["offset"], "embedding": self.prepare(row)} for x, row in enumerate(embeddings)]
)

# Update id offset and index metadata
self.config["offset"] += embeddings.shape[0]
Expand All @@ -61,14 +71,10 @@ def search(self, queries, limit):
results = []
for query in queries:
# Run query
query = (
self.database.query(self.table.c["indexid"], self.table.c["embedding"].max_inner_product(query).label("score"))
.order_by("score")
.limit(limit)
)
query = self.database.query(self.table.c["indexid"], self.query(query)).order_by("score").limit(limit)

# pgvector returns negative inner product since Postgres only supports ASC order index scans on operators
results.append([(indexid, -score) for indexid, score in query])
# Calculate and collect scores
results.append([(indexid, self.score(score)) for indexid, score in query])

return results

Expand Down Expand Up @@ -101,32 +107,25 @@ def initialize(self, recreate=False):
# Connect to database
self.connect()

# Set default schema, if necessary
schema = self.setting("schema")
if schema:
with self.engine.begin():
self.sqldialect(CreateSchema(schema, if_not_exists=True))

self.sqldialect(text("SET search_path TO :schema,public"), {"schema": schema})
# Set the database schema
self.schema()

# Table name
table = self.setting("table", "vectors")

# Get embedding column and index settings
column, index = self.column()

# Create vectors table
self.table = Table(
table,
MetaData(),
Column("indexid", Integer, primary_key=True, autoincrement=False),
Column("embedding", Vector(self.config["dimensions"])),
)
self.table = Table(table, MetaData(), Column("indexid", Integer, primary_key=True, autoincrement=False), Column("embedding", column))

# Create ANN index - inner product is equal to cosine similarity on normalized vectors
index = Index(
f"{table}-index",
self.table.c["embedding"],
postgresql_using="hnsw",
postgresql_with=self.settings(),
postgresql_ops={"embedding": "vector_ip_ops"},
postgresql_ops={"embedding": index},
)

# Drop and recreate table
Expand Down Expand Up @@ -157,6 +156,19 @@ def connect(self):
# Initialize pgvector extension
self.sqldialect(text("CREATE EXTENSION IF NOT EXISTS vector"))

def schema(self):
"""
Sets the database schema, if available.
"""

# Set default schema, if necessary
schema = self.setting("schema")
if schema:
with self.engine.begin():
self.sqldialect(CreateSchema(schema, if_not_exists=True))

self.sqldialect(text("SET search_path TO :schema,public"), {"schema": schema})

def settings(self):
"""
Returns settings for this index.
Expand All @@ -178,3 +190,84 @@ def sqldialect(self, sql, parameters=None):

args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (text("SELECT 1"),)
self.database.execute(*args)

def column(self):
"""
Gets embedding column and index definitions for the current settings.
Returns:
embedding column definition, index definition
"""

if self.qbits:
# If quantization is set, always return BIT vectors
return BIT(self.config["dimensions"] * 8), "bit_hamming_ops"

if self.setting("precision") == "half":
# 16-bit HALF precision vectors
return HALFVEC(self.config["dimensions"]), "halfvec_ip_ops"

# Default is full 32-bit FULL precision vectors
return VECTOR(self.config["dimensions"]), "vector_ip_ops"

def prepare(self, data):
"""
Prepares data for the embeddings column. This method returns a bit string for bit vectors and
the input data unmodified for float vectors.
Args:
data: input data
Returns:
data ready for the embeddings column
"""

# Transform to a bit string when vector quantization is enabled
if self.qbits:
return "".join(np.where(np.unpackbits(data), "1", "0"))

# Return original data
return data

def query(self, query):
"""
Creates a query statement from an input query. This method uses hamming distance for bit vectors and
the max_inner_product for float vectors.
Args:
query: input query
Returns:
query statement
"""

# Prepare query embeddings
query = self.prepare(query)

# Bit vector query
if self.qbits:
return self.table.c["embedding"].hamming_distance(query).label("score")

# Float vector query
return self.table.c["embedding"].max_inner_product(query).label("score")

def score(self, score):
"""
Calculates the index score from the input score. This method returns the hamming score
(1.0 - (hamming distance / total number of bits)) for bit vectors and the -score for
float vectors.
Args:
score: input score
Returns:
index score
"""

# Calculate hamming score as 1.0 - (hamming distance / total number of bits)
# Bound score from 0 to 1
if self.qbits:
return min(max(0.0, 1.0 - (score / (self.config["dimensions"] * 8))), 1.0)

# pgvector returns negative inner product since Postgres only supports ASC order index scans on operators
return -score
43 changes: 26 additions & 17 deletions test/python/testann.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,30 +171,39 @@ def testPGVector(self, query):
# Mock database query
query.return_value = [(x, -1.0) for x in range(data.shape[0])]

configs = [
("full", {"dimensions": 240}, {}, data),
("half", {"dimensions": 240}, {"precision": "half"}, data),
("binary", {"quantize": 1, "dimensions": 240 * 8}, {}, data.astype(np.uint8)),
]

# Create ANN
path = os.path.join(tempfile.gettempdir(), "pgvector.sqlite")
ann = ANNFactory.create({"backend": "pgvector", "pgvector": {"url": f"sqlite:///{path}", "schema": "txtai"}, "dimensions": 240})
for name, config, pgvector, data in configs:
path = os.path.join(tempfile.gettempdir(), f"pgvector.{name}.sqlite")
ann = ANNFactory.create(
{**{"backend": "pgvector", "pgvector": {**{"url": f"sqlite:///{path}", "schema": "txtai"}, **pgvector}}, **config}
)

# Test indexing
ann.index(data)
ann.append(data)
# Test indexing
ann.index(data)
ann.append(data)

# Validate search results
self.assertEqual(ann.search(data, 1), [[(0, 1.0)]])
# Validate search results
self.assertEqual(ann.search(data, 1), [[(0, 1.0)]])

# Validate save/load/delete
ann.save(None)
ann.load(None)
# Validate save/load/delete
ann.save(None)
ann.load(None)

# Validate count
self.assertEqual(ann.count(), 2)
# Validate count
self.assertEqual(ann.count(), 2)

# Test delete
ann.delete([0])
self.assertEqual(ann.count(), 1)
# Test delete
ann.delete([0])
self.assertEqual(ann.count(), 1)

# Close ANN
ann.close()
# Close ANN
ann.close()

@unittest.skipIf(platform.system() == "Darwin", "SQLite extensions not supported on macOS")
def testSQLite(self):
Expand Down

0 comments on commit 5e56dad

Please sign in to comment.