From 2e88bf530bb4fa4ffe20efbeaa22796334e0b08e Mon Sep 17 00:00:00 2001 From: Sebastien Blanc Date: Sun, 7 Jan 2024 20:43:21 +0100 Subject: [PATCH] adding opensearch formatting and doc --- .../IngestorExampleWithOpenSearch.java | 41 ++ .../quarkus-langchain4j-opensearch.adoc | 29 ++ docs/modules/ROOT/pages/opensearch-store.adoc | 37 ++ docs/pom.xml | 11 + opensearch/deployment/pom.xml | 79 ++++ .../Langchain4jOpensearchProcessor.java | 60 +++ .../test/Langchain4jOpensearchTest.java | 242 +++++++++++ opensearch/pom.xml | 19 + opensearch/runtime/pom.xml | 63 +++ .../langchain4j/opensearch/Document.java | 81 ++++ .../opensearch/OpenSearchEmbeddingStore.java | 406 ++++++++++++++++++ .../OpenSearchRequestFailedException.java | 16 + .../OpenSearchEmbeddingStoreConfig.java | 20 + .../OpenSearchEmbeddingStoreRecorder.java | 29 ++ .../resources/META-INF/quarkus-extension.yaml | 12 + pom.xml | 1 + 16 files changed, 1146 insertions(+) create mode 100644 docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/IngestorExampleWithOpenSearch.java create mode 100644 docs/modules/ROOT/pages/includes/quarkus-langchain4j-opensearch.adoc create mode 100644 docs/modules/ROOT/pages/opensearch-store.adoc create mode 100644 opensearch/deployment/pom.xml create mode 100644 opensearch/deployment/src/main/java/io/quarkiverse/langchain4j/opensearch/deployment/Langchain4jOpensearchProcessor.java create mode 100644 opensearch/deployment/src/test/java/io/quarkiverse/langchain4j/opensearch/test/Langchain4jOpensearchTest.java create mode 100644 opensearch/pom.xml create mode 100644 opensearch/runtime/pom.xml create mode 100644 opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/Document.java create mode 100644 opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/OpenSearchEmbeddingStore.java create mode 100644 opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/OpenSearchRequestFailedException.java create mode 100644 opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/runtime/OpenSearchEmbeddingStoreConfig.java create mode 100644 opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/runtime/OpenSearchEmbeddingStoreRecorder.java create mode 100644 opensearch/runtime/src/main/resources/META-INF/quarkus-extension.yaml diff --git a/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/IngestorExampleWithOpenSearch.java b/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/IngestorExampleWithOpenSearch.java new file mode 100644 index 000000000..5241705e6 --- /dev/null +++ b/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/IngestorExampleWithOpenSearch.java @@ -0,0 +1,41 @@ +package io.quarkiverse.langchain4j.samples; + +import static dev.langchain4j.data.document.splitter.DocumentSplitters.recursive; + +import java.util.List; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import dev.langchain4j.data.document.Document; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStoreIngestor; +import io.quarkiverse.langchain4j.opensearch.OpenSearchEmbeddingStore; + +@ApplicationScoped +public class IngestorExampleWithOpenSearch { + + /** + * The embedding store (the database). + * The bean is provided by the quarkus-langchain4j-opensearch extension. + */ + @Inject + OpenSearchEmbeddingStore store; + + /** + * The embedding model (how is computed the vector of a document). + * The bean is provided by the LLM (like openai) extension. + */ + @Inject + EmbeddingModel embeddingModel; + + public void ingest(List documents) { + EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder() + .embeddingStore(store) + .embeddingModel(embeddingModel) + .documentSplitter(recursive(500, 0)) + .build(); + // Warning - this can take a long time... + ingestor.ingest(documents); + } +} \ No newline at end of file diff --git a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-opensearch.adoc b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-opensearch.adoc new file mode 100644 index 000000000..7c716ccab --- /dev/null +++ b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-opensearch.adoc @@ -0,0 +1,29 @@ + +:summaryTableId: quarkus-langchain4j-opensearch +[.configuration-legend] +icon:lock[title=Fixed at build time] Configuration property fixed at build time - All other configuration properties are overridable at runtime +[.configuration-reference.searchable, cols="80,.^10,.^10"] +|=== + +h|[[quarkus-langchain4j-opensearch_configuration]]link:#quarkus-langchain4j-opensearch_configuration[Configuration property] + +h|Type +h|Default + +a| [[quarkus-langchain4j-opensearch_quarkus.langchain4j.opensearch.index]]`link:#quarkus-langchain4j-opensearch_quarkus.langchain4j.opensearch.index[quarkus.langchain4j.opensearch.index]` + + +[.description] +-- +Name of the index that will be used in OpenSearch when searching for related embeddings. If this index doesn't exist, it will be created. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENSEARCH_INDEX+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENSEARCH_INDEX+++` +endif::add-copy-button-to-env-var[] +--|string +|`default` + +|=== \ No newline at end of file diff --git a/docs/modules/ROOT/pages/opensearch-store.adoc b/docs/modules/ROOT/pages/opensearch-store.adoc new file mode 100644 index 000000000..eb802a2ce --- /dev/null +++ b/docs/modules/ROOT/pages/opensearch-store.adoc @@ -0,0 +1,37 @@ += OpenSearch Document Store for Retrieval Augmented Generation (RAG) + +include::./includes/attributes.adoc[] + +When implementing Retrieval Augmented Generation (RAG), a capable document store is necessary. This guide will explain how to leverage a pgvector database as the document store. + +== Leveraging the OpenSearch Document Store + +To utilize the OpenSearch document store, you'll need to include the following dependency: + +[source,xml,subs=attributes+] +---- + + io.quarkiverse.langchain4j + quarkus-langchain4j-opensearch + {project-version} + +---- + +This extension relies on the OpenSearch Java Client, make sure you have one configured correctly. + +Upon installing the extension, you can utilize the pgvector store using the following code: + +[source,java] +---- +include::{examples-dir}/io/quarkiverse/langchain4j/samples/IngestorExampleWithOpenSearch.java[] +---- + +== Configuration Settings + +Customize the behavior of the extension by exploring various configuration options: + +include::includes/quarkus-langchain4j-pgvector.adoc[leveloffset=+1,opts=optional] + +== Under the Hood + +Each ingested document is saved as a row in a Postgres table, containing the embedding column stored as a vector. diff --git a/docs/pom.xml b/docs/pom.xml index 6ff2d280f..c7b71bf4a 100644 --- a/docs/pom.xml +++ b/docs/pom.xml @@ -41,6 +41,11 @@ quarkus-langchain4j-pinecone ${project.version} + + io.quarkiverse.langchain4j + quarkus-langchain4j-opensearch + ${project.version} + @@ -73,6 +78,11 @@ quarkus-langchain4j-hugging-face-deployment ${project.version} + + io.quarkiverse.langchain4j + quarkus-langchain4j-opensearch-deployment + ${project.version} + @@ -127,6 +137,7 @@ quarkus-langchain4j-chroma.adoc quarkus-langchain4j-pinecone.adoc quarkus-langchain4j-pgvector.adoc + quarkus-langchain4j-opensearch.adoc false diff --git a/opensearch/deployment/pom.xml b/opensearch/deployment/pom.xml new file mode 100644 index 000000000..5125c0841 --- /dev/null +++ b/opensearch/deployment/pom.xml @@ -0,0 +1,79 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-opensearch-parent + 999-SNAPSHOT + + quarkus-langchain4j-opensearch-deployment + Quarkus Langchain4j - Opensearch embedding store - Deployment + + + io.quarkus + quarkus-arc-deployment + + + io.quarkiverse.langchain4j + quarkus-langchain4j-opensearch + ${project.version} + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core-deployment + ${project.version} + + + io.quarkiverse.opensearch + quarkus-opensearch-java-client-deployment + 1.4.0 + + + io.quarkus + quarkus-junit5-internal + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + org.wiremock + wiremock-standalone + ${wiremock.version} + test + + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2-q + ${langchain4j-embeddings.version} + test + + + dev.langchain4j + langchain4j-core + + + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + + diff --git a/opensearch/deployment/src/main/java/io/quarkiverse/langchain4j/opensearch/deployment/Langchain4jOpensearchProcessor.java b/opensearch/deployment/src/main/java/io/quarkiverse/langchain4j/opensearch/deployment/Langchain4jOpensearchProcessor.java new file mode 100644 index 000000000..9f8d491ef --- /dev/null +++ b/opensearch/deployment/src/main/java/io/quarkiverse/langchain4j/opensearch/deployment/Langchain4jOpensearchProcessor.java @@ -0,0 +1,60 @@ +package io.quarkiverse.langchain4j.opensearch.deployment; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Default; + +import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.ClassType; +import org.jboss.jandex.DotName; +import org.jboss.jandex.ParameterizedType; +import org.opensearch.client.opensearch.OpenSearchClient; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingStore; +import io.quarkiverse.langchain4j.deployment.EmbeddingStoreBuildItem; +import io.quarkiverse.langchain4j.opensearch.OpenSearchEmbeddingStore; +import io.quarkiverse.langchain4j.opensearch.runtime.OpenSearchEmbeddingStoreConfig; +import io.quarkiverse.langchain4j.opensearch.runtime.OpenSearchEmbeddingStoreRecorder; +import io.quarkus.arc.deployment.SyntheticBeanBuildItem; +import io.quarkus.deployment.annotations.BuildProducer; +import io.quarkus.deployment.annotations.BuildStep; +import io.quarkus.deployment.annotations.ExecutionTime; +import io.quarkus.deployment.annotations.Record; +import io.quarkus.deployment.builditem.FeatureBuildItem; + +class Langchain4jOpensearchProcessor { + + public static final DotName OPENSEARCH_EMBEDDING_STORE = DotName.createSimple(OpenSearchEmbeddingStore.class); + + private static final String FEATURE = "langchain4j-opensearch"; + + @BuildStep + FeatureBuildItem feature() { + return new FeatureBuildItem(FEATURE); + } + + @BuildStep + @Record(ExecutionTime.RUNTIME_INIT) + public void createBean( + BuildProducer beanProducer, + OpenSearchEmbeddingStoreRecorder recorder, + OpenSearchEmbeddingStoreConfig config, + BuildProducer embeddingStoreProducer) { + AnnotationInstance openSearchClientQualifier; + openSearchClientQualifier = AnnotationInstance.builder(Default.class).build(); + + beanProducer.produce(SyntheticBeanBuildItem + .configure(OPENSEARCH_EMBEDDING_STORE) + .types(ClassType.create(EmbeddingStore.class), + ParameterizedType.create(EmbeddingStore.class, ClassType.create(TextSegment.class))) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .addInjectionPoint(ClassType.create(DotName.createSimple(OpenSearchClient.class)), + openSearchClientQualifier) + .createWith(recorder.embeddingStoreFunction(config)) + .done()); + embeddingStoreProducer.produce(new EmbeddingStoreBuildItem()); + } + +} diff --git a/opensearch/deployment/src/test/java/io/quarkiverse/langchain4j/opensearch/test/Langchain4jOpensearchTest.java b/opensearch/deployment/src/test/java/io/quarkiverse/langchain4j/opensearch/test/Langchain4jOpensearchTest.java new file mode 100644 index 000000000..9e521d2d9 --- /dev/null +++ b/opensearch/deployment/src/test/java/io/quarkiverse/langchain4j/opensearch/test/Langchain4jOpensearchTest.java @@ -0,0 +1,242 @@ +package io.quarkiverse.langchain4j.opensearch.test; + +import static dev.langchain4j.internal.Utils.randomUUID; +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.data.Percentage.withPercentage; + +import java.sql.SQLException; +import java.util.List; + +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.CosineSimilarity; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.RelevanceScore; +import io.quarkiverse.langchain4j.opensearch.OpenSearchEmbeddingStore; +import io.quarkus.test.QuarkusUnitTest; + +public class Langchain4jOpensearchTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)); + + @Inject + EmbeddingStore embeddingStore; + + @Inject + OpenSearchEmbeddingStore openSearchEmbeddingStore; + + private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @AfterEach + public void cleanup() throws SQLException { + openSearchEmbeddingStore.deleteAll(); + } + + @Test + void should_add_embedding() { + assertThat(embeddingStore).isSameAs(openSearchEmbeddingStore); + + Embedding embedding = embeddingModel.embed(randomUUID()).content(); + + String id = embeddingStore.add(embedding); + assertThat(id).isNotNull(); + awaitUntilPersisted(); + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isNull(); + } + + @Test + void should_add_embedding_with_id() { + String id = randomUUID(); + Embedding embedding = embeddingModel.embed(randomUUID()).content(); + + embeddingStore.add(id, embedding); + awaitUntilPersisted(); + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isNull(); + } + + @Test + void should_add_embedding_with_segment() { + TextSegment segment = TextSegment.from(randomUUID()); + Embedding embedding = embeddingModel.embed(segment.text()).content(); + + String id = embeddingStore.add(embedding, segment); + assertThat(id).isNotNull(); + awaitUntilPersisted(); + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isEqualTo(segment); + } + + @Test + void should_add_embedding_with_segment_with_metadata() { + TextSegment segment = TextSegment.from(randomUUID(), Metadata.from("test-key", "test-value")); + Embedding embedding = embeddingModel.embed(segment.text()).content(); + + String id = embeddingStore.add(embedding, segment); + assertThat(id).isNotNull(); + awaitUntilPersisted(); + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isEqualTo(segment); + } + + @Test + void should_add_multiple_embeddings() { + Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); + Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); + + List ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding)); + assertThat(ids).hasSize(2); + awaitUntilPersisted(); + List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); + assertThat(relevant).hasSize(2); + + EmbeddingMatch firstMatch = relevant.get(0); + assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); + assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); + assertThat(firstMatch.embedded()).isNull(); + + EmbeddingMatch secondMatch = relevant.get(1); + assertThat(secondMatch.score()).isBetween(0d, 1d); + assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); + assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); + assertThat(secondMatch.embedded()).isNull(); + } + + @Test + void should_add_multiple_embeddings_with_segments() { + TextSegment firstSegment = TextSegment.from(randomUUID()); + Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content(); + TextSegment secondSegment = TextSegment.from(randomUUID()); + Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content(); + + List ids = embeddingStore.addAll( + asList(firstEmbedding, secondEmbedding), + asList(firstSegment, secondSegment)); + assertThat(ids).hasSize(2); + awaitUntilPersisted(); + List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); + assertThat(relevant).hasSize(2); + + EmbeddingMatch firstMatch = relevant.get(0); + assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); + assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); + assertThat(firstMatch.embedded()).isEqualTo(firstSegment); + + EmbeddingMatch secondMatch = relevant.get(1); + assertThat(secondMatch.score()).isBetween(0d, 1d); + assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); + assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); + assertThat(secondMatch.embedded()).isEqualTo(secondSegment); + } + + @Test + void should_find_with_min_score() { + String firstId = randomUUID(); + Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); + embeddingStore.add(firstId, firstEmbedding); + + String secondId = randomUUID(); + Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); + embeddingStore.add(secondId, secondEmbedding); + awaitUntilPersisted(); + List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); + assertThat(relevant).hasSize(2); + EmbeddingMatch firstMatch = relevant.get(0); + assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.embeddingId()).isEqualTo(firstId); + EmbeddingMatch secondMatch = relevant.get(1); + assertThat(secondMatch.score()).isBetween(0d, 1d); + assertThat(secondMatch.embeddingId()).isEqualTo(secondId); + + List> relevant2 = embeddingStore.findRelevant( + firstEmbedding, + 10, + secondMatch.score() - 0.01); + assertThat(relevant2).hasSize(2); + assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId); + assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId); + + List> relevant3 = embeddingStore.findRelevant( + firstEmbedding, + 10, + secondMatch.score()); + assertThat(relevant3).hasSize(2); + assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId); + assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId); + + List> relevant4 = embeddingStore.findRelevant( + firstEmbedding, + 10, + secondMatch.score() + 0.01); + assertThat(relevant4).hasSize(1); + assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId); + } + + @Test + void should_return_correct_score() { + Embedding embedding = embeddingModel.embed("hello").content(); + + String id = embeddingStore.add(embedding); + assertThat(id).isNotNull(); + + Embedding referenceEmbedding = embeddingModel.embed("hi").content(); + awaitUntilPersisted(); + List> relevant = embeddingStore.findRelevant(referenceEmbedding, 1); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo( + RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)), + withPercentage(1)); + } + + protected void awaitUntilPersisted() { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } +} diff --git a/opensearch/pom.xml b/opensearch/pom.xml new file mode 100644 index 000000000..c7a128f81 --- /dev/null +++ b/opensearch/pom.xml @@ -0,0 +1,19 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-parent + 999-SNAPSHOT + + quarkus-langchain4j-opensearch-parent + 999-SNAPSHOT + pom + Quarkus Langchain4j - Opensearch embedding store - Parent + + deployment + runtime + + diff --git a/opensearch/runtime/pom.xml b/opensearch/runtime/pom.xml new file mode 100644 index 000000000..80fa67ac3 --- /dev/null +++ b/opensearch/runtime/pom.xml @@ -0,0 +1,63 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-opensearch-parent + 999-SNAPSHOT + + quarkus-langchain4j-opensearch + Quarkus Langchain4j - Opensearch embedding store - Runtime + Do something useful. + + + io.quarkus + quarkus-arc + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core + ${project.version} + + + io.quarkiverse.opensearch + quarkus-opensearch-java-client + 1.4.0 + + + + + + io.quarkus + quarkus-extension-maven-plugin + ${quarkus.version} + + + compile + + extension-descriptor + + + ${project.groupId}:${project.artifactId}-deployment:${project.version} + + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + + diff --git a/opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/Document.java b/opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/Document.java new file mode 100644 index 000000000..4586b65c4 --- /dev/null +++ b/opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/Document.java @@ -0,0 +1,81 @@ +package io.quarkiverse.langchain4j.opensearch; + +import java.util.Map; + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; + +import io.quarkus.runtime.annotations.RegisterForReflection; + +@RegisterForReflection +@JsonDeserialize(builder = Document.Builder.class) +class Document { + + private float[] vector; + private String text; + private Map metadata; + + private Document(Builder builder) { + this.vector = builder.vector; + this.text = builder.text; + this.metadata = builder.metadata; + } + + public static Builder builder() { + return new Builder(); + } + + public float[] getVector() { + return vector; + } + + public void setVector(float[] vector) { + this.vector = vector; + } + + public String getText() { + return text; + } + + public void setText(String text) { + this.text = text; + } + + public Map getMetadata() { + return metadata; + } + + public void setMetadata(Map metadata) { + this.metadata = metadata; + } + + @JsonPOJOBuilder(withPrefix = "") + public static final class Builder { + private float[] vector; + private String text; + private Map metadata; + + private Builder() { + + } + + public Builder vector(float[] val) { + vector = val; + return this; + } + + public Builder text(String val) { + text = val; + return this; + } + + public Builder metadata(Map val) { + metadata = val; + return this; + } + + public Document build() { + return new Document(this); + } + } +} diff --git a/opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/OpenSearchEmbeddingStore.java b/opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/OpenSearchEmbeddingStore.java new file mode 100644 index 000000000..62620d4da --- /dev/null +++ b/opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/OpenSearchEmbeddingStore.java @@ -0,0 +1,406 @@ +package io.quarkiverse.langchain4j.opensearch; + +import static dev.langchain4j.internal.Utils.*; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static dev.langchain4j.internal.ValidationUtils.ensureTrue; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toList; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.apache.hc.client5.http.auth.AuthScope; +import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; +import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; +import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder; +import org.apache.hc.core5.http.HttpHost; +import org.apache.hc.core5.http.message.BasicHeader; +import org.opensearch.client.json.JsonData; +import org.opensearch.client.json.jackson.JacksonJsonpMapper; +import org.opensearch.client.opensearch.OpenSearchClient; +import org.opensearch.client.opensearch._types.ErrorCause; +import org.opensearch.client.opensearch._types.InlineScript; +import org.opensearch.client.opensearch._types.mapping.Property; +import org.opensearch.client.opensearch._types.mapping.TextProperty; +import org.opensearch.client.opensearch._types.mapping.TypeMapping; +import org.opensearch.client.opensearch._types.query_dsl.Query; +import org.opensearch.client.opensearch._types.query_dsl.ScriptScoreQuery; +import org.opensearch.client.opensearch.core.BulkRequest; +import org.opensearch.client.opensearch.core.BulkResponse; +import org.opensearch.client.opensearch.core.SearchRequest; +import org.opensearch.client.opensearch.core.SearchResponse; +import org.opensearch.client.opensearch.core.bulk.BulkResponseItem; +import org.opensearch.client.opensearch.indices.DeleteIndexRequest; +import org.opensearch.client.transport.OpenSearchTransport; +import org.opensearch.client.transport.aws.AwsSdk2Transport; +import org.opensearch.client.transport.aws.AwsSdk2TransportOptions; +import org.opensearch.client.transport.endpoints.BooleanResponse; +import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder; + +import com.fasterxml.jackson.core.JsonProcessingException; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingStore; +import io.quarkus.logging.Log; +import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.regions.Region; + +/** + * Represents an OpenSearch index as an + * embedding store. This implementation uses K-NN and the cosinesimil space type. + */ +public class OpenSearchEmbeddingStore implements EmbeddingStore { + + private final String indexName; + private final OpenSearchClient client; + + /** + * Creates an instance of OpenSearchEmbeddingStore to connect with + * OpenSearch clusters running locally and network reachable. + * + * @param serverUrl OpenSearch Server URL. + * @param apiKey OpenSearch API key (optional) + * @param userName OpenSearch username (optional) + * @param password OpenSearch password (optional) + * @param indexName OpenSearch index name. + */ + public OpenSearchEmbeddingStore(String serverUrl, + String apiKey, + String userName, + String password, + String indexName) { + HttpHost openSearchHost; + try { + openSearchHost = HttpHost.create(serverUrl); + } catch (URISyntaxException se) { + Log.error("[I/O OpenSearch Exception]", se); + throw new OpenSearchRequestFailedException(se.getMessage()); + } + + OpenSearchTransport transport = ApacheHttpClient5TransportBuilder + .builder(openSearchHost) + .setMapper(new JacksonJsonpMapper()) + .setHttpClientConfigCallback(httpClientBuilder -> { + + if (!isNullOrBlank(apiKey)) { + httpClientBuilder.setDefaultHeaders(singletonList( + new BasicHeader("Authorization", "ApiKey " + apiKey))); + } + + if (!isNullOrBlank(userName) && !isNullOrBlank(password)) { + BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + credentialsProvider.setCredentials(new AuthScope(openSearchHost), + new UsernamePasswordCredentials(userName, password.toCharArray())); + httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider); + } + + httpClientBuilder.setConnectionManager(PoolingAsyncClientConnectionManagerBuilder.create().build()); + + return httpClientBuilder; + }) + .build(); + + this.client = new OpenSearchClient(transport); + this.indexName = ensureNotNull(indexName, "indexName"); + } + + /** + * Creates an instance of OpenSearchEmbeddingStore to connect with + * OpenSearch clusters running as a fully managed service at AWS. + * + * @param serverUrl OpenSearch Server URL. + * @param serviceName The AWS signing service name, one of `es` (Amazon OpenSearch) or `aoss` (Amazon OpenSearch + * Serverless). + * @param region The AWS region for which requests will be signed. This should typically match the region in `serverUrl`. + * @param options The options to establish connection with the service. It must include which credentials should be used. + * @param indexName OpenSearch index name. + */ + public OpenSearchEmbeddingStore(String serverUrl, + String serviceName, + String region, + AwsSdk2TransportOptions options, + String indexName) { + + Region selectedRegion = Region.of(region); + + SdkHttpClient httpClient = ApacheHttpClient.builder().build(); + OpenSearchTransport transport = new AwsSdk2Transport(httpClient, serverUrl, serviceName, selectedRegion, options); + + this.client = new OpenSearchClient(transport); + this.indexName = ensureNotNull(indexName, "indexName"); + } + + /** + * Creates an instance of OpenSearchEmbeddingStore using provided OpenSearchClient + * + * @param openSearchClient OpenSearch client provided + * @param indexName OpenSearch index name. + */ + public OpenSearchEmbeddingStore(OpenSearchClient openSearchClient, + String indexName) { + + this.client = ensureNotNull(openSearchClient, "openSearchClient"); + this.indexName = ensureNotNull(indexName, "indexName"); + } + + public void deleteAll() { + DeleteIndexRequest deleteRequest = new DeleteIndexRequest.Builder().index(this.indexName).build(); + try { + client.indices().delete(deleteRequest); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String serverUrl; + private String apiKey; + private String userName; + private String password; + private String serviceName; + private String region; + private AwsSdk2TransportOptions options; + private String indexName = "default"; + private OpenSearchClient openSearchClient; + + public Builder serverUrl(String serverUrl) { + this.serverUrl = serverUrl; + return this; + } + + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + public Builder userName(String userName) { + this.userName = userName; + return this; + } + + public Builder password(String password) { + this.password = password; + return this; + } + + public Builder serviceName(String serviceName) { + this.serviceName = serviceName; + return this; + } + + public Builder region(String region) { + this.region = region; + return this; + } + + public Builder options(AwsSdk2TransportOptions options) { + this.options = options; + return this; + } + + public Builder indexName(String indexName) { + this.indexName = indexName; + return this; + } + + public Builder openSearchClient(OpenSearchClient openSearchClient) { + this.openSearchClient = openSearchClient; + return this; + } + + public OpenSearchEmbeddingStore build() { + if (openSearchClient != null) { + return new OpenSearchEmbeddingStore(openSearchClient, indexName); + } + if (!isNullOrBlank(serviceName) && !isNullOrBlank(region) && options != null) { + return new OpenSearchEmbeddingStore(serverUrl, serviceName, region, options, indexName); + } + return new OpenSearchEmbeddingStore(serverUrl, apiKey, userName, password, indexName); + } + + } + + @Override + public String add(Embedding embedding) { + String id = randomUUID(); + add(id, embedding); + return id; + } + + @Override + public void add(String id, Embedding embedding) { + addInternal(id, embedding, null); + } + + @Override + public String add(Embedding embedding, TextSegment textSegment) { + String id = randomUUID(); + addInternal(id, embedding, textSegment); + return id; + } + + @Override + public List addAll(List embeddings) { + List ids = embeddings.stream() + .map(ignored -> randomUUID()) + .collect(toList()); + addAllInternal(ids, embeddings, null); + return ids; + } + + @Override + public List addAll(List embeddings, List embedded) { + List ids = embeddings.stream() + .map(ignored -> randomUUID()) + .collect(toList()); + addAllInternal(ids, embeddings, embedded); + return ids; + } + + /** + * This implementation uses the exact k-NN with scoring script to calculate + * See https://opensearch.org/docs/latest/search-plugins/knn/knn-score-script/ + */ + @Override + public List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { + List> matches; + try { + ScriptScoreQuery scriptScoreQuery = buildDefaultScriptScoreQuery(referenceEmbedding.vector(), (float) minScore); + SearchResponse response = client.search( + SearchRequest.of(s -> s.index(indexName) + .query(n -> n.scriptScore(scriptScoreQuery)) + .size(maxResults)), + Document.class); + matches = toEmbeddingMatch(response); + } catch (IOException ex) { + Log.error("[I/O OpenSearch Exception]", ex); + throw new OpenSearchRequestFailedException(ex.getMessage()); + } + return matches; + } + + private ScriptScoreQuery buildDefaultScriptScoreQuery(float[] vector, float minScore) throws JsonProcessingException { + + return ScriptScoreQuery.of(q -> q.minScore(minScore) + .query(Query.of(qu -> qu.matchAll(m -> m))) + .script(s -> s.inline(InlineScript.of(i -> i + .source("knn_score") + .lang("knn") + .params("field", JsonData.of("vector")) + .params("query_value", JsonData.of(vector)) + .params("space_type", JsonData.of("cosinesimil"))))) + .boost(0.5f)); + + // ===> From the OpenSearch documentation: + // "Cosine similarity returns a number between -1 and 1, and because OpenSearch + // relevance scores can't be below 0, the k-NN plugin adds 1 to get the final score." + // See https://opensearch.org/docs/latest/search-plugins/knn/knn-score-script + // Thus, the query applies a boost of `0.5` to keep score in the range [0, 1] + } + + private void addInternal(String id, Embedding embedding, TextSegment embedded) { + addAllInternal(singletonList(id), singletonList(embedding), embedded == null ? null : singletonList(embedded)); + } + + private void addAllInternal(List ids, List embeddings, List embedded) { + + if (isNullOrEmpty(ids) || isNullOrEmpty(embeddings)) { + Log.info("[do not add empty embeddings to opensearch]"); + return; + } + + ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size"); + ensureTrue(embedded == null || embeddings.size() == embedded.size(), "embeddings size is not equal to embedded size"); + + try { + createIndexIfNotExist(embeddings.get(0).dimension()); + bulk(ids, embeddings, embedded); + } catch (IOException ex) { + Log.error("[I/O OpenSearch Exception]", ex); + throw new OpenSearchRequestFailedException(ex.getMessage()); + } + } + + private void createIndexIfNotExist(int dimension) throws IOException { + BooleanResponse response = client.indices().exists(c -> c.index(indexName)); + if (!response.value()) { + client.indices() + .create(c -> c.index(indexName) + .settings(s -> s.knn(true)) + .mappings(getDefaultMappings(dimension))); + } + } + + private TypeMapping getDefaultMappings(int dimension) { + Map properties = new HashMap<>(4); + properties.put("text", Property.of(p -> p.text(TextProperty.of(t -> t)))); + properties.put("vector", Property.of(p -> p.knnVector( + k -> k.dimension(dimension)))); + return TypeMapping.of(c -> c.properties(properties)); + } + + private void bulk(List ids, List embeddings, List embedded) throws IOException { + + int size = ids.size(); + BulkRequest.Builder bulkBuilder = new BulkRequest.Builder(); + + for (int i = 0; i < size; i++) { + int finalI = i; + Document document = Document.builder() + .vector(embeddings.get(i).vector()) + .text(embedded == null ? null : embedded.get(i).text()) + .metadata(embedded == null ? null + : Optional.ofNullable(embedded.get(i).metadata()) + .map(Metadata::asMap) + .orElse(null)) + .build(); + bulkBuilder.operations(op -> op.index( + idx -> idx + .index(indexName) + .id(ids.get(finalI)) + .document(document))); + } + + BulkResponse bulkResponse = client.bulk(bulkBuilder.build()); + + if (bulkResponse.errors()) { + for (BulkResponseItem item : bulkResponse.items()) { + if (item.error() != null) { + ErrorCause errorCause = item.error(); + if (errorCause != null) { + throw new OpenSearchRequestFailedException( + "type: " + errorCause.type() + "," + + "reason: " + errorCause.reason()); + } + } + } + } + } + + private List> toEmbeddingMatch(SearchResponse response) { + return response.hits().hits().stream() + .map(hit -> Optional.ofNullable(hit.source()) + .map(document -> new EmbeddingMatch<>( + hit.score(), + hit.id(), + new Embedding(document.getVector()), + document.getText() == null + ? null + : TextSegment.from(document.getText(), new Metadata(document.getMetadata())))) + .orElse(null)) + .collect(toList()); + } +} diff --git a/opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/OpenSearchRequestFailedException.java b/opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/OpenSearchRequestFailedException.java new file mode 100644 index 000000000..542e899a6 --- /dev/null +++ b/opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/OpenSearchRequestFailedException.java @@ -0,0 +1,16 @@ +package io.quarkiverse.langchain4j.opensearch; + +public class OpenSearchRequestFailedException extends RuntimeException { + + public OpenSearchRequestFailedException() { + super(); + } + + public OpenSearchRequestFailedException(String message) { + super(message); + } + + public OpenSearchRequestFailedException(String message, Throwable cause) { + super(message, cause); + } +} \ No newline at end of file diff --git a/opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/runtime/OpenSearchEmbeddingStoreConfig.java b/opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/runtime/OpenSearchEmbeddingStoreConfig.java new file mode 100644 index 000000000..69399caf4 --- /dev/null +++ b/opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/runtime/OpenSearchEmbeddingStoreConfig.java @@ -0,0 +1,20 @@ +package io.quarkiverse.langchain4j.opensearch.runtime; + +import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME; + +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefault; + +@ConfigRoot(phase = RUN_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.opensearch") +public interface OpenSearchEmbeddingStoreConfig { + + /** + * Name of the index that will be used in OpenSearch when searching for related embeddings. + * If this index doesn't exist, it will be created. + */ + @WithDefault("default") + String index(); + +} diff --git a/opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/runtime/OpenSearchEmbeddingStoreRecorder.java b/opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/runtime/OpenSearchEmbeddingStoreRecorder.java new file mode 100644 index 000000000..5c290b360 --- /dev/null +++ b/opensearch/runtime/src/main/java/io/quarkiverse/langchain4j/opensearch/runtime/OpenSearchEmbeddingStoreRecorder.java @@ -0,0 +1,29 @@ +package io.quarkiverse.langchain4j.opensearch.runtime; + +import java.util.function.Function; + +import jakarta.enterprise.inject.Default; + +import org.opensearch.client.opensearch.OpenSearchClient; + +import io.quarkiverse.langchain4j.opensearch.OpenSearchEmbeddingStore; +import io.quarkus.arc.SyntheticCreationalContext; +import io.quarkus.runtime.annotations.Recorder; + +@Recorder +public class OpenSearchEmbeddingStoreRecorder { + + public Function, OpenSearchEmbeddingStore> embeddingStoreFunction( + OpenSearchEmbeddingStoreConfig config) { + return new Function<>() { + @Override + public OpenSearchEmbeddingStore apply(SyntheticCreationalContext context) { + OpenSearchEmbeddingStore.Builder builder = new OpenSearchEmbeddingStore.Builder(); + OpenSearchClient openSearchClient; + openSearchClient = context.getInjectedReference(OpenSearchClient.class, new Default.Literal()); + builder.openSearchClient(openSearchClient); + return builder.build(); + } + }; + } +} diff --git a/opensearch/runtime/src/main/resources/META-INF/quarkus-extension.yaml b/opensearch/runtime/src/main/resources/META-INF/quarkus-extension.yaml new file mode 100644 index 000000000..0828c6e4a --- /dev/null +++ b/opensearch/runtime/src/main/resources/META-INF/quarkus-extension.yaml @@ -0,0 +1,12 @@ +name: Langchain4j Opensearch embedding store +artifact: ${project.groupId}:${project.artifactId}:${project.version} +description: Provides the Opensearch Embedding store for Langchain4j +metadata: + keywords: + - ai + - langchain4j + - openai + - opensearch + categories: + - "miscellaneous" + status: "experimental" diff --git a/pom.xml b/pom.xml index 4afd152da..9e48e5b1f 100644 --- a/pom.xml +++ b/pom.xml @@ -21,6 +21,7 @@ openai/azure-openai openai/openai-common openai/openai-vanilla + opensearch pinecone redis pgvector