From 5fa2db0f54627754587febfde3fb5c6cf0210038 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Blanc?= Date: Fri, 24 Nov 2023 15:51:51 +0100 Subject: [PATCH 1/2] Adding pgvector as embedding store adding pgvector as embedding store detect if pgvector is installable make sure the exception is related to the missing extension make dimension config property mandatory Update pgvector/runtime/src/main/resources/META-INF/quarkus-extension.yaml Update pgvector/runtime/src/main/resources/META-INF/quarkus-extension.yaml various refactoring based on feedback and added a real test deleted readme Update pgvector/deployment/src/main/java/io/quarkiverse/langchain4j/pgvector/deployment/Langchain4jPgvectorProcessor.java remove useless description Pinecone embedding store adding pgvector as embedding store add documentation generated config doc and added pgvector to doc's pom pgvector store --- .../samples/IngestorExampleWithPgvector.java | 41 +++ .../quarkus-langchain4j-pgvector.adoc | 114 +++++++ docs/modules/ROOT/pages/pgvector-store.adoc | 42 +++ docs/modules/ROOT/pages/pinecone-store.adoc | 14 - docs/pom.xml | 6 + pgvector/deployment/pom.xml | 79 +++++ .../Langchain4jPgvectorProcessor.java | 51 +++ .../test/Langchain4jPgvectorTest.java | 239 ++++++++++++++ pgvector/pom.xml | 18 ++ pgvector/runtime/pom.xml | 72 +++++ .../pgvector/PgVectorEmbeddingStore.java | 301 ++++++++++++++++++ .../runtime/PgVectorEmbeddingStoreConfig.java | 51 +++ .../PgVectorEmbeddingStoreRecorder.java | 29 ++ .../resources/META-INF/quarkus-extension.yaml | 12 + pom.xml | 2 + 15 files changed, 1057 insertions(+), 14 deletions(-) create mode 100644 docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/IngestorExampleWithPgvector.java create mode 100644 docs/modules/ROOT/pages/includes/quarkus-langchain4j-pgvector.adoc create mode 100644 docs/modules/ROOT/pages/pgvector-store.adoc create mode 100644 pgvector/deployment/pom.xml create mode 100644 pgvector/deployment/src/main/java/io/quarkiverse/langchain4j/pgvector/deployment/Langchain4jPgvectorProcessor.java create mode 100644 pgvector/deployment/src/test/java/io/quarkiverse/langchain4j/pgvector/test/Langchain4jPgvectorTest.java create mode 100644 pgvector/pom.xml create mode 100644 pgvector/runtime/pom.xml create mode 100644 pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/PgVectorEmbeddingStore.java create mode 100644 pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/runtime/PgVectorEmbeddingStoreConfig.java create mode 100644 pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/runtime/PgVectorEmbeddingStoreRecorder.java create mode 100644 pgvector/runtime/src/main/resources/META-INF/quarkus-extension.yaml diff --git a/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/IngestorExampleWithPgvector.java b/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/IngestorExampleWithPgvector.java new file mode 100644 index 000000000..f6754843d --- /dev/null +++ b/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/IngestorExampleWithPgvector.java @@ -0,0 +1,41 @@ +package io.quarkiverse.langchain4j.samples; + +import static dev.langchain4j.data.document.splitter.DocumentSplitters.recursive; + +import java.util.List; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import dev.langchain4j.data.document.Document; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStoreIngestor; +import io.quarkiverse.langchain4j.pgvector.PgVectorEmbeddingStore; + +@ApplicationScoped +public class IngestorExampleWithPgvector { + + /** + * The embedding store (the database). + * The bean is provided by the quarkus-langchain4j-pgvector extension. + */ + @Inject + PgVectorEmbeddingStore store; + + /** + * The embedding model (how is computed the vector of a document). + * The bean is provided by the LLM (like openai) extension. + */ + @Inject + EmbeddingModel embeddingModel; + + public void ingest(List documents) { + EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder() + .embeddingStore(store) + .embeddingModel(embeddingModel) + .documentSplitter(recursive(500, 0)) + .build(); + // Warning - this can take a long time... + ingestor.ingest(documents); + } +} \ No newline at end of file diff --git a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-pgvector.adoc b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-pgvector.adoc new file mode 100644 index 000000000..1e5cfcce8 --- /dev/null +++ b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-pgvector.adoc @@ -0,0 +1,114 @@ + +:summaryTableId: quarkus-langchain4j-pgvector +[.configuration-legend] +icon:lock[title=Fixed at build time] Configuration property fixed at build time - All other configuration properties are overridable at runtime +[.configuration-reference.searchable, cols="80,.^10,.^10"] +|=== + +h|[[quarkus-langchain4j-pgvector_configuration]]link:#quarkus-langchain4j-pgvector_configuration[Configuration property] + +h|Type +h|Default + +a| [[quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.table]]`link:#quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.table[quarkus.langchain4j.pgvector.table]` + + +[.description] +-- +The table name for storing embeddings + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_PGVECTOR_TABLE+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_PGVECTOR_TABLE+++` +endif::add-copy-button-to-env-var[] +--|string +|`embeddings` + + +a| [[quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.dimension]]`link:#quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.dimension[quarkus.langchain4j.pgvector.dimension]` + + +[.description] +-- +The dimension of the embedding vectors. This has to be the same as the dimension of vectors produced by the embedding model that you use. For example, AllMiniLmL6V2QuantizedEmbeddingModel produces vectors of dimension 384. OpenAI's text-embedding-ada-002 produces vectors of dimension 1536. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_PGVECTOR_DIMENSION+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_PGVECTOR_DIMENSION+++` +endif::add-copy-button-to-env-var[] +--|int +|required icon:exclamation-circle[title=Configuration property is required] + + +a| [[quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.use-index]]`link:#quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.use-index[quarkus.langchain4j.pgvector.use-index]` + + +[.description] +-- +Use index or not + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_PGVECTOR_USE_INDEX+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_PGVECTOR_USE_INDEX+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + +a| [[quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.index-list-size]]`link:#quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.index-list-size[quarkus.langchain4j.pgvector.index-list-size]` + + +[.description] +-- +index size + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_PGVECTOR_INDEX_LIST_SIZE+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_PGVECTOR_INDEX_LIST_SIZE+++` +endif::add-copy-button-to-env-var[] +--|int +|`0` + + +a| [[quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.create-table]]`link:#quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.create-table[quarkus.langchain4j.pgvector.create-table]` + + +[.description] +-- +Create table or not + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_PGVECTOR_CREATE_TABLE+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_PGVECTOR_CREATE_TABLE+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`true` + + +a| [[quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.drop-table-first]]`link:#quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.drop-table-first[quarkus.langchain4j.pgvector.drop-table-first]` + + +[.description] +-- +Drop table or not + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_PGVECTOR_DROP_TABLE_FIRST+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_PGVECTOR_DROP_TABLE_FIRST+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + +|=== \ No newline at end of file diff --git a/docs/modules/ROOT/pages/pgvector-store.adoc b/docs/modules/ROOT/pages/pgvector-store.adoc new file mode 100644 index 000000000..17cd162f8 --- /dev/null +++ b/docs/modules/ROOT/pages/pgvector-store.adoc @@ -0,0 +1,42 @@ += Pgvector Document Store for Retrieval Augmented Generation (RAG) + +include::./includes/attributes.adoc[] + +When implementing Retrieval Augmented Generation (RAG), a capable document store is necessary. This guide will explain how to leverage a pgvector database as the document store. + +== Leveraging the pgvector Document Store + +To utilize the Redis document store, you'll need to include the following dependency: + +[source,xml,subs=attributes+] +---- + + io.quarkiverse.langchain4j + quarkus-langchain4j-pgvector + {project-version} + +---- + +This extension will check for a default datasource, ensure you have defined at least one datasource. For detailed guidance, refer to the link:https://quarkus.io/guides/datasource[CONFIGURE DATA SOURCES IN QUARKUS]. + +IMPORTANT: If you plan to use `devservices` be sure to use this property : `quarkus.datasource.devservices.image-name=ankane/pgvector:v0.5.1`. + +IMPORTANT: The pgvector store requires the dimension of the vector to be set. Add the `quarkus.langchain4j.pgvector.dimension` property to your `application.properties` file and set it to the dimension of the vector. The dimension depends on the embedding model you use. +For example, `AllMiniLmL6V2QuantizedEmbeddingModel` produces vectors of dimension 384. OpenAI’s `text-embedding-ada-002` produces vectors of dimension 1536. + +Upon installing the extension, you can utilize the pgvector store using the following code: + +[source,java] +---- +include::{examples-dir}/io/quarkiverse/langchain4j/samples/IngestorExampleWithPgvector.java[] +---- + +== Configuration Settings + +Customize the behavior of the extension by exploring various configuration options: + +include::includes/quarkus-langchain4j-pgvector.adoc[leveloffset=+1,opts=optional] + +== Under the Hood + +Each ingested document is saved as a row in a Postgres table, containing the embedding column stored as a vector. diff --git a/docs/modules/ROOT/pages/pinecone-store.adoc b/docs/modules/ROOT/pages/pinecone-store.adoc index a536f362b..33aa24759 100644 --- a/docs/modules/ROOT/pages/pinecone-store.adoc +++ b/docs/modules/ROOT/pages/pinecone-store.adoc @@ -16,20 +16,6 @@ To make use of the Pinecone document store, you'll need to include the following ---- -The required configuration properties to make the extension work are -`quarkus.langchain4j.pinecone.api-key`, -`quarkus.langchain4j.pinecone.environment`, -`quarkus.langchain4j.pinecone.index-name`, and -`quarkus.langchain4j.pinecone.project-id`. The specified index will be -created if it doesn't exist yet. - -Upon installing the extension, you can utilize the Pinecone embedding store using the following code: - -[source,java] ----- -include::{examples-dir}/io/quarkiverse/langchain4j/samples/IngestorExampleWithPinecone.java[] ----- - == Configuration Settings Customize the behavior of the extension by exploring various configuration options: diff --git a/docs/pom.xml b/docs/pom.xml index b89de7c36..6ff2d280f 100644 --- a/docs/pom.xml +++ b/docs/pom.xml @@ -63,6 +63,11 @@ quarkus-langchain4j-pinecone-deployment ${project.version} + + io.quarkiverse.langchain4j + quarkus-langchain4j-pgvector-deployment + ${project.version} + io.quarkiverse.langchain4j quarkus-langchain4j-hugging-face-deployment @@ -121,6 +126,7 @@ quarkus-langchain4j-redis.adoc quarkus-langchain4j-chroma.adoc quarkus-langchain4j-pinecone.adoc + quarkus-langchain4j-pgvector.adoc false diff --git a/pgvector/deployment/pom.xml b/pgvector/deployment/pom.xml new file mode 100644 index 000000000..eefcb3e53 --- /dev/null +++ b/pgvector/deployment/pom.xml @@ -0,0 +1,79 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-pgvector-parent + 999-SNAPSHOT + + quarkus-langchain4j-pgvector-deployment + Quarkus langchain4j-pgvector - Deployment + + + io.quarkiverse.langchain4j + quarkus-langchain4j-pgvector + ${project.version} + + + io.quarkus + quarkus-arc-deployment + + + io.quarkus + quarkus-jackson-deployment + + + io.quarkus + quarkus-agroal-deployment + + + io.quarkus + quarkus-jdbc-postgresql-deployment + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core-deployment + ${project.version} + + + io.quarkus + quarkus-junit5-internal + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + org.wiremock + wiremock-standalone + ${wiremock.version} + test + + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2-q + ${langchain4j.version} + test + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + + diff --git a/pgvector/deployment/src/main/java/io/quarkiverse/langchain4j/pgvector/deployment/Langchain4jPgvectorProcessor.java b/pgvector/deployment/src/main/java/io/quarkiverse/langchain4j/pgvector/deployment/Langchain4jPgvectorProcessor.java new file mode 100644 index 000000000..c487d00b4 --- /dev/null +++ b/pgvector/deployment/src/main/java/io/quarkiverse/langchain4j/pgvector/deployment/Langchain4jPgvectorProcessor.java @@ -0,0 +1,51 @@ +package io.quarkiverse.langchain4j.pgvector.deployment; + +import jakarta.enterprise.context.ApplicationScoped; + +import org.jboss.jandex.ClassType; +import org.jboss.jandex.DotName; +import org.jboss.jandex.ParameterizedType; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingStore; +import io.agroal.api.AgroalDataSource; +import io.quarkiverse.langchain4j.pgvector.PgVectorEmbeddingStore; +import io.quarkiverse.langchain4j.pgvector.runtime.PgVectorEmbeddingStoreConfig; +import io.quarkiverse.langchain4j.pgvector.runtime.PgVectorEmbeddingStoreRecorder; +import io.quarkus.arc.deployment.SyntheticBeanBuildItem; +import io.quarkus.deployment.annotations.BuildProducer; +import io.quarkus.deployment.annotations.BuildStep; +import io.quarkus.deployment.annotations.ExecutionTime; +import io.quarkus.deployment.annotations.Record; +import io.quarkus.deployment.builditem.FeatureBuildItem; + +class Langchain4jPgvectorProcessor { + + public static final DotName PGVECTOR_EMBEDDING_STORE = DotName.createSimple(PgVectorEmbeddingStore.class); + + private static final String FEATURE = "langchain4j-pgvector"; + + @BuildStep + FeatureBuildItem feature() { + return new FeatureBuildItem(FEATURE); + } + + @BuildStep + @Record(ExecutionTime.RUNTIME_INIT) + public void createBean( + BuildProducer beanProducer, + PgVectorEmbeddingStoreRecorder recorder, + PgVectorEmbeddingStoreConfig config) { + beanProducer.produce(SyntheticBeanBuildItem + .configure(PGVECTOR_EMBEDDING_STORE) + .types(ClassType.create(EmbeddingStore.class), + ParameterizedType.create(EmbeddingStore.class, ClassType.create(TextSegment.class))) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .addInjectionPoint(ClassType.create(DotName.createSimple(AgroalDataSource.class))) + .createWith(recorder.embeddingStoreFunction(config)) + .done()); + + } +} diff --git a/pgvector/deployment/src/test/java/io/quarkiverse/langchain4j/pgvector/test/Langchain4jPgvectorTest.java b/pgvector/deployment/src/test/java/io/quarkiverse/langchain4j/pgvector/test/Langchain4jPgvectorTest.java new file mode 100644 index 000000000..ae0f1139e --- /dev/null +++ b/pgvector/deployment/src/test/java/io/quarkiverse/langchain4j/pgvector/test/Langchain4jPgvectorTest.java @@ -0,0 +1,239 @@ +package io.quarkiverse.langchain4j.pgvector.test; + +import static dev.langchain4j.internal.Utils.randomUUID; +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.data.Percentage.withPercentage; + +import java.sql.SQLException; +import java.util.List; + +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.CosineSimilarity; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.RelevanceScore; +import io.quarkiverse.langchain4j.pgvector.PgVectorEmbeddingStore; +import io.quarkus.test.QuarkusUnitTest; + +public class Langchain4jPgvectorTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addAsResource(new StringAsset("quarkus.langchain4j.pgvector.dimension=384\n" + + "quarkus.datasource.devservices.image-name=ankane/pgvector:v0.5.1"), + "application.properties")); + + @Inject + EmbeddingStore embeddingStore; + + @Inject + PgVectorEmbeddingStore pgvectorEmbeddingStore; + + private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @AfterEach + public void cleanup() throws SQLException { + pgvectorEmbeddingStore.deleteAll(); + } + + @Test + void should_add_embedding() { + assertThat(embeddingStore).isSameAs(pgvectorEmbeddingStore); + + Embedding embedding = embeddingModel.embed(randomUUID()).content(); + + String id = embeddingStore.add(embedding); + assertThat(id).isNotNull(); + + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isNull(); + } + + @Test + void should_add_embedding_with_id() { + String id = randomUUID(); + Embedding embedding = embeddingModel.embed(randomUUID()).content(); + + embeddingStore.add(id, embedding); + + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isNull(); + } + + @Test + void should_add_embedding_with_segment() { + TextSegment segment = TextSegment.from(randomUUID()); + Embedding embedding = embeddingModel.embed(segment.text()).content(); + + String id = embeddingStore.add(embedding, segment); + assertThat(id).isNotNull(); + + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isEqualTo(segment); + } + + @Test + void should_add_embedding_with_segment_with_metadata() { + TextSegment segment = TextSegment.from(randomUUID(), Metadata.from("test-key", "test-value")); + Embedding embedding = embeddingModel.embed(segment.text()).content(); + + String id = embeddingStore.add(embedding, segment); + assertThat(id).isNotNull(); + + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isEqualTo(segment); + } + + @Test + void should_add_multiple_embeddings() { + Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); + Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); + + List ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding)); + assertThat(ids).hasSize(2); + + List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); + assertThat(relevant).hasSize(2); + + EmbeddingMatch firstMatch = relevant.get(0); + assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); + assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); + assertThat(firstMatch.embedded()).isNull(); + + EmbeddingMatch secondMatch = relevant.get(1); + assertThat(secondMatch.score()).isBetween(0d, 1d); + assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); + assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); + assertThat(secondMatch.embedded()).isNull(); + } + + @Test + void should_add_multiple_embeddings_with_segments() { + TextSegment firstSegment = TextSegment.from(randomUUID()); + Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content(); + TextSegment secondSegment = TextSegment.from(randomUUID()); + Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content(); + + List ids = embeddingStore.addAll( + asList(firstEmbedding, secondEmbedding), + asList(firstSegment, secondSegment)); + assertThat(ids).hasSize(2); + + List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); + assertThat(relevant).hasSize(2); + + EmbeddingMatch firstMatch = relevant.get(0); + assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); + assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); + assertThat(firstMatch.embedded()).isEqualTo(firstSegment); + + EmbeddingMatch secondMatch = relevant.get(1); + assertThat(secondMatch.score()).isBetween(0d, 1d); + assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); + assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); + assertThat(secondMatch.embedded()).isEqualTo(secondSegment); + } + + @Test + void should_find_with_min_score() { + String firstId = randomUUID(); + Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); + embeddingStore.add(firstId, firstEmbedding); + + String secondId = randomUUID(); + Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); + embeddingStore.add(secondId, secondEmbedding); + + List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); + assertThat(relevant).hasSize(2); + EmbeddingMatch firstMatch = relevant.get(0); + assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.embeddingId()).isEqualTo(firstId); + EmbeddingMatch secondMatch = relevant.get(1); + assertThat(secondMatch.score()).isBetween(0d, 1d); + assertThat(secondMatch.embeddingId()).isEqualTo(secondId); + + List> relevant2 = embeddingStore.findRelevant( + firstEmbedding, + 10, + secondMatch.score() - 0.01); + assertThat(relevant2).hasSize(2); + assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId); + assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId); + + List> relevant3 = embeddingStore.findRelevant( + firstEmbedding, + 10, + secondMatch.score()); + assertThat(relevant3).hasSize(2); + assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId); + assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId); + + List> relevant4 = embeddingStore.findRelevant( + firstEmbedding, + 10, + secondMatch.score() + 0.01); + assertThat(relevant4).hasSize(1); + assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId); + } + + @Test + void should_return_correct_score() { + Embedding embedding = embeddingModel.embed("hello").content(); + + String id = embeddingStore.add(embedding); + assertThat(id).isNotNull(); + + Embedding referenceEmbedding = embeddingModel.embed("hi").content(); + + List> relevant = embeddingStore.findRelevant(referenceEmbedding, 1); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo( + RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)), + withPercentage(1)); + } + +} diff --git a/pgvector/pom.xml b/pgvector/pom.xml new file mode 100644 index 000000000..4f622034a --- /dev/null +++ b/pgvector/pom.xml @@ -0,0 +1,18 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-parent + 999-SNAPSHOT + + quarkus-langchain4j-pgvector-parent + 999-SNAPSHOT + pom + Quarkus langchain4j-pgvector - Parent + + deployment + runtime + + diff --git a/pgvector/runtime/pom.xml b/pgvector/runtime/pom.xml new file mode 100644 index 000000000..56436c295 --- /dev/null +++ b/pgvector/runtime/pom.xml @@ -0,0 +1,72 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-pgvector-parent + 999-SNAPSHOT + + quarkus-langchain4j-pgvector + Quarkus langchain4j-pgvector - Runtime + + + io.quarkus + quarkus-arc + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core + ${project.version} + + + com.pgvector + pgvector + ${pgvector-java.version} + + + io.quarkus + quarkus-jackson + + + io.quarkus + quarkus-agroal + + + io.quarkus + quarkus-jdbc-postgresql + + + + + + io.quarkus + quarkus-extension-maven-plugin + ${quarkus.version} + + + compile + + extension-descriptor + + + ${project.groupId}:${project.artifactId}-deployment:${project.version} + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + + diff --git a/pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/PgVectorEmbeddingStore.java b/pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/PgVectorEmbeddingStore.java new file mode 100644 index 000000000..60cf84ee9 --- /dev/null +++ b/pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/PgVectorEmbeddingStore.java @@ -0,0 +1,301 @@ +package io.quarkiverse.langchain4j.pgvector; + +import static dev.langchain4j.internal.Utils.*; +import static dev.langchain4j.internal.ValidationUtils.*; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toList; + +import java.sql.*; +import java.util.*; + +import org.postgresql.util.PSQLException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.pgvector.PGvector; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingStore; +import io.agroal.api.AgroalDataSource; +import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; +import io.quarkus.logging.Log; + +/** + * PGVector EmbeddingStore Implementation + *

+ * Only cosine similarity is used. + * Only ivfflat index is used. + */ +public class PgVectorEmbeddingStore implements EmbeddingStore { + + ObjectMapper objectMapper = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER; + private static final Logger log = LoggerFactory.getLogger(PgVectorEmbeddingStore.class); + private static final TypeReference> typeReference = new TypeReference>() { + }; + private final AgroalDataSource datasource; + private final String table; + private Statement statement; + + /** + * All args constructor for PgVectorEmbeddingStore Class + * + * @param datasource , the datasource object + * @param table The database table + * @param dimension The vector dimension + * @param useIndex Should use IVFFlat index + * @param indexListSize The IVFFlat number of lists + * @param createTable Should create table automatically + * @param dropTableFirst Should drop table first, usually for testing + */ + public PgVectorEmbeddingStore( + AgroalDataSource datasource, + String table, + Integer dimension, + Boolean useIndex, + Integer indexListSize, + Boolean createTable, + Boolean dropTableFirst) { + this.datasource = datasource; + this.table = ensureNotBlank(table, "table"); + + useIndex = getOrDefault(useIndex, false); + createTable = getOrDefault(createTable, true); + dropTableFirst = getOrDefault(dropTableFirst, false); + try (Connection connection = setupConnection()) { + if (dropTableFirst) { + statement = connection.createStatement(); + statement.executeUpdate(String.format("DROP TABLE IF EXISTS %s", table)); + statement.close(); + } + + if (createTable) { + statement = connection.createStatement(); + statement.executeUpdate(String.format( + "CREATE TABLE IF NOT EXISTS %s (" + + "embedding_id UUID PRIMARY KEY, " + + "embedding vector(%s), " + + "text TEXT NULL, " + + "metadata JSON NULL" + + ")", + table, ensureGreaterThanZero(dimension, "dimension"))); + statement.close(); + } + + if (useIndex) { + statement = connection.createStatement(); + statement.executeUpdate(String.format( + "CREATE INDEX IF NOT EXISTS ON %s " + + "USING ivfflat (embedding vector_cosine_ops) " + + "WITH (lists = %s)", + table, ensureGreaterThanZero(indexListSize, "indexListSize"))); + statement.close(); + } + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + private Connection setupConnection() throws SQLException { + Connection connection = datasource.getConnection(); + try { + statement = connection.createStatement(); + statement.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector"); + statement.close(); + } catch (PSQLException exception) { + if (exception.getMessage().contains("could not open extension control file")) { + Log.error( + "The PostgreSQL server does not seem to support pgvector." + + "If using containers/devservices we suggest to use quarkus.datasource.devservices.image-name=ankane/pgvector:v0.5.1"); + } else { + throw exception; + } + } + + PGvector.addVectorType(connection); + return connection; + } + + public void deleteAll() throws SQLException { + try (Connection connection = setupConnection()) { + statement = connection.createStatement(); + statement.executeUpdate(String.format("TRUNCATE TABLE %s", table)); + statement.close(); + } + } + + /** + * Adds a given embedding to the store. + * + * @param embedding The embedding to be added to the store. + * @return The auto-generated ID associated with the added embedding. + */ + @Override + public String add(Embedding embedding) { + String id = randomUUID(); + addInternal(id, embedding, null); + return id; + } + + /** + * Adds a given embedding to the store. + * + * @param id The unique identifier for the embedding to be added. + * @param embedding The embedding to be added to the store. + */ + @Override + public void add(String id, Embedding embedding) { + addInternal(id, embedding, null); + } + + /** + * Adds a given embedding and the corresponding content that has been embedded to the store. + * + * @param embedding The embedding to be added to the store. + * @param textSegment Original content that was embedded. + * @return The auto-generated ID associated with the added embedding. + */ + @Override + public String add(Embedding embedding, TextSegment textSegment) { + String id = randomUUID(); + addInternal(id, embedding, textSegment); + return id; + } + + /** + * Adds multiple embeddings to the store. + * + * @param embeddings A list of embeddings to be added to the store. + * @return A list of auto-generated IDs associated with the added embeddings. + */ + @Override + public List addAll(List embeddings) { + List ids = embeddings.stream().map(ignored -> randomUUID()).collect(toList()); + addAllInternal(ids, embeddings, null); + return ids; + } + + /** + * Adds multiple embeddings and their corresponding contents that have been embedded to the store. + * + * @param embeddings A list of embeddings to be added to the store. + * @param embedded A list of original contents that were embedded. + * @return A list of auto-generated IDs associated with the added embeddings. + */ + @Override + public List addAll(List embeddings, List embedded) { + List ids = embeddings.stream().map(ignored -> randomUUID()).collect(toList()); + addAllInternal(ids, embeddings, embedded); + return ids; + } + + /** + * Finds the most relevant (closest in space) embeddings to the provided reference embedding. + * + * @param referenceEmbedding The embedding used as a reference. Returned embeddings should be relevant (closest) to this + * one. + * @param maxResults The maximum number of embeddings to be returned. + * @param minScore The minimum relevance score, ranging from 0 to 1 (inclusive). + * Only embeddings with a score of this value or higher will be returned. + * @return A list of embedding matches. + * Each embedding match includes a relevance score (derivative of cosine distance), + * ranging from 0 (not relevant) to 1 (highly relevant). + */ + @Override + public List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { + List> result = new ArrayList<>(); + try (Connection connection = setupConnection()) { + String referenceVector = Arrays.toString(referenceEmbedding.vector()); + String query = String.format( + "WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, metadata FROM %s) SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;", + referenceVector, table, minScore, maxResults); + PreparedStatement selectStmt = connection.prepareStatement(query); + + ResultSet resultSet = selectStmt.executeQuery(); + while (resultSet.next()) { + double score = resultSet.getDouble("score"); + String embeddingId = resultSet.getString("embedding_id"); + + PGvector vector = (PGvector) resultSet.getObject("embedding"); + Embedding embedding = new Embedding(vector.toArray()); + + String text = resultSet.getString("text"); + TextSegment textSegment = null; + if (isNotNullOrBlank(text)) { + String metadataJson = Optional.ofNullable(resultSet.getString("metadata")).orElse("{}"); + Map metadataMap = objectMapper.readValue(metadataJson, typeReference); + Metadata metadata = new Metadata(new HashMap<>(metadataMap)); + textSegment = TextSegment.from(text, metadata); + } + result.add(new EmbeddingMatch<>(score, embeddingId, embedding, textSegment)); + } + selectStmt.close(); + resultSet.close(); + } catch (SQLException e) { + throw new RuntimeException(e); + } catch (JsonMappingException e) { + throw new RuntimeException(e); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + + return result; + } + + private void addInternal(String id, Embedding embedding, TextSegment embedded) { + addAllInternal( + singletonList(id), + singletonList(embedding), + embedded == null ? null : singletonList(embedded)); + } + + private void addAllInternal( + List ids, List embeddings, List embedded) { + if (isCollectionEmpty(ids) || isCollectionEmpty(embeddings)) { + log.info("Empty embeddings - no ops"); + return; + } + ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size"); + ensureTrue(embedded == null || embeddings.size() == embedded.size(), + "embeddings size is not equal to embedded size"); + + try (Connection connection = setupConnection()) { + String query = String.format( + "INSERT INTO %s (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?)" + + "ON CONFLICT (embedding_id) DO UPDATE SET " + + "embedding = EXCLUDED.embedding," + + "text = EXCLUDED.text," + + "metadata = EXCLUDED.metadata;", + table); + + PreparedStatement upsertStmt = connection.prepareStatement(query); + + for (int i = 0; i < ids.size(); ++i) { + upsertStmt.setObject(1, UUID.fromString(ids.get(i))); + upsertStmt.setObject(2, new PGvector(embeddings.get(i).vector())); + + if (embedded != null && embedded.get(i) != null) { + upsertStmt.setObject(3, embedded.get(i).text()); + Map metadata = embedded.get(i).metadata().asMap(); + upsertStmt.setObject(4, objectMapper.writeValueAsString(metadata), Types.OTHER); + } else { + upsertStmt.setNull(3, Types.VARCHAR); + upsertStmt.setNull(4, Types.OTHER); + } + upsertStmt.addBatch(); + } + + upsertStmt.executeBatch(); + upsertStmt.close(); + + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/runtime/PgVectorEmbeddingStoreConfig.java b/pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/runtime/PgVectorEmbeddingStoreConfig.java new file mode 100644 index 000000000..3d8c6fd3b --- /dev/null +++ b/pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/runtime/PgVectorEmbeddingStoreConfig.java @@ -0,0 +1,51 @@ +package io.quarkiverse.langchain4j.pgvector.runtime; + +import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME; + +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefault; + +@ConfigRoot(phase = RUN_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.pgvector") +public interface PgVectorEmbeddingStoreConfig { + + /** + * The table name for storing embeddings + */ + @WithDefault("embeddings") + String table(); + + /** + * The dimension of the embedding vectors. This has to be the same as the dimension of vectors produced by + * the embedding model that you use. For example, AllMiniLmL6V2QuantizedEmbeddingModel produces vectors of dimension 384. + * OpenAI's text-embedding-ada-002 produces vectors of dimension 1536. + */ + Integer dimension(); + + /** + * Use index or not + */ + @WithDefault("false") + Boolean useIndex(); + + /** + * + * index size + */ + @WithDefault("0") + Integer indexListSize(); + + /** + * Create table or not + */ + @WithDefault("true") + Boolean createTable(); + + /** + * Drop table or not + */ + @WithDefault("false") + Boolean dropTableFirst(); + +} diff --git a/pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/runtime/PgVectorEmbeddingStoreRecorder.java b/pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/runtime/PgVectorEmbeddingStoreRecorder.java new file mode 100644 index 000000000..42ec4fada --- /dev/null +++ b/pgvector/runtime/src/main/java/io/quarkiverse/langchain4j/pgvector/runtime/PgVectorEmbeddingStoreRecorder.java @@ -0,0 +1,29 @@ +package io.quarkiverse.langchain4j.pgvector.runtime; + +import java.util.function.Function; + +import jakarta.enterprise.inject.Default; + +import io.agroal.api.AgroalDataSource; +import io.quarkiverse.langchain4j.pgvector.PgVectorEmbeddingStore; +import io.quarkus.arc.SyntheticCreationalContext; +import io.quarkus.runtime.annotations.Recorder; + +@Recorder +public class PgVectorEmbeddingStoreRecorder { + + public Function, PgVectorEmbeddingStore> embeddingStoreFunction( + PgVectorEmbeddingStoreConfig config) { + return new Function<>() { + @Override + public PgVectorEmbeddingStore apply(SyntheticCreationalContext context) { + AgroalDataSource dataSource; + //TODO handle named datasources + dataSource = context.getInjectedReference(AgroalDataSource.class, new Default.Literal()); + return new PgVectorEmbeddingStore(dataSource, config.table(), config.dimension(), config.useIndex(), + config.indexListSize(), config.createTable(), config.dropTableFirst()); + } + }; + } + +} diff --git a/pgvector/runtime/src/main/resources/META-INF/quarkus-extension.yaml b/pgvector/runtime/src/main/resources/META-INF/quarkus-extension.yaml new file mode 100644 index 000000000..b136c392a --- /dev/null +++ b/pgvector/runtime/src/main/resources/META-INF/quarkus-extension.yaml @@ -0,0 +1,12 @@ +name: Quarkus Langchain4j pgvector embedding store +artifact: ${project.groupId}:${project.artifactId}:${project.version} +description: Provides the pgvector Embedding store for Quarkus Langchain4j +metadata: + keywords: + - ai + - langchain4j + - openai + - pgvector + categories: + - "miscellaneous" + status: "preview" diff --git a/pom.xml b/pom.xml index 1f2ec2f38..77a8e6cda 100644 --- a/pom.xml +++ b/pom.xml @@ -22,6 +22,7 @@ openai/openai-vanilla pinecone redis + pgvector scm:git:git@github.com:quarkiverse/quarkus-langchain4j.git @@ -39,6 +40,7 @@ 2.0.4 3.24.2 3.3.1 + 0.1.3 From 4e2a2055ce6cf33d117e02580e912f2a3c6faf4d Mon Sep 17 00:00:00 2001 From: Jan Martiska Date: Mon, 4 Dec 2023 14:54:27 +0100 Subject: [PATCH 2/2] Make pgvector work in native mode --- .../pgvector/deployment/Langchain4jPgvectorProcessor.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pgvector/deployment/src/main/java/io/quarkiverse/langchain4j/pgvector/deployment/Langchain4jPgvectorProcessor.java b/pgvector/deployment/src/main/java/io/quarkiverse/langchain4j/pgvector/deployment/Langchain4jPgvectorProcessor.java index c487d00b4..59f579d23 100644 --- a/pgvector/deployment/src/main/java/io/quarkiverse/langchain4j/pgvector/deployment/Langchain4jPgvectorProcessor.java +++ b/pgvector/deployment/src/main/java/io/quarkiverse/langchain4j/pgvector/deployment/Langchain4jPgvectorProcessor.java @@ -6,6 +6,8 @@ import org.jboss.jandex.DotName; import org.jboss.jandex.ParameterizedType; +import com.pgvector.PGvector; + import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingStore; import io.agroal.api.AgroalDataSource; @@ -18,6 +20,7 @@ import io.quarkus.deployment.annotations.ExecutionTime; import io.quarkus.deployment.annotations.Record; import io.quarkus.deployment.builditem.FeatureBuildItem; +import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem; class Langchain4jPgvectorProcessor { @@ -48,4 +51,9 @@ public void createBean( .done()); } + + @BuildStep + public ReflectiveClassBuildItem reflectiveClass() { + return ReflectiveClassBuildItem.builder(PGvector.class).build(); + } }