Skip to content

Commit

Permalink
Merge pull request #1050 from sberyozkin/generic_model_auth_provider_…
Browse files Browse the repository at this point in the history
…with_named_models

Allow using generic ModelAuthProvider with named models
  • Loading branch information
geoand authored Nov 6, 2024
2 parents 974e59e + 54606f2 commit 237418a
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package io.quarkiverse.langchain4j.test.auth;

import static org.junit.jupiter.api.Assertions.assertTrue;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkiverse.langchain4j.auth.ModelAuthProvider;
import io.quarkus.test.QuarkusUnitTest;

public class AllModelAuthProvidersTest {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(GeminiModelAuthProvider.class, OpenaiModelAuthProvider.class, GlobalModelAuthProvider.class));

@Test
void testThatGlobalModelAuthProviderIsSelectedWithNullModel() {
assertTrue(ModelAuthProvider.resolve(null).get() instanceof GlobalModelAuthProvider);
}

@Test
void testThatOpenAIModelAuthProviderIsSelectedForOpenaiModel() {
assertTrue(ModelAuthProvider.resolve("openai").get() instanceof OpenaiModelAuthProvider);
}

@Test
void testThatGeminiModelAuthProviderIsSelectedForGeminiModel() {
assertTrue(ModelAuthProvider.resolve("gemini").get() instanceof GeminiModelAuthProvider);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package io.quarkiverse.langchain4j.test.auth;

import jakarta.enterprise.context.ApplicationScoped;

import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.auth.ModelAuthProvider;

@ModelName("gemini")
@ApplicationScoped
public class GeminiModelAuthProvider implements ModelAuthProvider {

@Override
public String getAuthorization(Input input) {
return null;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.quarkiverse.langchain4j.test.auth;

import jakarta.enterprise.context.ApplicationScoped;

import io.quarkiverse.langchain4j.auth.ModelAuthProvider;

@ApplicationScoped
public class GlobalModelAuthProvider implements ModelAuthProvider {

@Override
public String getAuthorization(Input input) {
return null;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package io.quarkiverse.langchain4j.test.auth;

import static org.junit.jupiter.api.Assertions.assertTrue;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkiverse.langchain4j.auth.ModelAuthProvider;
import io.quarkus.test.QuarkusUnitTest;

public class GlobalModelAuthProviderTest {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClass(GlobalModelAuthProvider.class));

@Test
void testThatGlobalModelAuthProviderIsSelectedWithNullModel() {
assertTrue(ModelAuthProvider.resolve(null).get() instanceof GlobalModelAuthProvider);
}

@Test
void testThatGlobalOpenAIModelAuthProviderIsSelectedForOpenaiModel() {
assertTrue(ModelAuthProvider.resolve("openai").get() instanceof GlobalModelAuthProvider);
}

@Test
void testThatGlobalModelAuthProviderIsSelectedForGeminiModel() {
assertTrue(ModelAuthProvider.resolve("gemini").get() instanceof GlobalModelAuthProvider);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package io.quarkiverse.langchain4j.test.auth;

import static org.junit.jupiter.api.Assertions.assertTrue;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkiverse.langchain4j.auth.ModelAuthProvider;
import io.quarkus.test.QuarkusUnitTest;

public class NamedModelAuthProvidersTest {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(OpenaiModelAuthProvider.class, GeminiModelAuthProvider.class));

@Test
void testThatNoModelAuthProviderIsSelectedWithNullModel() {
assertTrue(ModelAuthProvider.resolve(null).isEmpty());
}

@Test
void testThatGlobalOpenAIModelAuthProviderIsSelectedForOpenaiModel() {
assertTrue(ModelAuthProvider.resolve("openai").get() instanceof OpenaiModelAuthProvider);
}

@Test
void testThatGlobalModelAuthProviderIsSelectedForGeminiModel() {
assertTrue(ModelAuthProvider.resolve("gemini").get() instanceof GeminiModelAuthProvider);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package io.quarkiverse.langchain4j.test.auth;

import jakarta.enterprise.context.ApplicationScoped;

import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.auth.ModelAuthProvider;

@ModelName("openai")
@ApplicationScoped
public class OpenaiModelAuthProvider implements ModelAuthProvider {

@Override
public String getAuthorization(Input input) {
return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,69 @@

import io.quarkiverse.langchain4j.ModelName;

/**
* Model authentication providers can be used to supply credentials such as access tokens, API keys, and other type of
* credentials.
*
* Providers which support a specific named model only must be annotated with a {@link ModelName} annotation.
*/
public interface ModelAuthProvider {

/**
* Provide authorization data which will be set as an HTTP Authorization header value.
*
* @param input representation of an HTTP request to the model provider.
* @return authorization data which must include an HTTP Authorization scheme value, for example: "Bearer the_access_token".
*/
String getAuthorization(Input input);

/*
* Representation of an HTTP request to the model provider
*/
interface Input {
/*
* HTTP request method, such as POST or GET
*/
String method();

/*
* HTTP request URI
*/
URI uri();

/*
* HTTP request headers
*/
Map<String, List<Object>> headers();
}

/**
* Resolve ModelAuthProvider.
*
* @param modelName the model name. If the model name is not null then a ModelAuthProvider with a matching {@link ModelName}
* annotation are preferred to a global ModelAuthProvider.
* @return Resolved ModelAuthProvider as an Optional value which will be empty if no ModelAuthProvider is available.
*/
static Optional<ModelAuthProvider> resolve(String modelName) {
Instance<ModelAuthProvider> beanInstance = modelName == null
? CDI.current().select(ModelAuthProvider.class)
: CDI.current().select(ModelAuthProvider.class, ModelName.Literal.of(modelName));

//get the first one without causing a bean1 resolution exception
ModelAuthProvider authorizer = null;
for (var handle : beanInstance.handles()) {
authorizer = handle.get();
break;
// If a model is named then try to find ModelAuthProvider matching this model only
if (modelName != null) {
Instance<ModelAuthProvider> beanInstance = CDI.current().select(ModelAuthProvider.class,
ModelName.Literal.of(modelName));

for (var handle : beanInstance.handles()) {
authorizer = handle.get();
break;
}
}
// Find a generic ModelAuthProvider if no model specific ModelAuthProvider is available
if (authorizer == null) {
Instance<ModelAuthProvider> beanInstance = CDI.current().select(ModelAuthProvider.class);
for (var handle : beanInstance.handles()) {
authorizer = handle.get();
break;
}
}
return Optional.ofNullable(authorizer);
}

static Optional<ModelAuthProvider> resolve() {
return resolve(null);
}

}

0 comments on commit 237418a

Please sign in to comment.