Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show how OIDC ModelAuthProvider can be used with Azure OpenAI #733

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,25 @@ interface Input {
}

static Optional<ModelAuthProvider> resolve(String modelName) {
Instance<ModelAuthProvider> beanInstance = modelName == null
? CDI.current().select(ModelAuthProvider.class)
: CDI.current().select(ModelAuthProvider.class, ModelName.Literal.of(modelName));

//get the first one without causing a bean1 resolution exception
// This will likely need to be refactored again.
// ModelAuthProvider should return a set of supported models (empty by default),
// otherwise the resolution on the main branch does not work for more than one OIDC provider
ModelAuthProvider authorizer = null;
for (var handle : beanInstance.handles()) {
authorizer = handle.get();
break;
if (modelName != null) {
Instance<ModelAuthProvider> beanInstance = CDI.current().select(ModelAuthProvider.class,
ModelName.Literal.of(modelName));

for (var handle : beanInstance.handles()) {
authorizer = handle.get();
break;
}
}
if (authorizer == null) {
Instance<ModelAuthProvider> beanInstance = CDI.current().select(ModelAuthProvider.class);
for (var handle : beanInstance.handles()) {
authorizer = handle.get();
break;
}
}
return Optional.ofNullable(authorizer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
public class DevServicesConfigBuilderCustomizer implements SmallRyeConfigBuilderCustomizer {
@Override
public void configBuilder(final SmallRyeConfigBuilder builder) {
// use a priority of 50 to make sure that this is overridable by any of the standard methods
builder.withSources(
new PropertiesConfigSource(Map.of("quarkus.datasource.devservices.image-name", "pgvector/pgvector:pg17"),
"quarkus-langchain4j-pgvector", 50));
// use a priority of 50 to make sure that this is overridable by any of the
// standard methods
builder.withSources(new PropertiesConfigSource(
Map.of("quarkus.datasource.devservices.image-name", "pgvector/pgvector:pg17"),
"quarkus-langchain4j-pgvector", 50));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
public interface PgVectorEmbeddingStoreBuildTimeConfig {

/**
* The name of the configured Postgres datasource to use for this store. If not set,
* the default datasource from the Agroal extension will be used.
* The name of the configured Postgres datasource to use for this store. If not
* set, the default datasource from the Agroal extension will be used.
*/
Optional<String> datasource();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class PgVectorEmbeddingStoreProcessor {

private static final DotName PG_VECTOR_EMBEDDING_STORE = DotName.createSimple(PgVectorEmbeddingStore.class);
private static final DotName AGROAL_POOL_INTERCEPTOR = DotName.createSimple(AgroalPoolInterceptor.class);
private static final DotName PG_VECTOR_AGROAL_POOL_INTERCEPTOR = DotName.createSimple(PgVectorAgroalPoolInterceptor.class);
private static final DotName PG_VECTOR_AGROAL_POOL_INTERCEPTOR = DotName
.createSimple(PgVectorAgroalPoolInterceptor.class);

private static final String FEATURE = "langchain4j-pgvector";

Expand All @@ -50,39 +51,27 @@ void indexDependencies(BuildProducer<IndexDependencyBuildItem> producer) {

@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
public void createBean(
BuildProducer<SyntheticBeanBuildItem> beanProducer,
PgVectorEmbeddingStoreRecorder recorder,
PgVectorEmbeddingStoreConfig config,
PgVectorEmbeddingStoreBuildTimeConfig buildTimeConfig,
public void createBean(BuildProducer<SyntheticBeanBuildItem> beanProducer, PgVectorEmbeddingStoreRecorder recorder,
PgVectorEmbeddingStoreConfig config, PgVectorEmbeddingStoreBuildTimeConfig buildTimeConfig,
BuildProducer<EmbeddingStoreBuildItem> embeddingStoreProducer) {

AnnotationInstance datasourceQualifier = buildTimeConfig.datasource()
.map(dn -> AnnotationInstance.builder(DataSource.class).add("value", dn).build())
.orElse(AnnotationInstance.builder(Default.class).build());

beanProducer.produce(SyntheticBeanBuildItem
.configure(PG_VECTOR_EMBEDDING_STORE)
beanProducer.produce(SyntheticBeanBuildItem.configure(PG_VECTOR_EMBEDDING_STORE)
.types(ClassType.create(dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore.class),
ClassType.create(EmbeddingStore.class),
ParameterizedType.create(EmbeddingStore.class, ClassType.create(TextSegment.class)))
.setRuntimeInit()
.defaultBean()
.unremovable()
.scope(ApplicationScoped.class)
.setRuntimeInit().defaultBean().unremovable().scope(ApplicationScoped.class)
.createWith(recorder.embeddingStoreFunction(config, buildTimeConfig.datasource().orElse(null)))
.addInjectionPoint(ClassType.create(DotName.createSimple(AgroalDataSource.class)), datasourceQualifier)
.done());

beanProducer.produce(SyntheticBeanBuildItem
.configure(PG_VECTOR_AGROAL_POOL_INTERCEPTOR)
.types(ClassType.create(AGROAL_POOL_INTERCEPTOR))
.setRuntimeInit()
.unremovable()
.scope(ApplicationScoped.class)
.supplier(recorder.pgVectorAgroalPoolInterceptor())
.qualifiers(datasourceQualifier)
.done());
beanProducer.produce(SyntheticBeanBuildItem.configure(PG_VECTOR_AGROAL_POOL_INTERCEPTOR)
.types(ClassType.create(AGROAL_POOL_INTERCEPTOR)).setRuntimeInit().unremovable()
.scope(ApplicationScoped.class).supplier(recorder.pgVectorAgroalPoolInterceptor())
.qualifiers(datasourceQualifier).done());

embeddingStoreProducer.produce(new EmbeddingStoreBuildItem());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ class ColumnsTest extends LangChain4jPgVectorBaseTest {
@RegisterExtension
static final QuarkusUnitTest test = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addAsResource(new StringAsset(
"quarkus.langchain4j.pgvector.dimension=384\n" +
"quarkus.langchain4j.pgvector.drop-table-first=true\n" +
"quarkus.class-loading.parent-first-artifacts=ai.djl.huggingface:tokenizers\n" +
"quarkus.log.category.\"io.quarkiverse.langchain4j.pgvector\".level=DEBUG\n\n" +
"quarkus.langchain4j.pgvector.metadata.storage-mode=COLUMN_PER_KEY\n" +
"quarkus.langchain4j.pgvector.metadata.column-definitions=key text NULL, name text NULL, " +
"age float NULL, city varchar null, country varchar null\n" +
"quarkus.langchain4j.pgvector.metadata.indexes=key, name, age"),
.addAsResource(new StringAsset("quarkus.langchain4j.pgvector.dimension=384\n"
+ "quarkus.langchain4j.pgvector.drop-table-first=true\n"
+ "quarkus.class-loading.parent-first-artifacts=ai.djl.huggingface:tokenizers\n"
+ "quarkus.log.category.\"io.quarkiverse.langchain4j.pgvector\".level=DEBUG\n\n"
+ "quarkus.langchain4j.pgvector.metadata.storage-mode=COLUMN_PER_KEY\n"
+ "quarkus.langchain4j.pgvector.metadata.column-definitions=key text NULL, name text NULL, "
+ "age float NULL, city varchar null, country varchar null\n"
+ "quarkus.langchain4j.pgvector.metadata.indexes=key, name, age"),
"application.properties"));

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,15 @@
public class JSONBMultiIndexTest extends LangChain4jPgVectorBaseTest {

@RegisterExtension
static final QuarkusUnitTest test = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addAsResource(new StringAsset(
"quarkus.langchain4j.pgvector.dimension=384\n" +
"quarkus.langchain4j.pgvector.drop-table-first=true\n" +
"quarkus.class-loading.parent-first-artifacts=ai.djl.huggingface:tokenizers\n" +
"quarkus.log.category.\"io.quarkiverse.langchain4j.pgvector\".level=DEBUG\n\n" +
"quarkus.langchain4j.pgvector.metadata.storage-mode=COMBINED_JSONB\n" +
"quarkus.langchain4j.pgvector.metadata.column-definitions=metadata_b JSONB NULL\n" +
"quarkus.langchain4j.pgvector.metadata.indexes=(metadata_b->'key'), (metadata_b->'name'), (metadata_b->'age')\n"
+
"quarkus.langchain4j.pgvector.metadata.index-type=GIN"),
"application.properties"));
static final QuarkusUnitTest test = new QuarkusUnitTest().setArchiveProducer(() -> ShrinkWrap
.create(JavaArchive.class)
.addAsResource(new StringAsset("quarkus.langchain4j.pgvector.dimension=384\n"
+ "quarkus.langchain4j.pgvector.drop-table-first=true\n"
+ "quarkus.class-loading.parent-first-artifacts=ai.djl.huggingface:tokenizers\n"
+ "quarkus.log.category.\"io.quarkiverse.langchain4j.pgvector\".level=DEBUG\n\n"
+ "quarkus.langchain4j.pgvector.metadata.storage-mode=COMBINED_JSONB\n"
+ "quarkus.langchain4j.pgvector.metadata.column-definitions=metadata_b JSONB NULL\n"
+ "quarkus.langchain4j.pgvector.metadata.indexes=(metadata_b->'key'), (metadata_b->'name'), (metadata_b->'age')\n"
+ "quarkus.langchain4j.pgvector.metadata.index-type=GIN"), "application.properties"));

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@
public class JSONBTest extends LangChain4jPgVectorBaseTest {

@RegisterExtension
static final QuarkusUnitTest test = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addAsResource(new StringAsset(
"quarkus.langchain4j.pgvector.dimension=384\n" +
"quarkus.langchain4j.pgvector.drop-table-first=true\n" +
"quarkus.class-loading.parent-first-artifacts=ai.djl.huggingface:tokenizers\n" +
"quarkus.log.category.\"io.quarkiverse.langchain4j.pgvector\".level=DEBUG\n\n" +
"quarkus.langchain4j.pgvector.metadata.storage-mode=COMBINED_JSONB\n" +
"quarkus.langchain4j.pgvector.metadata.column-definitions=metadata JSONB NULL\n" +
"quarkus.langchain4j.pgvector.metadata.indexes=metadata"),
"application.properties"));
static final QuarkusUnitTest test = new QuarkusUnitTest().setArchiveProducer(() -> ShrinkWrap
.create(JavaArchive.class)
.addAsResource(new StringAsset("quarkus.langchain4j.pgvector.dimension=384\n"
+ "quarkus.langchain4j.pgvector.drop-table-first=true\n"
+ "quarkus.class-loading.parent-first-artifacts=ai.djl.huggingface:tokenizers\n"
+ "quarkus.log.category.\"io.quarkiverse.langchain4j.pgvector\".level=DEBUG\n\n"
+ "quarkus.langchain4j.pgvector.metadata.storage-mode=COMBINED_JSONB\n"
+ "quarkus.langchain4j.pgvector.metadata.column-definitions=metadata JSONB NULL\n"
+ "quarkus.langchain4j.pgvector.metadata.indexes=metadata"), "application.properties"));

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@ public class JSONTest extends LangChain4jPgVectorBaseTest {

@RegisterExtension
static final QuarkusUnitTest test = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addAsResource(new StringAsset(
"quarkus.langchain4j.pgvector.dimension=384\n" +
"quarkus.langchain4j.pgvector.drop-table-first=true\n" +
"quarkus.class-loading.parent-first-artifacts=ai.djl.huggingface:tokenizers\n" +
"quarkus.log.category.\"io.quarkiverse.langchain4j.pgvector\".level=DEBUG\n\n"),
"application.properties"));
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addAsResource(
new StringAsset("quarkus.langchain4j.pgvector.dimension=384\n"
+ "quarkus.langchain4j.pgvector.drop-table-first=true\n"
+ "quarkus.class-loading.parent-first-artifacts=ai.djl.huggingface:tokenizers\n"
+ "quarkus.log.category.\"io.quarkiverse.langchain4j.pgvector\".level=DEBUG\n\n"),
"application.properties"));

// Default behavior
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ protected EmbeddingStore<TextSegment> embeddingStore() {
}

/**
* Just for information, not real benchmark.
* JSONTest: Ingesting time 50849 ms. Query average 10 ms
* JSONBTest: Ingesting time 56035 ms. Query average 6 ms.
* Just for information, not real benchmark. JSONTest: Ingesting time 50849 ms.
* Query average 10 ms JSONBTest: Ingesting time 56035 ms. Query average 6 ms.
* JSONBMultiIndexTest: Ingesting time 47344 ms. Query average 6 ms.
* ColumnsTest: Ingesting time 49752 ms. Query average 3 ms.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@ public class PgVectorDataSourceTest {

@RegisterExtension
static final QuarkusUnitTest test = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addAsResource(new StringAsset(
// DevServicesConfigBuilderCustomizer overrides the image-name only
// for the default DS, so in this case we have to override it manually
"quarkus.datasource.embeddings-ds.devservices.image-name=pgvector/pgvector:pg16\n" +
"quarkus.langchain4j.pgvector.datasource=embeddings-ds\n" +
"quarkus.langchain4j.pgvector.dimension=1536\n" +
"quarkus.datasource.embeddings-ds.db-kind=postgresql\n"),
"application.properties"));
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addAsResource(new StringAsset(
// DevServicesConfigBuilderCustomizer overrides the image-name only
// for the default DS, so in this case we have to override it manually
"quarkus.datasource.embeddings-ds.devservices.image-name=pgvector/pgvector:pg16\n"
+ "quarkus.langchain4j.pgvector.datasource=embeddings-ds\n"
+ "quarkus.langchain4j.pgvector.dimension=1536\n"
+ "quarkus.datasource.embeddings-ds.db-kind=postgresql\n"),
"application.properties"));

@io.quarkus.agroal.DataSource("embeddings-ds")
DataSource ds;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.image.DisabledImageModel;
import dev.langchain4j.model.image.ImageModel;
import io.quarkiverse.langchain4j.auth.ModelAuthProvider;
import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiChatModel;
import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiEmbeddingModel;
import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiImageModel;
Expand All @@ -35,6 +36,8 @@
import io.quarkiverse.langchain4j.openai.common.QuarkusOpenAiClient;
import io.quarkiverse.langchain4j.openai.common.runtime.AdditionalPropertiesHack;
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.SyntheticCreationalContext;
import io.quarkus.runtime.ShutdownContext;
import io.quarkus.runtime.annotations.Recorder;
Expand Down Expand Up @@ -295,6 +298,12 @@ private LangChain4jAzureOpenAiConfig.AzureAiConfig correspondingAzureOpenAiConfi

private void throwIfApiKeysNotConfigured(String apiKey, String adToken, String configName) {
if ((apiKey != null) == (adToken != null)) {
ArcContainer container = Arc.container();
if (container != null && container.instance(ModelAuthProvider.class).isAvailable()) {
// Perhaps ModelAuthProvider can provide a method with a default implementation returning a value like `ALL`
// to indicate that it applies to all or a specific model only
return;
}
throw new ConfigValidationException(createKeyMisconfigurationProblem(configName));
}
}
Expand Down
2 changes: 1 addition & 1 deletion samples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
<module>fraud-detection</module>
<module>review-triage</module>
<module>secure-fraud-detection</module>
<module>secure-vertex-ai-gemini-poem</module>
<module>secure-poem</module>
<module>sql-chatbot</module>
</modules>

Expand Down
Loading