Skip to content

Commit

Permalink
Allow injecting non-default datasource for PgVector store, produce an…
Browse files Browse the repository at this point in the history
… EmbeddingStoreBuildItem
  • Loading branch information
jmartisk committed Dec 8, 2023
1 parent 685203f commit 72a945b
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 7 deletions.
17 changes: 17 additions & 0 deletions docs/modules/ROOT/pages/includes/quarkus-langchain4j-pgvector.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@ h|[[quarkus-langchain4j-pgvector_configuration]]link:#quarkus-langchain4j-pgvect
h|Type
h|Default

a|icon:lock[title=Fixed at build time] [[quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.datasource]]`link:#quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.datasource[quarkus.langchain4j.pgvector.datasource]`


[.description]
--
The name of the configured Postgres datasource to use for this store. If not set, the default datasource from the Agroal extension will be used.

ifdef::add-copy-button-to-env-var[]
Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_PGVECTOR_DATASOURCE+++[]
endif::add-copy-button-to-env-var[]
ifndef::add-copy-button-to-env-var[]
Environment variable: `+++QUARKUS_LANGCHAIN4J_PGVECTOR_DATASOURCE+++`
endif::add-copy-button-to-env-var[]
--|string
|


a| [[quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.table]]`link:#quarkus-langchain4j-pgvector_quarkus.langchain4j.pgvector.table[quarkus.langchain4j.pgvector.table]`


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package io.quarkiverse.langchain4j.pgvector.deployment;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Default;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.ClassType;
import org.jboss.jandex.DotName;
import org.jboss.jandex.ParameterizedType;
Expand All @@ -11,9 +13,11 @@
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingStore;
import io.agroal.api.AgroalDataSource;
import io.quarkiverse.langchain4j.deployment.EmbeddingStoreBuildItem;
import io.quarkiverse.langchain4j.pgvector.PgVectorEmbeddingStore;
import io.quarkiverse.langchain4j.pgvector.runtime.PgVectorEmbeddingStoreConfig;
import io.quarkiverse.langchain4j.pgvector.runtime.PgVectorEmbeddingStoreRecorder;
import io.quarkus.agroal.DataSource;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
Expand All @@ -38,18 +42,30 @@ FeatureBuildItem feature() {
public void createBean(
BuildProducer<SyntheticBeanBuildItem> beanProducer,
PgVectorEmbeddingStoreRecorder recorder,
PgVectorEmbeddingStoreConfig config) {
PgVectorEmbeddingStoreConfig config,
PgVectorEmbeddingStoreBuildTimeConfig buildTimeConfig,
BuildProducer<EmbeddingStoreBuildItem> embeddingStoreProducer) {
String datasourceName = buildTimeConfig.datasource().orElse(null);
AnnotationInstance datasourceQualifier;
if (datasourceName == null) {
datasourceQualifier = AnnotationInstance.builder(Default.class).build();

} else {
datasourceQualifier = AnnotationInstance.builder(DataSource.class)
.add("value", datasourceName)
.build();
}
beanProducer.produce(SyntheticBeanBuildItem
.configure(PGVECTOR_EMBEDDING_STORE)
.types(ClassType.create(EmbeddingStore.class),
ParameterizedType.create(EmbeddingStore.class, ClassType.create(TextSegment.class)))
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.addInjectionPoint(ClassType.create(DotName.createSimple(AgroalDataSource.class)))
.createWith(recorder.embeddingStoreFunction(config))
.createWith(recorder.embeddingStoreFunction(config, buildTimeConfig.datasource().orElse(null)))
.addInjectionPoint(ClassType.create(DotName.createSimple(AgroalDataSource.class)), datasourceQualifier)
.done());

embeddingStoreProducer.produce(new EmbeddingStoreBuildItem());
}

@BuildStep
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package io.quarkiverse.langchain4j.pgvector.deployment;

import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME;

import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigRoot;
import io.smallrye.config.ConfigMapping;

@ConfigRoot(phase = BUILD_TIME)
@ConfigMapping(prefix = "quarkus.langchain4j.pgvector")
public interface PgVectorEmbeddingStoreBuildTimeConfig {

/**
* The name of the configured Postgres datasource to use for this store. If not set,
* the default datasource from the Agroal extension will be used.
*/
Optional<String> datasource();

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,24 @@

import io.agroal.api.AgroalDataSource;
import io.quarkiverse.langchain4j.pgvector.PgVectorEmbeddingStore;
import io.quarkus.agroal.DataSource.DataSourceLiteral;
import io.quarkus.arc.SyntheticCreationalContext;
import io.quarkus.runtime.annotations.Recorder;

@Recorder
public class PgVectorEmbeddingStoreRecorder {

public Function<SyntheticCreationalContext<PgVectorEmbeddingStore>, PgVectorEmbeddingStore> embeddingStoreFunction(
PgVectorEmbeddingStoreConfig config) {
PgVectorEmbeddingStoreConfig config, String datasourceName) {
return new Function<>() {
@Override
public PgVectorEmbeddingStore apply(SyntheticCreationalContext<PgVectorEmbeddingStore> context) {
AgroalDataSource dataSource;
//TODO handle named datasources
dataSource = context.getInjectedReference(AgroalDataSource.class, new Default.Literal());
if (datasourceName == null) {
dataSource = context.getInjectedReference(AgroalDataSource.class, new Default.Literal());
} else {
dataSource = context.getInjectedReference(AgroalDataSource.class, new DataSourceLiteral(datasourceName));
}
return new PgVectorEmbeddingStore(dataSource, config.table(), config.dimension(), config.useIndex(),
config.indexListSize(), config.createTable(), config.dropTableFirst());
}
Expand Down

0 comments on commit 72a945b

Please sign in to comment.