From 5a9945499fbdc5d5974dbcf9ac0c76a55acc07bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Blanc?= Date: Sat, 2 Dec 2023 18:03:31 +0100 Subject: [PATCH] pgvector store --- .../pgvector/PgVectorEmbeddingStore.java | 301 ++++++++++++++++++ 1 file changed, 301 insertions(+) diff --git a/pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/PgVectorEmbeddingStore.java b/pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/PgVectorEmbeddingStore.java index e69de29bb..60cf84ee9 100644 --- a/pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/PgVectorEmbeddingStore.java +++ b/pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/PgVectorEmbeddingStore.java @@ -0,0 +1,301 @@ +package io.quarkiverse.langchain4j.pgvector; + +import static dev.langchain4j.internal.Utils.*; +import static dev.langchain4j.internal.ValidationUtils.*; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toList; + +import java.sql.*; +import java.util.*; + +import org.postgresql.util.PSQLException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.pgvector.PGvector; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingStore; +import io.agroal.api.AgroalDataSource; +import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; +import io.quarkus.logging.Log; + +/** + * PGVector EmbeddingStore Implementation + *

+ * Only cosine similarity is used. + * Only ivfflat index is used. + */ +public class PgVectorEmbeddingStore implements EmbeddingStore { + + ObjectMapper objectMapper = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER; + private static final Logger log = LoggerFactory.getLogger(PgVectorEmbeddingStore.class); + private static final TypeReference> typeReference = new TypeReference>() { + }; + private final AgroalDataSource datasource; + private final String table; + private Statement statement; + + /** + * All args constructor for PgVectorEmbeddingStore Class + * + * @param datasource , the datasource object + * @param table The database table + * @param dimension The vector dimension + * @param useIndex Should use IVFFlat index + * @param indexListSize The IVFFlat number of lists + * @param createTable Should create table automatically + * @param dropTableFirst Should drop table first, usually for testing + */ + public PgVectorEmbeddingStore( + AgroalDataSource datasource, + String table, + Integer dimension, + Boolean useIndex, + Integer indexListSize, + Boolean createTable, + Boolean dropTableFirst) { + this.datasource = datasource; + this.table = ensureNotBlank(table, "table"); + + useIndex = getOrDefault(useIndex, false); + createTable = getOrDefault(createTable, true); + dropTableFirst = getOrDefault(dropTableFirst, false); + try (Connection connection = setupConnection()) { + if (dropTableFirst) { + statement = connection.createStatement(); + statement.executeUpdate(String.format("DROP TABLE IF EXISTS %s", table)); + statement.close(); + } + + if (createTable) { + statement = connection.createStatement(); + statement.executeUpdate(String.format( + "CREATE TABLE IF NOT EXISTS %s (" + + "embedding_id UUID PRIMARY KEY, " + + "embedding vector(%s), " + + "text TEXT NULL, " + + "metadata JSON NULL" + + ")", + table, ensureGreaterThanZero(dimension, "dimension"))); + statement.close(); + } + + if (useIndex) { + statement = connection.createStatement(); + statement.executeUpdate(String.format( + "CREATE INDEX IF NOT EXISTS ON %s " + + "USING ivfflat (embedding vector_cosine_ops) " + + "WITH (lists = %s)", + table, ensureGreaterThanZero(indexListSize, "indexListSize"))); + statement.close(); + } + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + private Connection setupConnection() throws SQLException { + Connection connection = datasource.getConnection(); + try { + statement = connection.createStatement(); + statement.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector"); + statement.close(); + } catch (PSQLException exception) { + if (exception.getMessage().contains("could not open extension control file")) { + Log.error( + "The PostgreSQL server does not seem to support pgvector." + + "If using containers/devservices we suggest to use quarkus.datasource.devservices.image-name=ankane/pgvector:v0.5.1"); + } else { + throw exception; + } + } + + PGvector.addVectorType(connection); + return connection; + } + + public void deleteAll() throws SQLException { + try (Connection connection = setupConnection()) { + statement = connection.createStatement(); + statement.executeUpdate(String.format("TRUNCATE TABLE %s", table)); + statement.close(); + } + } + + /** + * Adds a given embedding to the store. + * + * @param embedding The embedding to be added to the store. + * @return The auto-generated ID associated with the added embedding. + */ + @Override + public String add(Embedding embedding) { + String id = randomUUID(); + addInternal(id, embedding, null); + return id; + } + + /** + * Adds a given embedding to the store. + * + * @param id The unique identifier for the embedding to be added. + * @param embedding The embedding to be added to the store. + */ + @Override + public void add(String id, Embedding embedding) { + addInternal(id, embedding, null); + } + + /** + * Adds a given embedding and the corresponding content that has been embedded to the store. + * + * @param embedding The embedding to be added to the store. + * @param textSegment Original content that was embedded. + * @return The auto-generated ID associated with the added embedding. + */ + @Override + public String add(Embedding embedding, TextSegment textSegment) { + String id = randomUUID(); + addInternal(id, embedding, textSegment); + return id; + } + + /** + * Adds multiple embeddings to the store. + * + * @param embeddings A list of embeddings to be added to the store. + * @return A list of auto-generated IDs associated with the added embeddings. + */ + @Override + public List addAll(List embeddings) { + List ids = embeddings.stream().map(ignored -> randomUUID()).collect(toList()); + addAllInternal(ids, embeddings, null); + return ids; + } + + /** + * Adds multiple embeddings and their corresponding contents that have been embedded to the store. + * + * @param embeddings A list of embeddings to be added to the store. + * @param embedded A list of original contents that were embedded. + * @return A list of auto-generated IDs associated with the added embeddings. + */ + @Override + public List addAll(List embeddings, List embedded) { + List ids = embeddings.stream().map(ignored -> randomUUID()).collect(toList()); + addAllInternal(ids, embeddings, embedded); + return ids; + } + + /** + * Finds the most relevant (closest in space) embeddings to the provided reference embedding. + * + * @param referenceEmbedding The embedding used as a reference. Returned embeddings should be relevant (closest) to this + * one. + * @param maxResults The maximum number of embeddings to be returned. + * @param minScore The minimum relevance score, ranging from 0 to 1 (inclusive). + * Only embeddings with a score of this value or higher will be returned. + * @return A list of embedding matches. + * Each embedding match includes a relevance score (derivative of cosine distance), + * ranging from 0 (not relevant) to 1 (highly relevant). + */ + @Override + public List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { + List> result = new ArrayList<>(); + try (Connection connection = setupConnection()) { + String referenceVector = Arrays.toString(referenceEmbedding.vector()); + String query = String.format( + "WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, metadata FROM %s) SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;", + referenceVector, table, minScore, maxResults); + PreparedStatement selectStmt = connection.prepareStatement(query); + + ResultSet resultSet = selectStmt.executeQuery(); + while (resultSet.next()) { + double score = resultSet.getDouble("score"); + String embeddingId = resultSet.getString("embedding_id"); + + PGvector vector = (PGvector) resultSet.getObject("embedding"); + Embedding embedding = new Embedding(vector.toArray()); + + String text = resultSet.getString("text"); + TextSegment textSegment = null; + if (isNotNullOrBlank(text)) { + String metadataJson = Optional.ofNullable(resultSet.getString("metadata")).orElse("{}"); + Map metadataMap = objectMapper.readValue(metadataJson, typeReference); + Metadata metadata = new Metadata(new HashMap<>(metadataMap)); + textSegment = TextSegment.from(text, metadata); + } + result.add(new EmbeddingMatch<>(score, embeddingId, embedding, textSegment)); + } + selectStmt.close(); + resultSet.close(); + } catch (SQLException e) { + throw new RuntimeException(e); + } catch (JsonMappingException e) { + throw new RuntimeException(e); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + + return result; + } + + private void addInternal(String id, Embedding embedding, TextSegment embedded) { + addAllInternal( + singletonList(id), + singletonList(embedding), + embedded == null ? null : singletonList(embedded)); + } + + private void addAllInternal( + List ids, List embeddings, List embedded) { + if (isCollectionEmpty(ids) || isCollectionEmpty(embeddings)) { + log.info("Empty embeddings - no ops"); + return; + } + ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size"); + ensureTrue(embedded == null || embeddings.size() == embedded.size(), + "embeddings size is not equal to embedded size"); + + try (Connection connection = setupConnection()) { + String query = String.format( + "INSERT INTO %s (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?)" + + "ON CONFLICT (embedding_id) DO UPDATE SET " + + "embedding = EXCLUDED.embedding," + + "text = EXCLUDED.text," + + "metadata = EXCLUDED.metadata;", + table); + + PreparedStatement upsertStmt = connection.prepareStatement(query); + + for (int i = 0; i < ids.size(); ++i) { + upsertStmt.setObject(1, UUID.fromString(ids.get(i))); + upsertStmt.setObject(2, new PGvector(embeddings.get(i).vector())); + + if (embedded != null && embedded.get(i) != null) { + upsertStmt.setObject(3, embedded.get(i).text()); + Map metadata = embedded.get(i).metadata().asMap(); + upsertStmt.setObject(4, objectMapper.writeValueAsString(metadata), Types.OTHER); + } else { + upsertStmt.setNull(3, Types.VARCHAR); + upsertStmt.setNull(4, Types.OTHER); + } + upsertStmt.addBatch(); + } + + upsertStmt.executeBatch(); + upsertStmt.close(); + + } catch (Exception e) { + throw new RuntimeException(e); + } + } +}