diff --git a/milvus/deployment/pom.xml b/milvus/deployment/pom.xml new file mode 100644 index 000000000..3e47cdbff --- /dev/null +++ b/milvus/deployment/pom.xml @@ -0,0 +1,72 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-milvus-parent + 999-SNAPSHOT + + quarkus-langchain4j-milvus-deployment + Quarkus Langchain4j - Milvus embedding store - Deployment + + + io.quarkus + quarkus-arc-deployment + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core-deployment + ${project.version} + + + io.quarkiverse.langchain4j + quarkus-langchain4j-milvus + ${project.version} + + + dev.langchain4j + langchain4j-milvus + ${langchain4j.version} + + + io.quarkus + quarkus-devservices-deployment + + + io.quarkus + quarkus-junit5-internal + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2-q + ${langchain4j.version} + test + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + + + diff --git a/milvus/deployment/src/main/java/io/quarkiverse/langchain4j/milvus/MilvusBuildConfig.java b/milvus/deployment/src/main/java/io/quarkiverse/langchain4j/milvus/MilvusBuildConfig.java new file mode 100644 index 000000000..8940ce938 --- /dev/null +++ b/milvus/deployment/src/main/java/io/quarkiverse/langchain4j/milvus/MilvusBuildConfig.java @@ -0,0 +1,67 @@ +package io.quarkiverse.langchain4j.milvus; + +import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME; + +import java.util.OptionalInt; + +import io.quarkus.runtime.annotations.ConfigGroup; +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefault; + +@ConfigRoot(phase = BUILD_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.milvus") +public interface MilvusBuildConfig { + + /** + * Configuration for DevServices. DevServices allows Quarkus to automatically start a database in dev and test mode. + */ + MilvusDevServicesBuildTimeConfig devservices(); + + @ConfigGroup + interface MilvusDevServicesBuildTimeConfig { + + /** + * Whether Dev Services for Milvus are enabled or not. + */ + @WithDefault("true") + boolean enabled(); + + /** + * Container image for Milvus. + */ + @WithDefault("docker.io/milvusdb/milvus:v2.3.3") + String milvusImageName(); + + /** + * Container image for etcd. + */ + @WithDefault("quay.io/coreos/etcd:v3.5.5") + String etcdImageName(); + + /** + * Container image for minio. + */ + @WithDefault("docker.io/minio/minio:RELEASE.2023-12-13T23-28-55Z") + String minioImageName(); + + /** + * Optional fixed port the Milvus dev service will listen to. + * If not defined, the port will be chosen randomly. + */ + OptionalInt port(); + + /** + * Indicates if the Dev Service containers managed by Quarkus for Milvus are shared. + */ + @WithDefault("true") + boolean shared(); + + /** + * Service label to apply to created Dev Services containers. + */ + @WithDefault("milvus") + String serviceName(); + + } +} diff --git a/milvus/deployment/src/main/java/io/quarkiverse/langchain4j/milvus/MilvusDevServicesProcessor.java b/milvus/deployment/src/main/java/io/quarkiverse/langchain4j/milvus/MilvusDevServicesProcessor.java new file mode 100644 index 000000000..737068f6a --- /dev/null +++ b/milvus/deployment/src/main/java/io/quarkiverse/langchain4j/milvus/MilvusDevServicesProcessor.java @@ -0,0 +1,492 @@ +package io.quarkiverse.langchain4j.milvus; + +import java.io.Closeable; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.function.Supplier; + +import org.jboss.logging.Logger; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy; +import org.testcontainers.utility.DockerImageName; + +import io.quarkus.bootstrap.classloading.QuarkusClassLoader; +import io.quarkus.deployment.IsNormal; +import io.quarkus.deployment.annotations.BuildStep; +import io.quarkus.deployment.annotations.BuildSteps; +import io.quarkus.deployment.builditem.DevServicesResultBuildItem; +import io.quarkus.deployment.builditem.DockerStatusBuildItem; +import io.quarkus.deployment.builditem.LaunchModeBuildItem; +import io.quarkus.deployment.console.ConsoleInstalledBuildItem; +import io.quarkus.deployment.console.StartupLogCompressor; +import io.quarkus.deployment.dev.devservices.GlobalDevServicesConfig; +import io.quarkus.deployment.logging.LoggingSetupBuildItem; +import io.quarkus.devservices.common.ConfigureUtil; +import io.quarkus.devservices.common.ContainerLocator; +import io.quarkus.runtime.LaunchMode; +import io.quarkus.runtime.configuration.ConfigUtils; + +@SuppressWarnings("OptionalUsedAsFieldOrParameterType") +@BuildSteps(onlyIfNot = IsNormal.class, onlyIf = GlobalDevServicesConfig.Enabled.class) +public class MilvusDevServicesProcessor { + + private static final Logger log = Logger.getLogger(MilvusDevServicesProcessor.class); + + /** + * Label to add to shared Dev Service for Chroma running in containers. + * This allows other applications to discover the running service and use it instead of starting a new instance. + */ + private static final String DEV_SERVICE_LABEL = "quarkus-dev-service-milvus"; + + private static final String ETCD_IMAGE_NAME = "docker.io/coreos/etcd"; + private static final String MINIO_IMAGE_NAME = "docker.io/minio/minio"; + private static final String MILVUS_IMAGE_NAME = "docker.io/milvusdb/milvus"; + + private static final int MILVUS_PORT = 19530; + private static final int MINIO_PORT = 9000; + private static final int ETCD_PORT = 2379; + + private static final ContainerLocator containerLocator = new ContainerLocator(DEV_SERVICE_LABEL, MILVUS_PORT); + static volatile DevServicesResultBuildItem.RunningDevService milvusDevService; + static volatile DevServicesResultBuildItem.RunningDevService minioDevService; + static volatile DevServicesResultBuildItem.RunningDevService etcdDevService; + static volatile MilvusDevServiceCfg cfg; + static volatile boolean first = true; + + @BuildStep + public List startMilvusDevServices( + DockerStatusBuildItem dockerStatusBuildItem, + LaunchModeBuildItem launchMode, + MilvusBuildConfig milvusBuildConfig, + Optional consoleInstalledBuildItem, + LoggingSetupBuildItem loggingSetupBuildItem, + GlobalDevServicesConfig devServicesConfig) { + + List result = new ArrayList<>(); + MilvusDevServiceCfg configuration = getConfiguration(milvusBuildConfig); + + if (milvusDevService != null || etcdDevService != null || minioDevService != null) { + boolean shouldShutdown = !configuration.equals(cfg); + if (!shouldShutdown) { + result.add(milvusDevService.toBuildItem()); + result.add(etcdDevService.toBuildItem()); + result.add(minioDevService.toBuildItem()); + return result; + } + shutdownContainers(); + cfg = null; + } + + if (!milvusBuildConfig.devservices().enabled()) { + // explicitly disabled + log.debug("Not starting Dev Services for Milvus, as it has been disabled in the config."); + return Collections.emptyList(); + } + // if connection to Milvus was explicitly specified, don't start Dev Services + if (ConfigUtils.isPropertyPresent("quarkus.langchain4j.milvus.host")) { + return Collections.emptyList(); + } + StartupLogCompressor compressor = new StartupLogCompressor( + (launchMode.isTest() ? "(test) " : "") + "Milvus Dev Services Starting:", consoleInstalledBuildItem, + loggingSetupBuildItem); + try { + DevServicesResultBuildItem.RunningDevService newEtcdDevService = startEtcdContainer( + dockerStatusBuildItem, configuration, launchMode, + devServicesConfig.timeout); + if (newEtcdDevService != null) { + etcdDevService = newEtcdDevService; + if (etcdDevService.isOwner()) { + log.info("Dev Services instance of Etcd started."); + } + } + if (etcdDevService == null) { + compressor.closeAndDumpCaptured(); + } else { + compressor.close(); + } + DevServicesResultBuildItem.RunningDevService newMinioDevService = startMinioContainer( + dockerStatusBuildItem, configuration, launchMode, + devServicesConfig.timeout); + if (newMinioDevService != null) { + minioDevService = newMinioDevService; + if (minioDevService.isOwner()) { + log.info("Dev Services instance of Minio started."); + } + } + if (minioDevService == null) { + compressor.closeAndDumpCaptured(); + } else { + compressor.close(); + } + DevServicesResultBuildItem.RunningDevService newMilvusDevService = startMilvusContainer( + dockerStatusBuildItem, configuration, launchMode, + devServicesConfig.timeout, + newMinioDevService.getConfig().get("minio-host"), newMinioDevService.getConfig().get("minio-port"), + newEtcdDevService.getConfig().get("etcd-host"), newEtcdDevService.getConfig().get("etcd-port")); + if (newMilvusDevService != null) { + milvusDevService = newMilvusDevService; + if (milvusDevService.isOwner()) { + log.info("Dev Services instance of Milvus started."); + } + } + if (milvusDevService == null) { + compressor.closeAndDumpCaptured(); + } else { + compressor.close(); + } + } catch (Throwable t) { + compressor.closeAndDumpCaptured(); + throw new RuntimeException(t); + } + + if (milvusDevService == null || etcdDevService == null || minioDevService == null) { + return Collections.emptyList(); + } + + // Configure the watch dog + if (first) { + first = false; + Runnable closeTask = () -> { + shutdownContainers(); + first = true; + cfg = null; + }; + QuarkusClassLoader cl = (QuarkusClassLoader) Thread.currentThread().getContextClassLoader(); + ((QuarkusClassLoader) cl.parent()).addCloseTask(closeTask); + } + cfg = configuration; + result.add(milvusDevService.toBuildItem()); + result.add(etcdDevService.toBuildItem()); + result.add(minioDevService.toBuildItem()); + return result; + } + + private void shutdownContainers() { + if (milvusDevService != null) { + try { + milvusDevService.close(); + } catch (Throwable e) { + log.error("Failed to stop the Milvus server", e); + } finally { + milvusDevService = null; + } + } + if (etcdDevService != null) { + try { + etcdDevService.close(); + } catch (Throwable e) { + log.error("Failed to stop the Etcd server", e); + } finally { + etcdDevService = null; + } + } + if (minioDevService != null) { + try { + minioDevService.close(); + } catch (Throwable e) { + log.error("Failed to stop the Minio server", e); + } finally { + minioDevService = null; + } + } + } + + private DevServicesResultBuildItem.RunningDevService startMilvusContainer(DockerStatusBuildItem dockerStatusBuildItem, + MilvusDevServiceCfg config, LaunchModeBuildItem launchMode, + Optional timeout, String minioHost, String minioPort, String etcdHost, String etcdPort) { + if (!dockerStatusBuildItem.isDockerAvailable()) { + log.warn("Docker isn't working, please configure the Milvus server location."); + return null; + } + + ConfiguredMilvusContainer container = new ConfiguredMilvusContainer( + DockerImageName.parse(config.milvusImageName).asCompatibleSubstituteFor(MILVUS_IMAGE_NAME), + config.fixedMilvusPort, + launchMode.getLaunchMode() == LaunchMode.DEVELOPMENT ? config.serviceName : null); + + final Supplier defaultMilvusSupplier = () -> { + + // Starting the broker + timeout.ifPresent(container::withStartupTimeout); + container.addEnv("ETCD_ENDPOINTS", etcdHost + ":" + etcdPort); + container.addEnv("MINIO_ADDRESS", minioHost + ":" + minioPort); + container.start(); + return getRunningMilvusDevService( + container.getContainerId(), + container::close, + container.getHost(), + container.getPort()); + }; + return containerLocator + .locateContainer( + config.serviceName, + config.shared, + launchMode.getLaunchMode()) + .map(containerAddress -> getRunningMilvusDevService( + containerAddress.getId(), + null, + containerAddress.getHost(), + containerAddress.getPort())) + .orElseGet(defaultMilvusSupplier); + } + + private DevServicesResultBuildItem.RunningDevService startMinioContainer(DockerStatusBuildItem dockerStatusBuildItem, + MilvusDevServiceCfg config, LaunchModeBuildItem launchMode, Optional timeout) { + + ConfiguredMinioContainer container = new ConfiguredMinioContainer( + DockerImageName.parse(config.minioImageName).asCompatibleSubstituteFor(MINIO_IMAGE_NAME), + launchMode.getLaunchMode() == LaunchMode.DEVELOPMENT ? config.serviceName : null); + + final Supplier defaultMinioSupplier = () -> { + + // Starting the broker + timeout.ifPresent(container::withStartupTimeout); + container.start(); + return getRunningMinioDevService( + container.getContainerId(), + container::close, + container.getHost(), + container.getPort()); + }; + + return containerLocator + .locateContainer( + config.serviceName, + config.shared, + launchMode.getLaunchMode()) + .map(containerAddress -> getRunningMinioDevService( + containerAddress.getId(), + null, + containerAddress.getHost(), + containerAddress.getPort())) + .orElseGet(defaultMinioSupplier); + } + + private DevServicesResultBuildItem.RunningDevService startEtcdContainer(DockerStatusBuildItem dockerStatusBuildItem, + MilvusDevServiceCfg config, LaunchModeBuildItem launchMode, + Optional timeout) { + ConfiguredEtcdContainer container = new ConfiguredEtcdContainer( + DockerImageName.parse(config.etcdImageName).asCompatibleSubstituteFor(ETCD_IMAGE_NAME), + launchMode.getLaunchMode() == LaunchMode.DEVELOPMENT ? config.serviceName : null); + + final Supplier defaultEtcdSupplier = () -> { + + timeout.ifPresent(container::withStartupTimeout); + container.addEnv("ETCD_AUTO_COMPACTION_MODE", "revision"); + container.addEnv("ETCD_AUTO_COMPACTION_RETENTION", "1000"); + container.addEnv("ETCD_QUOTA_BACKEND_BYTES", "4294967296"); + container.addEnv("ETCD_SNAPSHOT_COUNT", "50000"); + container.setCommand("etcd", "-advertise-client-urls=http://127.0.0.1:2379", + "-listen-client-urls=http://0.0.0.0:2379", + "--data-dir=/etcd"); + container.start(); + return getRunningEtcdDevService( + container.getContainerId(), + container::close, + container.getHost(), + container.getPort()); + }; + + return containerLocator + .locateContainer( + config.serviceName, + config.shared, + launchMode.getLaunchMode()) + .map(containerAddress -> getRunningEtcdDevService( + containerAddress.getId(), + null, + containerAddress.getHost(), + containerAddress.getPort())) + .orElseGet(defaultEtcdSupplier); + } + + private DevServicesResultBuildItem.RunningDevService getRunningMilvusDevService( + String containerId, + Closeable closeable, + String host, + int port) { + Map configMap = Map.of("quarkus.langchain4j.milvus.host", "localhost", + "quarkus.langchain4j.milvus.port", String.valueOf(port)); + return new DevServicesResultBuildItem.RunningDevService(MilvusProcessor.FEATURE, + containerId, closeable, configMap); + } + + private DevServicesResultBuildItem.RunningDevService getRunningMinioDevService( + String containerId, + Closeable closeable, + String host, + int port) { + Map configMap = new HashMap<>(); + configMap.put("minio-host", host); + configMap.put("minio-port", String.valueOf(port)); + return new DevServicesResultBuildItem.RunningDevService(MilvusProcessor.FEATURE, + containerId, closeable, configMap); + } + + private DevServicesResultBuildItem.RunningDevService getRunningEtcdDevService( + String containerId, + Closeable closeable, + String host, + int port) { + Map configMap = new HashMap<>(); + configMap.put("etcd-host", host); + configMap.put("etcd-port", String.valueOf(port)); + return new DevServicesResultBuildItem.RunningDevService(MilvusProcessor.FEATURE, + containerId, closeable, configMap); + } + + private MilvusDevServiceCfg getConfiguration(MilvusBuildConfig cfg) { + return new MilvusDevServiceCfg(cfg.devservices()); + } + + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + private static final class MilvusDevServiceCfg { + + public OptionalInt fixedEtcdPort; + private boolean devServicesEnabled; + private OptionalInt fixedMilvusPort; + private String milvusImageName; + private String etcdImageName; + private String minioImageName; + private String serviceName; + private boolean shared; + + public MilvusDevServiceCfg(MilvusBuildConfig.MilvusDevServicesBuildTimeConfig devservices) { + this.devServicesEnabled = devservices.enabled(); + this.fixedMilvusPort = devservices.port(); + this.milvusImageName = devservices.milvusImageName(); + this.etcdImageName = devservices.etcdImageName(); + this.minioImageName = devservices.minioImageName(); + this.serviceName = devservices.serviceName(); + this.shared = devservices.shared(); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + MilvusDevServiceCfg that = (MilvusDevServiceCfg) o; + return devServicesEnabled == that.devServicesEnabled && + shared == that.shared && + Objects.equals(fixedMilvusPort, that.fixedMilvusPort) && + Objects.equals(milvusImageName, that.milvusImageName) && + Objects.equals(etcdImageName, that.etcdImageName) && + Objects.equals(minioImageName, that.minioImageName) && + Objects.equals(serviceName, that.serviceName); + } + + @Override + public int hashCode() { + return Objects.hash(devServicesEnabled, fixedMilvusPort, milvusImageName, + etcdImageName, minioImageName, serviceName, shared); + } + } + + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + private static class ConfiguredMilvusContainer extends GenericContainer { + private final OptionalInt fixedExposedPort; + private String hostName = null; + + public ConfiguredMilvusContainer(DockerImageName dockerImageName, + OptionalInt fixedExposedPort, + String serviceName) { + super(dockerImageName); + this.fixedExposedPort = fixedExposedPort; + if (serviceName != null) { + withLabel(DEV_SERVICE_LABEL, serviceName); + } + } + + @Override + protected void configure() { + super.configure(); + this.setCommand("milvus", "run", "standalone"); + setWaitStrategy(new LogMessageWaitStrategy().withRegEx(".*QueryNode successfully started.*\\s")); + hostName = ConfigureUtil.configureSharedNetwork(this, "milvus"); + if (fixedExposedPort.isPresent()) { + addFixedExposedPort(fixedExposedPort.getAsInt(), MILVUS_PORT); + } else { + addExposedPort(MILVUS_PORT); + } + } + + public int getPort() { + if (fixedExposedPort.isPresent()) { + return fixedExposedPort.getAsInt(); + } + return super.getMappedPort(MILVUS_PORT); + } + + @Override + public String getHost() { + return hostName; + } + + } + + private static class ConfiguredMinioContainer extends GenericContainer { + + private String hostName = null; + + public ConfiguredMinioContainer(DockerImageName dockerImageName, + String serviceName) { + super(dockerImageName); + if (serviceName != null) { + withLabel(DEV_SERVICE_LABEL, serviceName); + } + } + + @Override + protected void configure() { + super.configure(); + this.setCommand("server", "--console-address", ":9001", "/data"); + hostName = ConfigureUtil.configureSharedNetwork(this, "minio"); + } + + public int getPort() { + return MINIO_PORT; + } + + @Override + public String getHost() { + return hostName; + } + } + + private static class ConfiguredEtcdContainer extends GenericContainer { + + private String hostName = null; + + public ConfiguredEtcdContainer(DockerImageName dockerImageName, + String serviceName) { + super(dockerImageName); + if (serviceName != null) { + withLabel(DEV_SERVICE_LABEL, serviceName); + } + } + + @Override + protected void configure() { + super.configure(); + hostName = ConfigureUtil.configureSharedNetwork(this, "etcd"); + } + + public int getPort() { + return ETCD_PORT; + } + + @Override + public String getHost() { + return hostName; + } + } +} diff --git a/milvus/deployment/src/main/java/io/quarkiverse/langchain4j/milvus/MilvusProcessor.java b/milvus/deployment/src/main/java/io/quarkiverse/langchain4j/milvus/MilvusProcessor.java new file mode 100644 index 000000000..e7363284e --- /dev/null +++ b/milvus/deployment/src/main/java/io/quarkiverse/langchain4j/milvus/MilvusProcessor.java @@ -0,0 +1,52 @@ +package io.quarkiverse.langchain4j.milvus; + +import jakarta.enterprise.context.ApplicationScoped; + +import org.jboss.jandex.ClassType; +import org.jboss.jandex.DotName; +import org.jboss.jandex.ParameterizedType; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore; +import io.quarkiverse.langchain4j.deployment.EmbeddingStoreBuildItem; +import io.quarkiverse.langchain4j.milvus.runtime.MilvusRecorder; +import io.quarkiverse.langchain4j.milvus.runtime.MilvusRuntimeConfig; +import io.quarkus.arc.deployment.SyntheticBeanBuildItem; +import io.quarkus.deployment.annotations.BuildProducer; +import io.quarkus.deployment.annotations.BuildStep; +import io.quarkus.deployment.annotations.ExecutionTime; +import io.quarkus.deployment.annotations.Record; +import io.quarkus.deployment.builditem.FeatureBuildItem; + +public class MilvusProcessor { + + public static final DotName MILVUS_EMBEDDING_STORE = DotName.createSimple(MilvusEmbeddingStore.class); + public static final String FEATURE = "langchain4j-milvus"; + + @BuildStep + FeatureBuildItem feature() { + return new FeatureBuildItem(FEATURE); + } + + @BuildStep + @Record(ExecutionTime.RUNTIME_INIT) + public void createBean( + BuildProducer beanProducer, + MilvusRecorder recorder, + MilvusRuntimeConfig config, + BuildProducer embeddingStoreProducer) { + beanProducer.produce(SyntheticBeanBuildItem + .configure(MILVUS_EMBEDDING_STORE) + .types(ClassType.create(EmbeddingStore.class), + ParameterizedType.create(EmbeddingStore.class, ClassType.create(TextSegment.class))) + .defaultBean() + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.milvusStoreSupplier(config)) + .done()); + embeddingStoreProducer.produce(new EmbeddingStoreBuildItem()); + } + +} diff --git a/milvus/deployment/src/test/java/io/quarkiverse/langchain4j/milvus/deployment/MilvusEmbeddingStoreTest.java b/milvus/deployment/src/test/java/io/quarkiverse/langchain4j/milvus/deployment/MilvusEmbeddingStoreTest.java new file mode 100644 index 000000000..d4d89b360 --- /dev/null +++ b/milvus/deployment/src/test/java/io/quarkiverse/langchain4j/milvus/deployment/MilvusEmbeddingStoreTest.java @@ -0,0 +1,261 @@ +package io.quarkiverse.langchain4j.milvus.deployment; + +import static dev.langchain4j.internal.Utils.randomUUID; +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.data.Percentage.withPercentage; + +import java.util.List; + +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.CosineSimilarity; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.RelevanceScore; +import io.milvus.client.MilvusClient; +import io.milvus.client.MilvusServiceClient; +import io.milvus.grpc.MutationResult; +import io.milvus.param.ConnectParam; +import io.milvus.param.R; +import io.milvus.param.collection.LoadCollectionParam; +import io.milvus.param.dml.DeleteParam; +import io.quarkus.logging.Log; +import io.quarkus.test.QuarkusUnitTest; + +public class MilvusEmbeddingStoreTest { + + public static final String COLLECTION_NAME = "test_embeddings"; + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addAsResource(new StringAsset( + "quarkus.langchain4j.milvus.collection-name=" + COLLECTION_NAME + "\n" + + "quarkus.langchain4j.milvus.devservices.port=19530\n" + + "quarkus.langchain4j.milvus.dimension=384"), + "application.properties")); + + @Inject + EmbeddingStore embeddingStore; + private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + /** + * Delete all embeddings from the collection before each test. + */ + @AfterEach + public void cleanup() { + ConnectParam connectParam = ConnectParam.newBuilder() + .withHost("localhost") + .withPort(19530) + .build(); + MilvusClient client = new MilvusServiceClient(connectParam); + client.loadCollection(LoadCollectionParam.newBuilder().withCollectionName(COLLECTION_NAME).build()); + R deleteResult = client.delete(DeleteParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + // seems we can't just say "delete all entries", but + // can provide a predicate that is always false + .withExpr("id != 'BLABLA'") + .build()); + Log.info("Deleted: " + deleteResult.getData().getDeleteCnt()); + client.close(); + } + + @Test + void should_add_embedding() { + Embedding embedding = embeddingModel.embed(randomUUID()).content(); + + String id = embeddingStore.add(embedding); + assertThat(id).isNotNull(); + + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isNull(); + } + + @Test + void should_add_embedding_with_id() { + String id = randomUUID(); + Embedding embedding = embeddingModel.embed(randomUUID()).content(); + + embeddingStore.add(id, embedding); + + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isNull(); + } + + @Test + void should_add_embedding_with_segment() { + TextSegment segment = TextSegment.from(randomUUID()); + Embedding embedding = embeddingModel.embed(segment.text()).content(); + + String id = embeddingStore.add(embedding, segment); + assertThat(id).isNotNull(); + + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isEqualTo(segment); + } + + @Disabled("Milvus store doesn't support storing metadata yet") + @Test + void should_add_embedding_with_segment_with_metadata() { + TextSegment segment = TextSegment.from(randomUUID(), Metadata.from("test-key", "test-value")); + Embedding embedding = embeddingModel.embed(segment.text()).content(); + + String id = embeddingStore.add(embedding, segment); + + assertThat(id).isNotNull(); + + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isEqualTo(segment); + } + + @Test + void should_add_multiple_embeddings() { + Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); + Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); + + List ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding)); + assertThat(ids).hasSize(2); + + List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); + assertThat(relevant).hasSize(2); + + EmbeddingMatch firstMatch = relevant.get(0); + assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); + assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); + assertThat(firstMatch.embedded()).isNull(); + + EmbeddingMatch secondMatch = relevant.get(1); + assertThat(secondMatch.score()).isBetween(0d, 1d); + assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); + assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); + assertThat(secondMatch.embedded()).isNull(); + } + + @Test + void should_add_multiple_embeddings_with_segments() { + TextSegment firstSegment = TextSegment.from(randomUUID()); + Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content(); + TextSegment secondSegment = TextSegment.from(randomUUID()); + Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content(); + + List ids = embeddingStore.addAll( + asList(firstEmbedding, secondEmbedding), + asList(firstSegment, secondSegment)); + assertThat(ids).hasSize(2); + + List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); + assertThat(relevant).hasSize(2); + + EmbeddingMatch firstMatch = relevant.get(0); + assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); + assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); + assertThat(firstMatch.embedded()).isEqualTo(firstSegment); + + EmbeddingMatch secondMatch = relevant.get(1); + assertThat(secondMatch.score()).isBetween(0d, 1d); + assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); + assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); + assertThat(secondMatch.embedded()).isEqualTo(secondSegment); + } + + @Test + void should_find_with_min_score() { + String firstId = randomUUID(); + Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); + embeddingStore.add(firstId, firstEmbedding); + + String secondId = randomUUID(); + Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); + embeddingStore.add(secondId, secondEmbedding); + + List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); + assertThat(relevant).hasSize(2); + EmbeddingMatch firstMatch = relevant.get(0); + assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.embeddingId()).isEqualTo(firstId); + EmbeddingMatch secondMatch = relevant.get(1); + assertThat(secondMatch.score()).isBetween(0d, 1d); + assertThat(secondMatch.embeddingId()).isEqualTo(secondId); + + List> relevant2 = embeddingStore.findRelevant( + firstEmbedding, + 10, + secondMatch.score() - 0.01); + assertThat(relevant2).hasSize(2); + assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId); + assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId); + + List> relevant3 = embeddingStore.findRelevant( + firstEmbedding, + 10, + secondMatch.score()); + assertThat(relevant3).hasSize(2); + assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId); + assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId); + + List> relevant4 = embeddingStore.findRelevant( + firstEmbedding, + 10, + secondMatch.score() + 0.01); + assertThat(relevant4).hasSize(1); + assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId); + } + + @Test + void should_return_correct_score() { + Embedding embedding = embeddingModel.embed("hello").content(); + + String id = embeddingStore.add(embedding); + assertThat(id).isNotNull(); + + Embedding referenceEmbedding = embeddingModel.embed("hi").content(); + + List> relevant = embeddingStore.findRelevant(referenceEmbedding, 1); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo( + RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)), + withPercentage(1)); + } +} diff --git a/milvus/pom.xml b/milvus/pom.xml new file mode 100644 index 000000000..77a755d14 --- /dev/null +++ b/milvus/pom.xml @@ -0,0 +1,20 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-parent + 999-SNAPSHOT + + quarkus-langchain4j-milvus-parent + Quarkus Langchain4j - Milvus embedding store - Parent + pom + + + deployment + runtime + + + diff --git a/milvus/runtime/pom.xml b/milvus/runtime/pom.xml new file mode 100644 index 000000000..4ea514e8a --- /dev/null +++ b/milvus/runtime/pom.xml @@ -0,0 +1,83 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-milvus-parent + 999-SNAPSHOT + + quarkus-langchain4j-milvus + Quarkus Langchain4j - Milvus embedding store - Runtime + + + io.quarkus + quarkus-arc + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core + ${project.version} + + + dev.langchain4j + langchain4j-milvus + ${langchain4j.version} + + + + + + io.quarkus + quarkus-extension-maven-plugin + ${quarkus.version} + + + compile + + extension-descriptor + + + ${project.groupId}:${project.artifactId}-deployment:${project.version} + + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + maven-jar-plugin + + + generate-codestart-jar + generate-resources + + jar + + + ${project.basedir}/src/main + + codestarts/** + + codestarts + true + + + + + + + + diff --git a/milvus/runtime/src/main/java/io/quarkiverse/langchain4j/milvus/runtime/MilvusRecorder.java b/milvus/runtime/src/main/java/io/quarkiverse/langchain4j/milvus/runtime/MilvusRecorder.java new file mode 100644 index 000000000..b961f59ba --- /dev/null +++ b/milvus/runtime/src/main/java/io/quarkiverse/langchain4j/milvus/runtime/MilvusRecorder.java @@ -0,0 +1,32 @@ +package io.quarkiverse.langchain4j.milvus.runtime; + +import java.util.function.Supplier; + +import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore; +import io.quarkus.runtime.annotations.Recorder; + +@Recorder +public class MilvusRecorder { + + public Supplier milvusStoreSupplier(MilvusRuntimeConfig config) { + return new Supplier<>() { + @Override + public MilvusEmbeddingStore get() { + return new MilvusEmbeddingStore.Builder() + .host(config.host()) + .port(config.port()) + .collectionName(config.collectionName()) + .dimension(config.dimension().orElse(null)) + .indexType(config.indexType()) + .metricType(config.metricType()) + .token(config.token().orElse(null)) + .username(config.username().orElse(null)) + .password(config.password().orElse(null)) + .consistencyLevel(config.consistencyLevel()) + .retrieveEmbeddingsOnSearch(true) + .databaseName(config.dbName()) + .build(); + } + }; + } +} diff --git a/milvus/runtime/src/main/java/io/quarkiverse/langchain4j/milvus/runtime/MilvusRuntimeConfig.java b/milvus/runtime/src/main/java/io/quarkiverse/langchain4j/milvus/runtime/MilvusRuntimeConfig.java new file mode 100644 index 000000000..742a2698d --- /dev/null +++ b/milvus/runtime/src/main/java/io/quarkiverse/langchain4j/milvus/runtime/MilvusRuntimeConfig.java @@ -0,0 +1,107 @@ +package io.quarkiverse.langchain4j.milvus.runtime; + +import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME; + +import java.time.Duration; +import java.util.Optional; + +import io.milvus.common.clientenum.ConsistencyLevelEnum; +import io.milvus.param.IndexType; +import io.milvus.param.MetricType; +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefault; + +@ConfigRoot(phase = RUN_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.milvus") +public interface MilvusRuntimeConfig { + + /** + * The URL of the Milvus server. + */ + String host(); + + /** + * The port of the Milvus server. + */ + Integer port(); + + /** + * The authentication token for the Milvus server. + */ + Optional token(); + + /** + * The username for the Milvus server. + */ + Optional username(); + + /** + * The password for the Milvus server. + */ + Optional password(); + + /** + * The timeout duration for the Milvus client. If not specified, 5 seconds will be used. + */ + Optional timeout(); + + /** + * Name of the database. + */ + @WithDefault("default") + String dbName(); + + /** + * Create the collection if it does not exist yet. + */ + @WithDefault("true") + boolean createCollection(); + + /** + * Name of the collection. + */ + @WithDefault("embeddings") + String collectionName(); + + /** + * Dimension of the vectors. Only applicable when the collection yet has to be created. + */ + Optional dimension(); + + /** + * TODO + */ + @WithDefault("id") + String primaryField(); + + /** + * Name of the field to store the vector in. + */ + @WithDefault("vector") + String vectorField(); + + /** + * Description of the collection. + */ + Optional description(); + + /** + * The index type to use for the collection. + */ + @WithDefault("FLAT") + IndexType indexType(); + + /** + * The metric type to use for searching. + */ + @WithDefault("COSINE") + MetricType metricType(); + + /** + * The consistency level. + */ + @WithDefault("EVENTUALLY") + ConsistencyLevelEnum consistencyLevel(); + +} diff --git a/milvus/runtime/src/main/resources/META-INF/beans.xml b/milvus/runtime/src/main/resources/META-INF/beans.xml new file mode 100644 index 000000000..e69de29bb diff --git a/milvus/runtime/src/main/resources/META-INF/quarkus-extension.yaml b/milvus/runtime/src/main/resources/META-INF/quarkus-extension.yaml new file mode 100644 index 000000000..cdaea0ca5 --- /dev/null +++ b/milvus/runtime/src/main/resources/META-INF/quarkus-extension.yaml @@ -0,0 +1,12 @@ +name: Langchain4j Milvus embedding store +artifact: ${project.groupId}:${project.artifactId}:${project.version} +description: Provides the Milvus Embedding store for Langchain4j +metadata: + keywords: + - ai + - langchain4j + - openai + - milvus + categories: + - "miscellaneous" + status: "experimental" \ No newline at end of file diff --git a/pom.xml b/pom.xml index bb1367977..dc4af1c09 100644 --- a/pom.xml +++ b/pom.xml @@ -16,6 +16,7 @@ core docs hugging-face + milvus ollama openai/azure-openai openai/openai-common