From a70d5a39de17236c67b99127482c5e114f839e12 Mon Sep 17 00:00:00 2001 From: Colin McCloskey Date: Wed, 4 Oct 2023 16:02:39 -0400 Subject: [PATCH] Formatting, removing comments, removing now-duplicated system message test --- .../meta/cp4m/llm/HuggingFaceLlamaPlugin.java | 10 +- .../llm/HuggingFaceLlamaPromptBuilder.java | 38 +- .../cp4m/llm/HuggingFaceLlamaPluginTest.java | 518 +++++++++--------- 3 files changed, 264 insertions(+), 302 deletions(-) diff --git a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java index e5a678c..abff7df 100644 --- a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java +++ b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java @@ -14,9 +14,11 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import com.meta.cp4m.message.Message; import com.meta.cp4m.message.ThreadState; + import java.io.IOException; import java.net.URI; import java.time.Instant; + import org.apache.hc.client5.http.fluent.Request; import org.apache.hc.client5.http.fluent.Response; import org.apache.hc.core5.http.ContentType; @@ -29,11 +31,11 @@ public class HuggingFaceLlamaPlugin implements LLMPlugin { public HuggingFaceLlamaPlugin(HuggingFaceConfig config) { this.config = config; - this.endpoint = this.config.endpoint(); + this.endpoint = this.config.endpoint(); } - @Override - public T handle(ThreadState threadState) throws IOException { + @Override + public T handle(ThreadState threadState) throws IOException { T fromUser = threadState.tail(); ObjectNode body = MAPPER.createObjectNode(); @@ -48,7 +50,7 @@ public T handle(ThreadState threadState) throws IOException { HuggingFaceLlamaPromptBuilder promptBuilder = new HuggingFaceLlamaPromptBuilder<>(); String prompt = promptBuilder.createPrompt(threadState, config); - if (prompt.equals("I'm sorry but that request was too long for me.")){ + if (prompt.equals("I'm sorry but that request was too long for me.")) { return threadState.newMessageFromBot( Instant.now(), prompt); } diff --git a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPromptBuilder.java b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPromptBuilder.java index 5bb04f1..974d3ef 100644 --- a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPromptBuilder.java +++ b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPromptBuilder.java @@ -50,10 +50,9 @@ public String createPrompt(ThreadState threadState, HuggingFaceConfig config) LOGGER.error("Failed to initialize Llama2 tokenizer from local file", e); } - if(config.systemMessage().isPresent()){ + if (config.systemMessage().isPresent()) { return "[INST] <>\n" + (config.systemMessage().get()) + "\n<>\n\n" + threadState.messages().get(threadState.messages().size() - 1) + " [/INST] "; - } - else{ + } else { return "[INST] " + threadState.messages().get(threadState.messages().size() - 1) + " [/INST] "; } } @@ -68,52 +67,35 @@ private String pruneMessages(ThreadState threadState, HuggingFaceConfig confi int totalTokens = 5; // Account for closing tokens at end of message StringBuilder promptStringBuilder = new StringBuilder(); - if(config.systemMessage().isPresent()){ + if (config.systemMessage().isPresent()) { String systemPrompt = "[INST] <>\n" + config.systemMessage().get() + "\n<>\n\n"; totalTokens += tokenCount(systemPrompt, tokenizer); promptStringBuilder.append("[INST] <>\n").append(config.systemMessage().get()).append("\n<>\n\n"); - } - else { + } else { totalTokens += 6; promptStringBuilder.append("[INST] "); } -// for (int i = list.size() - 1; i >= 0; i--) -// { -// // access elements by their index (position) -// System.out.println(list.get(i)); -// } - -// Okay so we have a system prompt stringbuilder and then a context stringbuilder and we add those together and -// only if context stringbuilde ris empty do we return the "too long" message - - - - // The first user input is _not_ stripped -// boolean hasUserMessage = false; Message.Role nextMessageSender = Message.Role.ASSISTANT; StringBuilder contextStringBuilder = new StringBuilder(); List messages = threadState.messages(); - for (int i = messages.size() - 1; i >= 0; i--) - { + for (int i = messages.size() - 1; i >= 0; i--) { Message message = messages.get(i); StringBuilder messageText = new StringBuilder(); String text = message.message().strip(); Message.Role user = message.role(); boolean isUser = user == Message.Role.USER; - // access elements by their index (position) messageText.append(text); - if (isUser && nextMessageSender == Message.Role.ASSISTANT){ + if (isUser && nextMessageSender == Message.Role.ASSISTANT) { messageText.append(" [/INST] "); - } - else if (user == Message.Role.ASSISTANT && nextMessageSender == Message.Role.USER){ + } else if (user == Message.Role.ASSISTANT && nextMessageSender == Message.Role.USER) { messageText.append(" [INST] "); } totalTokens += tokenCount(messageText.toString(), tokenizer); - if(totalTokens > config.maxInputTokens()){ - if(contextStringBuilder.isEmpty()){ + if (totalTokens > config.maxInputTokens()) { + if (contextStringBuilder.isEmpty()) { return "I'm sorry but that request was too long for me."; } break; @@ -122,7 +104,7 @@ else if (user == Message.Role.ASSISTANT && nextMessageSender == Message.Role.USE nextMessageSender = user; } - if(nextMessageSender == Message.Role.ASSISTANT){ + if (nextMessageSender == Message.Role.ASSISTANT) { contextStringBuilder.append(" ]TSNI/[ "); // Reversed [/INST] to close instructions for when first message after system prompt is not from user } diff --git a/src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java b/src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java index d0f7d62..28a4bd7 100644 --- a/src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java +++ b/src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java @@ -25,6 +25,7 @@ import com.meta.cp4m.store.ChatStore; import com.meta.cp4m.store.MemoryStoreConfig; import io.javalin.Javalin; + import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; @@ -39,6 +40,7 @@ import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.Stream; + import org.apache.hc.client5.http.fluent.Request; import org.apache.hc.core5.http.HttpResponse; import org.apache.hc.core5.net.URIBuilder; @@ -50,296 +52,272 @@ public class HuggingFaceLlamaPluginTest { - private static final ObjectMapper MAPPER = new ObjectMapper(); - public static final ArrayNode SAMPLE_RESPONSE = MAPPER.createArrayNode(); - private static final String PATH = "/"; - private static final String TEST_MESSAGE = "this is a test message"; - private static final String TEST_SYSTEM_MESSAGE = "this is a system message"; - private static final String TEST_PAYLOAD = "[INST] test message [/INST]"; - private static final String TEST_PAYLOAD_WITH_SYSTEM = - "[INST] <>\nthis is a system message\n<>\n\nthis is a test message [/INST]"; - - private static final ThreadState STACK = - ThreadState.of( - MessageFactory.instance(FBMessage.class) - .newMessage( - Instant.now(), - "test message", - Identifier.random(), - Identifier.random(), - Identifier.random(), - Role.USER)); - - static { - SAMPLE_RESPONSE.addObject().put("generated_text", TEST_MESSAGE); - } - - private BlockingQueue HuggingFaceLlamaRequests; - private Javalin app; - private URI endpoint; - private ObjectNode minimalConfig; - - static Stream modelOptions() { - Set non_model_options = Set.of("name", "type", "api_key", "max_input_tokens"); - return HuggingFaceConfigTest.CONFIG_ITEMS.stream() - .filter(c -> !non_model_options.contains(c.key())); - } - - @BeforeEach - void setUp() throws UnknownHostException, URISyntaxException { - HuggingFaceLlamaRequests = new LinkedBlockingDeque<>(); - app = Javalin.create(); - app.before( - PATH, - ctx -> - HuggingFaceLlamaRequests.add( - new OutboundRequest(ctx.body(), ctx.headerMap(), ctx.queryParamMap()))); - app.post(PATH, ctx -> ctx.result(MAPPER.writeValueAsString(SAMPLE_RESPONSE))); - app.start(0); - endpoint = - URIBuilder.localhost().setScheme("http").appendPath(PATH).setPort(app.port()).build(); - } - - @Test - void sampleValid() throws IOException, InterruptedException { - String apiKey = UUID.randomUUID().toString(); - HuggingFaceConfig config = - HuggingFaceConfig.builder(apiKey).endpoint(endpoint.toString()).tokenLimit(100).build(); - HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); - FBMessage message = plugin.handle(STACK); - assertThat(message.message()).isEqualTo(TEST_MESSAGE); - assertThat(message.role()).isSameAs(Role.ASSISTANT); - assertThatCode(() -> STACK.with(message)).doesNotThrowAnyException(); - @Nullable OutboundRequest or = HuggingFaceLlamaRequests.poll(500, TimeUnit.MILLISECONDS); - assertThat(or).isNotNull(); - assertThat(or.headerMap().get("Authorization")).isNotNull().isEqualTo("Bearer " + apiKey); - } - - @Test - void createPayload() { - String apiKey = UUID.randomUUID().toString(); - HuggingFaceConfig config = - HuggingFaceConfig.builder(apiKey).endpoint(endpoint.toString()).tokenLimit(100).build(); - HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); - HuggingFaceLlamaPromptBuilder promptBuilder = new HuggingFaceLlamaPromptBuilder<>(); - String createdPayload = promptBuilder.createPrompt(STACK, config); - assertThat(createdPayload).isEqualTo(TEST_PAYLOAD); - } - - @Test - void createPayloadWithSystemMessage() { - String apiKey = UUID.randomUUID().toString(); - HuggingFaceConfig config = - HuggingFaceConfig.builder(apiKey).endpoint(endpoint.toString()).tokenLimit(100).systemMessage(TEST_SYSTEM_MESSAGE).build(); - HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); - ThreadState stack = - ThreadState.of( - MessageFactory.instance(FBMessage.class) - .newMessage( - Instant.now(), - TEST_MESSAGE, - Identifier.random(), - Identifier.random(), - Identifier.random(), - Role.USER)); - HuggingFaceLlamaPromptBuilder promptBuilder = new HuggingFaceLlamaPromptBuilder<>(); - String createdPayload = promptBuilder.createPrompt(stack, config); - assertThat(createdPayload).isEqualTo(TEST_PAYLOAD_WITH_SYSTEM); - } + private static final ObjectMapper MAPPER = new ObjectMapper(); + public static final ArrayNode SAMPLE_RESPONSE = MAPPER.createArrayNode(); + private static final String PATH = "/"; + private static final String TEST_MESSAGE = "this is a test message"; + private static final String TEST_SYSTEM_MESSAGE = "this is a system message"; + private static final String TEST_PAYLOAD = "[INST] test message [/INST]"; + private static final String TEST_PAYLOAD_WITH_SYSTEM = + "[INST] <>\nthis is a system message\n<>\n\nthis is a test message [/INST]"; - @Test - void createPayloadWithConfigSystemMessage() { - String apiKey = UUID.randomUUID().toString(); - HuggingFaceConfig config = - HuggingFaceConfig.builder(apiKey) - .endpoint(endpoint.toString()) - .tokenLimit(100) - .systemMessage(TEST_SYSTEM_MESSAGE) - .build(); - HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); - ThreadState stack = - ThreadState.of( - MessageFactory.instance(FBMessage.class) - .newMessage( - Instant.now(), - TEST_MESSAGE, - Identifier.random(), - Identifier.random(), - Identifier.random(), - Role.USER)); - HuggingFaceLlamaPromptBuilder promptBuilder = new HuggingFaceLlamaPromptBuilder<>(); - String createdPayload = promptBuilder.createPrompt(stack, config); - assertThat(createdPayload).isEqualTo(TEST_PAYLOAD_WITH_SYSTEM); - } - - @Test - void contextTooBig() throws IOException { - String apiKey = UUID.randomUUID().toString(); - HuggingFaceConfig config = - HuggingFaceConfig.builder(apiKey) - .endpoint(endpoint.toString()) - .tokenLimit(200) - .maxInputTokens(100) - .systemMessage(TEST_SYSTEM_MESSAGE) - .build(); - HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); - ThreadState thread = + private static final ThreadState STACK = ThreadState.of( MessageFactory.instance(FBMessage.class) .newMessage( Instant.now(), - Stream.generate(() -> "0123456789").limit(100).collect(Collectors.joining()), + "test message", Identifier.random(), Identifier.random(), Identifier.random(), Role.USER)); - FBMessage response = plugin.handle(thread); - assertThat(response.message()).isEqualTo("I'm sorry but that request was too long for me."); - } - @Test - void truncatesContext() throws IOException { - String apiKey = UUID.randomUUID().toString(); - HuggingFaceConfig config = - HuggingFaceConfig.builder(apiKey) - .endpoint(endpoint.toString()) - .tokenLimit(200) - .maxInputTokens(100) - .build(); - HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); - ThreadState thread = - ThreadState.of( - MessageFactory.instance(FBMessage.class) - .newMessage( - Instant.now(), - Stream.generate(() -> "0123456789").limit(100).collect(Collectors.joining()), - Identifier.random(), - Identifier.random(), - Identifier.random(), - Role.USER)); - thread = thread.with(thread.newMessageFromUser(Instant.now(), "test message", Identifier.from(2))); - HuggingFaceLlamaPromptBuilder promptBuilder = new HuggingFaceLlamaPromptBuilder<>(); - String createdPayload = promptBuilder.createPrompt(thread, config); - assertThat(createdPayload).isEqualTo(TEST_PAYLOAD); - } + static { + SAMPLE_RESPONSE.addObject().put("generated_text", TEST_MESSAGE); + } - @BeforeEach - void setUpMinConfig() { - minimalConfig = MAPPER.createObjectNode(); - HuggingFaceConfigTest.CONFIG_ITEMS.forEach( - t -> { - if (t.required()) { - minimalConfig.set(t.key(), t.validValue()); - } - }); - } + private BlockingQueue HuggingFaceLlamaRequests; + private Javalin app; + private URI endpoint; + private ObjectNode minimalConfig; - @ParameterizedTest - @MethodSource("modelOptions") - void validConfigValues(HuggingFaceConfigTest.ConfigItem configItem) - throws IOException, InterruptedException { - minimalConfig.set(configItem.key(), configItem.validValue()); - minimalConfig.put("endpoint", endpoint.toString()); // needs the correct endpoint to run - HuggingFaceConfig config = - ConfigurationUtils.jsonMapper().convertValue(minimalConfig, HuggingFaceConfig.class); - HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); - FBMessage message = plugin.handle(STACK); - assertThat(message.message()).isEqualTo(TEST_MESSAGE); - assertThat(message.role()).isSameAs(Role.ASSISTANT); - assertThatCode(() -> STACK.with(message)).doesNotThrowAnyException(); - @Nullable OutboundRequest or = HuggingFaceLlamaRequests.poll(500, TimeUnit.MILLISECONDS); - assertThat(or).isNotNull(); - System.out.println(or); - assertThat(or.headerMap().get("Authorization")) - .isNotNull() - .isEqualTo("Bearer " + config.apiKey()); - } + static Stream modelOptions() { + Set non_model_options = Set.of("name", "type", "api_key", "max_input_tokens"); + return HuggingFaceConfigTest.CONFIG_ITEMS.stream() + .filter(c -> !non_model_options.contains(c.key())); + } - @Test - void orderedCorrectly() throws IOException, InterruptedException { - HuggingFaceConfig config = - HuggingFaceConfig.builder("lkjasdlkjasdf") - .maxInputTokens(100) - .tokenLimit(200) - .endpoint(endpoint.toString()) - .build(); - HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); - ThreadState stack = - ThreadState.of( - MessageFactory.instance(FBMessage.class) - .newMessage( - Instant.now(), - "1", - Identifier.random(), - Identifier.random(), - Identifier.random(), - Role.SYSTEM)); - stack = stack.with(stack.newMessageFromUser(Instant.now(), "2", Identifier.from(2))); - stack = stack.with(stack.newMessageFromUser(Instant.now(), "3", Identifier.from(3))); - stack = stack.with(stack.newMessageFromUser(Instant.now(), "4", Identifier.from(4))); - plugin.handle(stack); - @Nullable OutboundRequest or = HuggingFaceLlamaRequests.poll(500, TimeUnit.MILLISECONDS); - assertThat(or).isNotNull(); - JsonNode body = MAPPER.readTree(or.body()); + @BeforeEach + void setUp() throws UnknownHostException, URISyntaxException { + HuggingFaceLlamaRequests = new LinkedBlockingDeque<>(); + app = Javalin.create(); + app.before( + PATH, + ctx -> + HuggingFaceLlamaRequests.add( + new OutboundRequest(ctx.body(), ctx.headerMap(), ctx.queryParamMap()))); + app.post(PATH, ctx -> ctx.result(MAPPER.writeValueAsString(SAMPLE_RESPONSE))); + app.start(0); + endpoint = + URIBuilder.localhost().setScheme("http").appendPath(PATH).setPort(app.port()).build(); + } - int prevMessageIndex = 0; - for (int i = 0; i < stack.messages().size(); i++) { - FBMessage stackMessage = stack.messages().get(i); - String sentMessage = body.get("inputs").textValue(); - int index = sentMessage.indexOf(stackMessage.message()); - int finalPrevMessageIndex = prevMessageIndex; - assertSoftly(s -> s.assertThat(index).isGreaterThan(finalPrevMessageIndex)); - prevMessageIndex = index; + @Test + void sampleValid() throws IOException, InterruptedException { + String apiKey = UUID.randomUUID().toString(); + HuggingFaceConfig config = + HuggingFaceConfig.builder(apiKey).endpoint(endpoint.toString()).tokenLimit(100).build(); + HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); + FBMessage message = plugin.handle(STACK); + assertThat(message.message()).isEqualTo(TEST_MESSAGE); + assertThat(message.role()).isSameAs(Role.ASSISTANT); + assertThatCode(() -> STACK.with(message)).doesNotThrowAnyException(); + @Nullable OutboundRequest or = HuggingFaceLlamaRequests.poll(500, TimeUnit.MILLISECONDS); + assertThat(or).isNotNull(); + assertThat(or.headerMap().get("Authorization")).isNotNull().isEqualTo("Bearer " + apiKey); } - } - @Test - void inPipeline() throws IOException, URISyntaxException, InterruptedException { - ChatStore store = MemoryStoreConfig.of(1, 1).toStore(); - String appSecret = "app secret"; - String accessToken = "access token"; - String verifyToken = "verify token"; + @Test + void createPayload() { + String apiKey = UUID.randomUUID().toString(); + HuggingFaceConfig config = + HuggingFaceConfig.builder(apiKey).endpoint(endpoint.toString()).tokenLimit(100).build(); + HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); + HuggingFaceLlamaPromptBuilder promptBuilder = new HuggingFaceLlamaPromptBuilder<>(); + String createdPayload = promptBuilder.createPrompt(STACK, config); + assertThat(createdPayload).isEqualTo(TEST_PAYLOAD); + } + + @Test + void createPayloadWithSystemMessage() { + String apiKey = UUID.randomUUID().toString(); + HuggingFaceConfig config = + HuggingFaceConfig.builder(apiKey).endpoint(endpoint.toString()).tokenLimit(100).systemMessage(TEST_SYSTEM_MESSAGE).build(); + HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); + ThreadState stack = + ThreadState.of( + MessageFactory.instance(FBMessage.class) + .newMessage( + Instant.now(), + TEST_MESSAGE, + Identifier.random(), + Identifier.random(), + Identifier.random(), + Role.USER)); + HuggingFaceLlamaPromptBuilder promptBuilder = new HuggingFaceLlamaPromptBuilder<>(); + String createdPayload = promptBuilder.createPrompt(stack, config); + assertThat(createdPayload).isEqualTo(TEST_PAYLOAD_WITH_SYSTEM); + } - BlockingQueue metaRequests = new LinkedBlockingDeque<>(); - String metaPath = "/meta"; - URI messageReceiver = - URIBuilder.localhost().appendPath(metaPath).setScheme("http").setPort(app.port()).build(); - app.post( - metaPath, - ctx -> - metaRequests.put( - new OutboundRequest(ctx.body(), ctx.headerMap(), ctx.queryParamMap()))); - FBMessageHandler handler = - new FBMessageHandler(verifyToken, accessToken, appSecret) - .baseURLFactory(ignored -> messageReceiver); + @Test + void contextTooBig() throws IOException { + String apiKey = UUID.randomUUID().toString(); + HuggingFaceConfig config = + HuggingFaceConfig.builder(apiKey) + .endpoint(endpoint.toString()) + .tokenLimit(200) + .maxInputTokens(100) + .systemMessage(TEST_SYSTEM_MESSAGE) + .build(); + HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); + ThreadState thread = + ThreadState.of( + MessageFactory.instance(FBMessage.class) + .newMessage( + Instant.now(), + Stream.generate(() -> "0123456789").limit(100).collect(Collectors.joining()), + Identifier.random(), + Identifier.random(), + Identifier.random(), + Role.USER)); + FBMessage response = plugin.handle(thread); + assertThat(response.message()).isEqualTo("I'm sorry but that request was too long for me."); + } - String apiKey = "api key"; - HuggingFaceConfig config = - HuggingFaceConfig.builder(apiKey).endpoint(endpoint.toString()).tokenLimit(1000).build(); - HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); + @Test + void truncatesContext() throws IOException { + String apiKey = UUID.randomUUID().toString(); + HuggingFaceConfig config = + HuggingFaceConfig.builder(apiKey) + .endpoint(endpoint.toString()) + .tokenLimit(200) + .maxInputTokens(100) + .build(); + HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); + ThreadState thread = + ThreadState.of( + MessageFactory.instance(FBMessage.class) + .newMessage( + Instant.now(), + Stream.generate(() -> "0123456789").limit(100).collect(Collectors.joining()), + Identifier.random(), + Identifier.random(), + Identifier.random(), + Role.USER)); + thread = thread.with(thread.newMessageFromUser(Instant.now(), "test message", Identifier.from(2))); + HuggingFaceLlamaPromptBuilder promptBuilder = new HuggingFaceLlamaPromptBuilder<>(); + String createdPayload = promptBuilder.createPrompt(thread, config); + assertThat(createdPayload).isEqualTo(TEST_PAYLOAD); + } - String webhookPath = "/webhook"; - Service service = new Service<>(store, handler, plugin, webhookPath); - ServicesRunner runner = ServicesRunner.newInstance().service(service).port(0); - runner.start(); + @BeforeEach + void setUpMinConfig() { + minimalConfig = MAPPER.createObjectNode(); + HuggingFaceConfigTest.CONFIG_ITEMS.forEach( + t -> { + if (t.required()) { + minimalConfig.set(t.key(), t.validValue()); + } + }); + } - // TODO: create test harness - Request request = - FBMessageHandlerTest.createMessageRequest(FBMessageHandlerTest.SAMPLE_MESSAGE, runner); - HttpResponse response = request.execute().returnResponse(); - assertThat(response.getCode()).isEqualTo(200); - @Nullable OutboundRequest or = HuggingFaceLlamaRequests.poll(500, TimeUnit.MILLISECONDS); - assertThat(or).isNotNull(); - assertThat(or.headerMap().get("Authorization")) - .isNotNull() - .isEqualTo("Bearer " + config.apiKey()); - JsonNode body = ConfigurationUtils.jsonMapper().readTree(or.body()); + @ParameterizedTest + @MethodSource("modelOptions") + void validConfigValues(HuggingFaceConfigTest.ConfigItem configItem) + throws IOException, InterruptedException { + minimalConfig.set(configItem.key(), configItem.validValue()); + minimalConfig.put("endpoint", endpoint.toString()); // needs the correct endpoint to run + HuggingFaceConfig config = + ConfigurationUtils.jsonMapper().convertValue(minimalConfig, HuggingFaceConfig.class); + HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); + FBMessage message = plugin.handle(STACK); + assertThat(message.message()).isEqualTo(TEST_MESSAGE); + assertThat(message.role()).isSameAs(Role.ASSISTANT); + assertThatCode(() -> STACK.with(message)).doesNotThrowAnyException(); + @Nullable OutboundRequest or = HuggingFaceLlamaRequests.poll(500, TimeUnit.MILLISECONDS); + assertThat(or).isNotNull(); + System.out.println(or); + assertThat(or.headerMap().get("Authorization")) + .isNotNull() + .isEqualTo("Bearer " + config.apiKey()); + } - or = metaRequests.poll(500, TimeUnit.MILLISECONDS); - // plugin output got back to meta - assertThat(or).isNotNull().satisfies(r -> assertThat(r.body()).contains(TEST_MESSAGE)); - } + @Test + void orderedCorrectly() throws IOException, InterruptedException { + HuggingFaceConfig config = + HuggingFaceConfig.builder("lkjasdlkjasdf") + .maxInputTokens(100) + .tokenLimit(200) + .endpoint(endpoint.toString()) + .build(); + HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); + ThreadState stack = + ThreadState.of( + MessageFactory.instance(FBMessage.class) + .newMessage( + Instant.now(), + "1", + Identifier.random(), + Identifier.random(), + Identifier.random(), + Role.SYSTEM)); + stack = stack.with(stack.newMessageFromUser(Instant.now(), "2", Identifier.from(2))); + stack = stack.with(stack.newMessageFromUser(Instant.now(), "3", Identifier.from(3))); + stack = stack.with(stack.newMessageFromUser(Instant.now(), "4", Identifier.from(4))); + plugin.handle(stack); + @Nullable OutboundRequest or = HuggingFaceLlamaRequests.poll(500, TimeUnit.MILLISECONDS); + assertThat(or).isNotNull(); + JsonNode body = MAPPER.readTree(or.body()); - private record OutboundRequest( - String body, Map headerMap, Map> queryParamMap) {} + int prevMessageIndex = 0; + for (int i = 0; i < stack.messages().size(); i++) { + FBMessage stackMessage = stack.messages().get(i); + String sentMessage = body.get("inputs").textValue(); + int index = sentMessage.indexOf(stackMessage.message()); + int finalPrevMessageIndex = prevMessageIndex; + assertSoftly(s -> s.assertThat(index).isGreaterThan(finalPrevMessageIndex)); + prevMessageIndex = index; + } + } + + @Test + void inPipeline() throws IOException, URISyntaxException, InterruptedException { + ChatStore store = MemoryStoreConfig.of(1, 1).toStore(); + String appSecret = "app secret"; + String accessToken = "access token"; + String verifyToken = "verify token"; + + BlockingQueue metaRequests = new LinkedBlockingDeque<>(); + String metaPath = "/meta"; + URI messageReceiver = + URIBuilder.localhost().appendPath(metaPath).setScheme("http").setPort(app.port()).build(); + app.post( + metaPath, + ctx -> + metaRequests.put( + new OutboundRequest(ctx.body(), ctx.headerMap(), ctx.queryParamMap()))); + FBMessageHandler handler = + new FBMessageHandler(verifyToken, accessToken, appSecret) + .baseURLFactory(ignored -> messageReceiver); + + String apiKey = "api key"; + HuggingFaceConfig config = + HuggingFaceConfig.builder(apiKey).endpoint(endpoint.toString()).tokenLimit(1000).build(); + HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); + + String webhookPath = "/webhook"; + Service service = new Service<>(store, handler, plugin, webhookPath); + ServicesRunner runner = ServicesRunner.newInstance().service(service).port(0); + runner.start(); + + // TODO: create test harness + Request request = + FBMessageHandlerTest.createMessageRequest(FBMessageHandlerTest.SAMPLE_MESSAGE, runner); + HttpResponse response = request.execute().returnResponse(); + assertThat(response.getCode()).isEqualTo(200); + @Nullable OutboundRequest or = HuggingFaceLlamaRequests.poll(500, TimeUnit.MILLISECONDS); + assertThat(or).isNotNull(); + assertThat(or.headerMap().get("Authorization")) + .isNotNull() + .isEqualTo("Bearer " + config.apiKey()); + JsonNode body = ConfigurationUtils.jsonMapper().readTree(or.body()); + + or = metaRequests.poll(500, TimeUnit.MILLISECONDS); + // plugin output got back to meta + assertThat(or).isNotNull().satisfies(r -> assertThat(r.body()).contains(TEST_MESSAGE)); + } + + private record OutboundRequest( + String body, Map headerMap, Map> queryParamMap) { + } }