From 7591433e712db9064cb02da3f0f4cab2f6baac03 Mon Sep 17 00:00:00 2001 From: Clement Escoffier Date: Wed, 18 Dec 2024 10:42:32 +0100 Subject: [PATCH] Add JUnit 5 `ScorerExtension` for AI model evaluation and associated library - Implement `ScorerExtension` to inject and manage Scorer instances in tests. - Support field and parameter injection for Scorer using `@ScorerConfiguration`. - Add support for parameter injection of samples via `@SampleLocation` annotation. - Provide built-in evaluation strategies: - `SemanticSimilarityStrategy` (cosine similarity-based evaluation). - `AiJudgeStrategy` (AI-powered evaluation with customizable prompts). - Add tests for ScorerExtension: - Validate field and parameter injection of Scorer. - Test sample injection from YAML files. - Verify evaluation strategies and reporting. - Document ScorerExtension: - Explain concepts: Scorer, Samples, Evaluation Strategies, Reports. - Usage examples for field/parameter injection and evaluation. - Guide for using built-in strategies and creating custom strategies. --- docs/modules/ROOT/nav.adoc | 1 + docs/modules/ROOT/pages/testing.adoc | 634 ++++++++++++++++++ pom.xml | 1 + testing/pom.xml | 23 + testing/scorer/pom.xml | 26 + testing/scorer/scorer-core/pom.xml | 64 ++ .../testing/scorer/EvaluationReport.java | 80 +++ .../testing/scorer/EvaluationSample.java | 150 +++++ .../testing/scorer/EvaluationStrategy.java | 19 + .../langchain4j/testing/scorer/Parameter.java | 78 +++ .../testing/scorer/Parameters.java | 158 +++++ .../langchain4j/testing/scorer/Samples.java | 62 ++ .../langchain4j/testing/scorer/Scorer.java | 80 +++ .../testing/scorer/YamlLoader.java | 83 +++ .../testing/scorer/EvaluationReportTest.java | 82 +++ .../testing/scorer/EvaluationSampleTest.java | 147 ++++ .../testing/scorer/ParameterTest.java | 110 +++ .../testing/scorer/ParametersTest.java | 129 ++++ .../testing/scorer/SamplesTest.java | 121 ++++ .../testing/scorer/ScorerTest.java | 105 +++ .../testing/scorer/YamlLoaderTest.java | 149 ++++ testing/scorer/scorer-junit5/pom.xml | 53 ++ .../langchain4j/scorer/junit5/AiScorer.java | 18 + .../scorer/junit5/SampleLocation.java | 22 + .../scorer/junit5/ScorerConfiguration.java | 23 + .../scorer/junit5/ScorerExtension.java | 85 +++ .../junit5/test/ScorerExtensionTest.java | 50 ++ .../src/test/resources/test-samples.yaml | 14 + .../scorer/scorer-strategies/ai-judge/pom.xml | 32 + .../testing/scorer/judge/AiJudgeStrategy.java | 61 ++ .../semantic-similarity/pom.xml | 32 + .../SemanticSimilarityStrategy.java | 88 +++ 32 files changed, 2780 insertions(+) create mode 100644 docs/modules/ROOT/pages/testing.adoc create mode 100644 testing/pom.xml create mode 100644 testing/scorer/pom.xml create mode 100644 testing/scorer/scorer-core/pom.xml create mode 100644 testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReport.java create mode 100644 testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationSample.java create mode 100644 testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationStrategy.java create mode 100644 testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Parameter.java create mode 100644 testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Parameters.java create mode 100644 testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Samples.java create mode 100644 testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java create mode 100644 testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/YamlLoader.java create mode 100644 testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReportTest.java create mode 100644 testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationSampleTest.java create mode 100644 testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ParameterTest.java create mode 100644 testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ParametersTest.java create mode 100644 testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/SamplesTest.java create mode 100644 testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java create mode 100644 testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/YamlLoaderTest.java create mode 100644 testing/scorer/scorer-junit5/pom.xml create mode 100644 testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/AiScorer.java create mode 100644 testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/SampleLocation.java create mode 100644 testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerConfiguration.java create mode 100644 testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerExtension.java create mode 100644 testing/scorer/scorer-junit5/src/test/java/io/quarkiverse/langchain4j/scorer/junit5/test/ScorerExtensionTest.java create mode 100644 testing/scorer/scorer-junit5/src/test/resources/test-samples.yaml create mode 100644 testing/scorer/scorer-strategies/ai-judge/pom.xml create mode 100644 testing/scorer/scorer-strategies/ai-judge/src/main/java/io/quarkiverse/langchain4j/testing/scorer/judge/AiJudgeStrategy.java create mode 100644 testing/scorer/scorer-strategies/semantic-similarity/pom.xml create mode 100644 testing/scorer/scorer-strategies/semantic-similarity/src/main/java/io/quarkiverse/langchain4j/testing/scorer/similarity/SemanticSimilarityStrategy.java diff --git a/docs/modules/ROOT/nav.adoc b/docs/modules/ROOT/nav.adoc index e0b12a371..27ba4133d 100644 --- a/docs/modules/ROOT/nav.adoc +++ b/docs/modules/ROOT/nav.adoc @@ -8,6 +8,7 @@ ** xref:prompt-generation.adoc[Prompt Generation] ** xref:guardrails.adoc[Guardrails] ** xref:response-augmenter.adoc[Response Augmenter] +** xref:testing.adoc[Testing] * LLMs ** xref:llms.adoc[LLMs] diff --git a/docs/modules/ROOT/pages/testing.adoc b/docs/modules/ROOT/pages/testing.adoc new file mode 100644 index 000000000..3ac0cb453 --- /dev/null +++ b/docs/modules/ROOT/pages/testing.adoc @@ -0,0 +1,634 @@ += Testing AI-Infused Applications + +The `quarkus-langchain4j-testing-scorer-junit5` extension provides a pragmatic and extensible testing framework for evaluating AI-infused applications. +It integrates with JUnit 5 and offers tools for automating evaluation processes, scoring outputs, and generating evaluation reports using customizable evaluation strategies. + +== Maven Dependency + +To use the `ScorerExtension`, include the following Maven dependency in your pom.xml: + +[source, xml] +---- + + io.quarkiverse.langchain4j + quarkus-langchain4j-testing-scorer-junit5 + test + +---- + +== Using the extension + +To use the extension, annotate your test class with `@ExtendWith(ScorerExtension.class)` or `@AiScorer`: + +[source,java] +---- +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import io.quarkiverse.langchain4j.scorer.junit5.ScorerExtension; + +@ExtendWith(ScorerExtension.class) +public class MyScorerTests { + + // Test cases go here +} +---- + +Or, you can use the `@AiScorer` annotation: + +[source,java] +---- +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import io.quarkiverse.langchain4j.scorer.junit5.AiScorer; + +@AiScorer +public class MyScorerTests { + + // Test cases go here +} +---- + +This Junit 5 extension can be combined with `@QuarkusTest` to test Quarkus applications: + +[source,java] +---- +import io.quarkus.test.junit.QuarkusTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import io.quarkiverse.langchain4j.scorer.junit5.AiScorer; + +@QuarkusTest +@AiScorer +public class MyScorerTests { + + // Test cases go here +} +---- + +== Concepts + +=== Scorer + +The `Scorer` (`io.quarkiverse.langchain4j.testing.scorer.Scorer`) is a utility that evaluates a _set of samples_ (represented by `io.quarkiverse.langchain4j.testing.scorer.Samples`) against a function (part of the application) and a set of evaluation strategies. +It can run evaluations concurrently and produces an `EvaluationReport` summarizing the results and providing the _score_. + +The _score_ is the percentage of passed evaluations (between 0.0 and 100.0). +It is calculated as the ratio of the number of passed evaluations to the total number of evaluations. + +In general, tests using the `Scorer` follow this pattern: + +[source,java] +---- + +@Inject CustomerSupportAssistant assistant; // The AI Service to test + + @Test +void testAiService(@ScorerConfiguration(concurrency = 5) Scorer scorer, // The scorer instance, with concurrency set to 5 + @SampleLocation("src/test/resources/samples.yaml") Samples samples) { // The samples loaded from a YAML file + + // Define the function that will be evaluated + // The parameters comes from the sample + // The output of this function will be compared to the expected output in the samples + Function function = parameters -> { + return assistant.chat(parameters.get(0)); + }; + + EvaluationReport report = scorer.evaluate(samples, function, + new SemanticSimilarityStrategy(0.8)); // The evaluation strategy + assertThat(report.score()).isGreaterThanOrEqualTo(70); // Assert the score +} +---- + +=== Samples + +A `Sample` (`io.quarkiverse.langchain4j.testing.scorer.EvaluationSample`) represents a single input-output test case. +It includes: +- a name: the name of the sample, +- the parameters: the parameter data for the test, +- the expected output: the expected result that will be evaluated, +- the tags: metadata that can categorize the sample for targeted evaluation (tags are optional). + +When tags are set, the score can be calculated per tag (in addition to the global score). + +A list of samples is represented by `Samples` (`io.quarkiverse.langchain4j.testing.scorer.Samples`). + +Samples can be defined using a builder pattern: + +[source, java] +---- +var s1 = EvaluationSample.builder() + .withName("sample1") + .withParameter("value1") + .withExpectedOutput("my expected result2") + .build(); + + var s2 = EvaluationSample.builder() + .withName("sample2") + .withParameter("value2") + .withExpectedOutput("my expected results") + .build(); + + Samples samples = new Samples<>(List.of(s1, s2)); +---- + +Alternatively, samples can be loaded from a YAML file using the `@SampleLocation` annotation: + +[source, yaml] +---- +- name: Sample1 + parameters: + - "parameter1" + expectedOutput: "expected1" + tags: ["tag1"] +- name: Sample2 + parameters: + - "parameter2" + expectedOutput: "expected2" + tags: ["tag1"] +---- + +=== Evaluation Strategy + +An `EvaluationStrategy` (`io.quarkiverse.langchain4j.testing.scorer.EvaluationStrategy`) defines how to evaluate a sample. +The framework includes ready-to-use strategies (detailed below), and you can implement custom ones. + +[source, java] +---- +/** + * A strategy to evaluate the output of a model. + * @param the type of the output. + */ +public interface EvaluationStrategy { + + /** + * Evaluate the output of a model. + * @param sample the sample to evaluate. + * @param output the output of the model. + * @return {@code true} if the output is correct, {@code false} otherwise. + */ + boolean evaluate(EvaluationSample sample, T output); + +} +---- + +=== Evaluation Report + +The `EvaluationReport` aggregates the results of all evaluations. It provides: + +- a global score (percentage of passed evaluations). +- the scores per tag. +- the possibility to dump the report as Markdown. + +== Writing Tests with Scorer + +=== Example Test Using Field Injection + +[source, java] +---- +@ExtendWith(ScorerExtension.class) +public class ScorerFieldInjectionTest { + + @ScorerConfiguration(concurrency = 4) + private Scorer scorer; + + @Test + void evaluateSamples() { + // Define test samples + Samples samples = new Samples<>( + EvaluationSample.builder().withName("Sample1").withParameter("p1").withExpectedOutput("expected1").build(), + EvaluationSample.builder().withName("Sample2").withParameter("p2").withExpectedOutput("expected2").build() + ); + + // Define evaluation strategies + EvaluationStrategy strategy = new SemanticSimilarityStrategy(0.85); + + // Evaluate samples + EvaluationReport report = scorer.evaluate(samples, parameters -> { + // Replace with your function under test + return "actualOutput"; + }, strategy); + + // Assert results + assertThat(report.score()).isGreaterThan(50.0); + } +} +---- + +=== Example Test Using Parameter Injection + +[source, java] +---- + +@ExtendWith(ScorerExtension.class) +public class ScorerParameterInjectionTest { + + // .... + + @Test + void evaluateWithInjectedScorer( + @ScorerConfiguration(concurrency = 2) Scorer scorer, + @SampleLocation("test-samples.yaml") Samples samples + ) { + // Use an evaluation strategy + EvaluationStrategy strategy = new AiJudgeStrategy(myChatLanguageModel); + + // Evaluate samples + EvaluationReport report = scorer.evaluate(samples, parameters -> { + // Replace with your function under test + return "actualOutput"; + }, strategy); + + // Assert results + assertThat(report.evaluations()).isNotEmpty(); + assertThat(report.score()).isGreaterThan(50.0); + } +} +---- + +== Built-in Evaluation Strategies + +=== Semantic Similarity + +The `SemanticSimilarityStrategy` (`io.quarkiverse.langchain4j.testing.scorer.similarity.SemanticSimilarityStrategy`) evaluates the similarity between the actual output and the expected output using cosine similarity. It requires an embedding model and a minimum similarity threshold. + +**Maven Dependency:** + +[source, xml] +---- + + io.quarkiverse.langchain4j + quarkus-langchain4j-testing-scorer-semantic-similarity + test + +---- + +**Examples:** + +[source, java] +---- +EvaluationStrategy strategy = new SemanticSimilarityStrategy(0.9); +EvaluationStrategy strategy2 = new SemanticSimilarityStrategy(embeddingModel, 0.85); +---- + +=== AI Judge + +The `AiJudgeStrategy` leverages an AI model to determine if the actual output matches the expected output. +It uses a configurable evaluation prompt and `ChatModel`. + +**Maven Dependency** + +[source, xml] +---- + + io.quarkiverse.langchain4j + quarkus-langchain4j-testing-scorer-ai-judge + test + +---- + +**Example:** + +[source, java] +---- +EvaluationStrategy strategy = new AiJudgeStrategy(myChatLanguageModel, """ + You are an AI evaluating a response and the expected output. + You need to evaluate whether the model response is correct or not. + Return true if the response is correct, false otherwise. + + Response to evaluate: {response} + Expected output: {expected_output} + + """); +---- + +== Creating a Custom Evaluation Strategy + +To implement your own evaluation strategy, implement the `EvaluationStrategy` interface: + +[source, java] +---- +import io.quarkiverse.langchain4j.testing.scorer.*; + +public class MyCustomStrategy implements EvaluationStrategy { + + @Override + public boolean evaluate(EvaluationSample sample, String output) { + // Custom evaluation logic + return output.equalsIgnoreCase(sample.expectedOutput()); + } +} +---- + +Then, use the custom strategy in your test: + +[source, java] +---- +EvaluationStrategy strategy = new MyCustomStrategy(); +EvaluationReport report = scorer.evaluate(samples, parameters -> { + return "actualOutput"; +}, strategy); +---- + +Here is an exmaple of a custom strategy that can be used to verify the correctness of a vector search: + +[source, java] +---- +public class TextSegmentEvaluationStrategy implements EvaluationStrategy> { + + @Override + public boolean evaluate(EvaluationSample> sample, List response) { + List expected = sample.expectedOutput(); + int found = 0; + for (String seg : expected) { + // Make sure that the response contains the expected segment + boolean segFound = false; + for (String s : response) { + if (s.toLowerCase().contains(seg.toLowerCase())) { + segFound = true; + found++; + break; + } + } + if (!segFound) { + System.out.println("Segment not found: " + seg); + } + } + return found == expected.size(); + } + + } +---- + +== Injecting Samples + +You can load samples directly from a YAML file using the `@SampleLocation` annotation: + +[source, yaml] +---- +- name: Sample1 + parameters: + - "value1" + expectedOutput: "expected1" + tags: ["tag1"] +- name: Sample2 + parameters: + - "value2" + expectedOutput: "expected2" + tags: ["tag2"] +---- + +Then, inject the samples into your test method: + +[source, java] +---- +@Test +void evaluateWithSamples(@SampleLocation("test-samples.yaml") Samples samples) { + // Use samples in your test +} +---- + +== Example of tests using Quarkus + +Let's imagine an _AI Service_ used by a Chatbot to generate responses. +Let's also imagine that this _AI Service_ has access to a (RAG) _content retriever_. +The associated tests could be: + +[source, java] +---- +package dev.langchain4j.quarkus; + +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.quarkus.workshop.CustomerSupportAssistant; +import dev.langchain4j.rag.AugmentationRequest; +import dev.langchain4j.rag.RetrievalAugmentor; +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.rag.query.Metadata; +import io.quarkiverse.langchain4j.scorer.junit5.AiScorer; +import io.quarkiverse.langchain4j.scorer.junit5.SampleLocation; +import io.quarkiverse.langchain4j.scorer.junit5.ScorerConfiguration; +import io.quarkiverse.langchain4j.testing.scorer.EvaluationReport; +import io.quarkiverse.langchain4j.testing.scorer.EvaluationSample; +import io.quarkiverse.langchain4j.testing.scorer.EvaluationStrategy; +import io.quarkiverse.langchain4j.testing.scorer.Parameters; +import io.quarkiverse.langchain4j.testing.scorer.Samples; +import io.quarkiverse.langchain4j.testing.scorer.Scorer; +import io.quarkiverse.langchain4j.testing.scorer.judge.AiJudgeStrategy; +import io.quarkiverse.langchain4j.testing.scorer.similarity.SemanticSimilarityStrategy; +import io.quarkus.test.junit.QuarkusTest; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.UUID; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThat; + +@QuarkusTest +@AiScorer +public class AssistantTest { + + // Just a function calling the AI Service and returning the response as a String. + @Inject + AiServiceEvaluation aiServiceEvaluation; + + // The content retriever from the RAG pattern I want to test + @Inject + RetrievalAugmentor retriever; + + // Test the AI Service using the Semantic Similarity Strategy + @Test + void testAiService(@ScorerConfiguration(concurrency = 5) Scorer scorer, + @SampleLocation("src/test/resources/samples.yaml") Samples samples) { + + EvaluationReport report = scorer.evaluate(samples, aiServiceEvaluation, + new SemanticSimilarityStrategy(0.8)); + assertThat(report.score()).isGreaterThanOrEqualTo(70); + } + + // Test the AI Service using the AI Judge Strategy + @Test + void testAiServiceUsingAiJudge(Scorer scorer, + @SampleLocation("src/test/resources/samples.yaml") Samples samples) { + var judge = OpenAiChatModel.builder() + .baseUrl("http://localhost:11434/v1") // Ollama + .modelName("mistral") + .build(); + EvaluationReport report = scorer.evaluate(samples, aiServiceEvaluation, + new AiJudgeStrategy(judge)); + assertThat(report.score()).isGreaterThanOrEqualTo(70); + } + + // Evaluation strategy can be CDI beans (which means they can easily be injected) + @Inject + TextSegmentEvaluationStrategy textSegmentEvaluationStrategy; + + // Test of the RAG retriever + @Test + void testRagRetriever(Scorer scorer, @SampleLocation("src/test/resources/content-retriever-samples.yaml") Samples> samples) { + EvaluationReport report = scorer.evaluate(samples, i -> runRetriever(i.get(0)), + textSegmentEvaluationStrategy); + assertThat(report.score()).isEqualTo(100); // Expect full success + } + + private List runRetriever(String query) { + UserMessage message = UserMessage.userMessage(query); + AugmentationRequest request = new AugmentationRequest(message, + new Metadata(message, UUID.randomUUID().toString(), List.of())); + var res = retriever.augment(request); + return res.contents().stream().map(Content::textSegment).map(TextSegment::text).toList(); + } + + @Singleton + public static class AiServiceEvaluation implements Function { + + @Inject + CustomerSupportAssistant assistant; + + @ActivateRequestContext + @Override + public String apply(Parameters params) { + return assistant.chat(UUID.randomUUID().toString(), params.get(0)).collect() + .in(StringBuilder::new, StringBuilder::append).map(StringBuilder::toString).await().indefinitely(); + } + } + + @Singleton + public static class TextSegmentEvaluationStrategy implements EvaluationStrategy> { + + @Override + public boolean evaluate(EvaluationSample> sample, List response) { + List expected = sample.expectedOutput(); + int found = 0; + for (String seg : expected) { + // Make sure that the response contains the expected segment + boolean segFound = false; + for (String s : response) { + if (s.toLowerCase().contains(seg.toLowerCase())) { + segFound = true; + found++; + break; + } + } + if (!segFound) { + System.out.println("Segment not found: " + seg); + } + } + return found == expected.size(); + } + + } +} +---- + +This test class demonstrates how to use the `ScorerExtension` to evaluate an AI Service and a RAG retriever using different strategies. +The associated samples are: + +[source, yaml] +---- +--- +- name: "car types" + parameters: + - "What types of cars do you offer for rental?" + expected-output: | + We offer three categories of cars: + 1. Compact Commuter – Ideal for city driving, fuel-efficient, and budget-friendly. Example: Toyota Corolla, Honda Civic. + 2. Family Explorer SUV – Perfect for family trips with spacious seating for up to 7 passengers. Example: Toyota RAV4, Hyundai Santa Fe. + 3. Luxury Cruiser – Designed for traveling in style with premium features. Example: Mercedes-Benz E-Class, BMW 5 Series. +- name: "cancellation" + parameters: + - "Can I cancel my car rental booking at any time?" + expected-output: | + Our cancellation policy states that reservations can be canceled up to 11 days prior to the start of the booking period. If the booking period is less than 4 days, cancellations are not permitted. +- name: "teaching" + parameters: + - "Am I allowed to use the rental car to teach someone how to drive?" + expected-output: | + No, rental cars from Miles of Smiles cannot be used for teaching someone to drive, as outlined in our Terms of Use under “Use of Vehicle.” +- name: "damages" + parameters: + - "What happens if the car is damaged during my rental period?" + expected-output: | + You will be held liable for any damage, loss, or theft that occurs during the rental period, as stated in our Terms of Use under “Liability.” +- name: "requirements" + parameters: + - "What are the requirements for making a car rental booking?" + expected-output: | + To make a booking, you need to provide accurate, current, and complete information during the reservation process. All bookings are also subject to vehicle availability. +- name: "race" + parameters: + - "Can I use the rental car for a race or rally?" + expected-output: | + No, rental cars must not be used for any race, rally, or contest. This is prohibited as per our Terms of Use under “Use of Vehicle.” +- name: "family" + parameters: + - "Do you offer cars suitable for long family trips?" + expected-output: | + Yes, we recommend the Family Explorer SUV for long family trips. It offers spacious seating for up to seven passengers, ample cargo space, and advanced driver-assistance features. +- name: "alcohol" + parameters: + - "Is there any restriction on alcohol consumption while using the rental car?" + expected-output: | + Yes, you are not allowed to drive the rental car while under the influence of alcohol or drugs. This is strictly prohibited as stated in our Terms of Use. +- name: "other questions" + parameters: + - What should I do if I have questions unrelated to car rentals? + expected-output: | + For questions unrelated to car rentals, I recommend contacting the appropriate department. I’m here to assist with any car rental-related inquiries! +- name: "categories" + parameters: + - "Which car category is best for someone who values luxury and comfort?" + expected-output: | + If you value luxury and comfort, the Luxury Cruiser is the perfect choice. It offers premium interiors, cutting-edge technology, and unmatched comfort for a first-class driving experience. +---- + +and for the content retriever: + +[source, yaml] +---- +--- +- name: cancellation_policy_test + parameters: + - What is the cancellation policy for car rentals? + expected-outputs: + - "Reservations can be cancelled up to 11 days prior to the start of the booking period." + - "If the booking period is less than 4 days, cancellations are not permitted." + +- name: vehicle_restrictions_test + parameters: + - What are the restrictions on how the rental car can be used? + expected-outputs: + - "All cars rented from Miles of Smiles must not be used:" + - "for any illegal purpose or in connection with any criminal offense." + - "for teaching someone to drive." + - "in any race, rally or contest." + - "while under the influence of alcohol or drugs." + +- name: car_types_test + parameters: + - What types of cars are available for rent? + expected-outputs: + - "Compact Commuter" + - "Perfect for city driving and short commutes, this fuel-efficient and easy-to-park car is your ideal companion for urban adventures" + - "Family Explorer SUV" + - "Designed for road trips, family vacations, or adventures with friends, this spacious and versatile SUV offers ample cargo space, comfortable seating for up to seven passengers" + - "Luxury Cruiser" + - "For those who want to travel in style, the Luxury Cruiser delivers unmatched comfort, cutting-edge technology, and a touch of elegance" + +- name: car_damage_liability_test + parameters: + - What happens if I damage the car during my rental period? + expected-outputs: + - "Users will be held liable for any damage, loss, or theft that occurs during the rental period" + +- name: governing_law_test + parameters: + - Under what law are the terms and conditions governed? + expected-outputs: + - "These terms will be governed by and construed in accordance with the laws of the United States of America" + - "Any disputes relating to these terms will be subject to the exclusive jurisdiction of the courts of United States" +---- \ No newline at end of file diff --git a/pom.xml b/pom.xml index c4e6774d8..900506053 100644 --- a/pom.xml +++ b/pom.xml @@ -22,6 +22,7 @@ tools codestarts testing-internal + testing scm:git:git@github.com:quarkiverse/quarkus-langchain4j.git diff --git a/testing/pom.xml b/testing/pom.xml new file mode 100644 index 000000000..bf77933c1 --- /dev/null +++ b/testing/pom.xml @@ -0,0 +1,23 @@ + + + 4.0.0 + + + io.quarkiverse.langchain4j + quarkus-langchain4j-parent + 999-SNAPSHOT + + + + quarkus-langchain4j-testing + Quarkus LangChain4j - Testing + pom + + + scorer + + + + \ No newline at end of file diff --git a/testing/scorer/pom.xml b/testing/scorer/pom.xml new file mode 100644 index 000000000..243c31605 --- /dev/null +++ b/testing/scorer/pom.xml @@ -0,0 +1,26 @@ + + + 4.0.0 + + + io.quarkiverse.langchain4j + quarkus-langchain4j-testing + 999-SNAPSHOT + + + quarkus-langchain4j-testing-scorer-parent + Quarkus LangChain4j - Testing - Scorer Parent + + pom + + + scorer-core + scorer-strategies/semantic-similarity + scorer-strategies/ai-judge + scorer-junit5 + + + + \ No newline at end of file diff --git a/testing/scorer/scorer-core/pom.xml b/testing/scorer/scorer-core/pom.xml new file mode 100644 index 000000000..172682581 --- /dev/null +++ b/testing/scorer/scorer-core/pom.xml @@ -0,0 +1,64 @@ + + + 4.0.0 + + + io.quarkiverse.langchain4j + quarkus-langchain4j-testing-scorer-parent + 999-SNAPSHOT + + + quarkus-langchain4j-testing-scorer-core + Quarkus LangChain4j - Testing - Scorer Core + Provides the core of the scorer testing utilities + + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core + ${project.version} + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + + + + org.assertj + assertj-core + 3.26.3 + test + + + org.junit-pioneer + junit-pioneer + 2.2.0 + test + + + org.junit.jupiter + junit-jupiter-params + test + + + org.junit.jupiter + junit-jupiter-api + test + + + org.junit.jupiter + junit-jupiter-engine + test + + + org.junit.jupiter + junit-jupiter + test + + + + + + \ No newline at end of file diff --git a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReport.java b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReport.java new file mode 100644 index 000000000..b6994fae4 --- /dev/null +++ b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReport.java @@ -0,0 +1,80 @@ +package io.quarkiverse.langchain4j.testing.scorer; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.List; + +/** + * Report of the evaluation of a set of samples. + */ +public class EvaluationReport { + + private final List> evaluations; + private final double score; + + /** + * Create a new evaluation report and computes the global score. + * + * @param evaluations the evaluations, must not be {@code null}, must not be empty. + */ + public EvaluationReport(List> evaluations) { + this.evaluations = evaluations; + this.score = 100.0 * evaluations.stream().filter(Scorer.EvaluationResult::passed).count() / evaluations.size(); + } + + /** + * @return the global score, between 0.0 and 100.0. + */ + public double score() { + return score; + } + + /** + * @return the evaluations + */ + public List> evaluations() { + return evaluations; + } + + /** + * Compute the score for a given tag. + * + * @param tag the tag, must not be {@code null} + * @return the score for the given tag, between 0.0 and 100.0. + */ + public double scoreForTag(String tag) { + return 100.0 * evaluations.stream().filter(e -> e.sample().tags().contains(tag)) + .filter(Scorer.EvaluationResult::passed).count() + / evaluations.stream().filter(e -> e.sample().tags().contains(tag)).count(); + } + + /** + * Write the report to a file using the Markdown syntax. + * + * @param output the output file, must not be {@code null} + * @throws IOException if an error occurs while writing the report + */ + public void writeReport(File output) throws IOException { + StringBuilder buffer = new StringBuilder(); + buffer.append("# Evaluation Report\n\n"); + buffer.append("**Global Score**: ").append(score).append("\n\n"); + + List tags = evaluations.stream().flatMap(e -> e.sample().tags().stream()).distinct().toList(); + if (!tags.isEmpty()) { + buffer.append("## Score per tags\n\n"); + for (String tag : tags) { + buffer.append("- **").append(tag).append("**: ").append(scoreForTag(tag)).append("\n"); + } + } + + buffer.append("\n## Details\n\n"); + for (Scorer.EvaluationResult evaluation : evaluations) { + buffer.append("- ").append(evaluation.sample().name()).append(": ") + .append(evaluation.passed() ? "PASSED" : "FAILED").append("\n"); + } + + Files.write(output.toPath(), buffer.toString().getBytes()); + } + +} diff --git a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationSample.java b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationSample.java new file mode 100644 index 000000000..5dde868b4 --- /dev/null +++ b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationSample.java @@ -0,0 +1,150 @@ +package io.quarkiverse.langchain4j.testing.scorer; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * A sample for evaluation. + * + * @param the type of the expected output. + */ +public record EvaluationSample(String name, Parameters parameters, T expectedOutput, List tags) { + + /** + * Create a new builder. + * + * @param the type of the expected output. + * @return a new builder. + */ + public static Builder builder() { + return new Builder<>(); + } + + /** + * Builder for {@link EvaluationSample}. + * + * @param the type of the expected output. + */ + public static class Builder { + + private String name; + private Parameters parameters = new Parameters(); + private List tags = new ArrayList<>(); + private O expectedOutput; + + /** + * Set the name of the sample. + * + * @param name the name of the sample, must not be {@code null}. + * @return this builder. + */ + public Builder withName(String name) { + this.name = name; + return this; + } + + /** + * Set the parameters of the sample. + * The parameters are the data that will be passed to the function to evaluate. + * + * @param parameters the parameters, must not be {@code null}. + * @return this builder. + */ + public Builder withParameters(Parameters parameters) { + this.parameters = parameters; + return this; + } + + /** + * Adds a parameter to the sample. + * The parameters are the data that will be passed to the function to evaluate. + * Order matters when using index-based parameters. + * + * @param parameter the parameter, must not be {@code null}. + * @return this builder. + */ + public Builder withParameter(Parameter parameter) { + this.parameters.add(parameter); + return this; + } + + /** + * Adds a {@code String} parameter to the sample. + * This is a convenient helper method for adding unnamed string parameters. + * The parameters are the data that will be passed to the function to evaluate. + * Order matters when using index-based parameters. + * + * @param value the parameter value, must not be {@code null}. + * @return this builder. + */ + public Builder withParameter(String value) { + return withParameter(new Parameter.UnnamedParameter(value)); + } + + /** + * Set the expected output of the sample. + * + * @param expectedOutput the expected output, must not be {@code null}. + * @return this builder. + */ + public Builder withExpectedOutput(O expectedOutput) { + this.expectedOutput = expectedOutput; + return this; + } + + /** + * Set the tags of the sample. + * + * @param tags the tags, must not be {@code null}. + * @return this builder. + */ + public Builder withTags(List tags) { + this.tags = new ArrayList<>(tags); + return this; + } + + /** + * Adds a tag to the sample. + * + * @param tag the tag, must not be {@code null}. + * @return this builder. + */ + public Builder withTag(String tag) { + this.tags.add(tag); + return this; + } + + /** + * Adds tags to the sample. + * + * @param tags the tags, must not be {@code null}. + * @return this builder. + */ + public Builder withTags(String... tags) { + return withTags(Arrays.stream(tags).toList()); + } + + /** + * Build the sample. + * + * @return the sample. + */ + public EvaluationSample build() { + if (name == null) { + throw new IllegalArgumentException("Name must not be null"); + } + if (parameters == null) { + throw new IllegalArgumentException("Parameters must not be null"); + } + if (parameters.size() == 0) { + throw new IllegalArgumentException("Parameters must not be empty"); + } + if (expectedOutput == null) { + throw new IllegalArgumentException("Expected output must not be null"); + } + return new EvaluationSample<>(name, parameters, expectedOutput, tags); + } + } + +} diff --git a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationStrategy.java b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationStrategy.java new file mode 100644 index 000000000..ec6ce658e --- /dev/null +++ b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationStrategy.java @@ -0,0 +1,19 @@ +package io.quarkiverse.langchain4j.testing.scorer; + +/** + * A strategy to evaluate the output of a model. + * + * @param the type of the output. + */ +public interface EvaluationStrategy { + + /** + * Evaluate the output of a model. + * + * @param sample the sample to evaluate. + * @param output the output of the model. + * @return {@code true} if the output is correct, {@code false} otherwise. + */ + boolean evaluate(EvaluationSample sample, T output); + +} diff --git a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Parameter.java b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Parameter.java new file mode 100644 index 000000000..20538b6ae --- /dev/null +++ b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Parameter.java @@ -0,0 +1,78 @@ +package io.quarkiverse.langchain4j.testing.scorer; + +import org.eclipse.microprofile.config.spi.Converter; + +import io.smallrye.config.Converters; + +/** + * Represents a parameter passed to the function to evaluate. + */ +public interface Parameter { + /** + * Get the value of the parameter. + * + * @return the value of the parameter, can be {@code null}. + */ + Object value(); + + /** + * Get the value of the parameter and convert it to the given type. + *

+ * If the value is already an instance of the given type, it is returned as is. + * Otherwise, the value is converted to the given type using a converter (from {@link Converters}). + * + * @param clazz the class of the expected value, must not be {@code null} + * @param the type of the expected value + * @return the value of the parameter converted to the given type + * @throws IllegalArgumentException if the value cannot be converted to the given type + */ + default T as(Class clazz) { + if (clazz.isInstance(value())) { + return clazz.cast(value()); + } else { + Converter converter = Converters.getImplicitConverter(clazz); + if (converter != null) { + return converter.convert(value().toString()); + } else { + throw new ClassCastException("Cannot convert " + value() + " to " + clazz); + } + } + } + + /** + * Cast the value of the parameter to the given type. + *

+ * This method is a shortcut for {@code (T) value()}. + * + * @param the type of the expected value + * @return the value of the parameter casted to the given type + */ + @SuppressWarnings("unchecked") + default T cast() { + return (T) value(); + } + + /** + * Create a named parameter. The name is used to identify the parameter. + * + * @param name the name, must not be {@code null} + * @param value the value, can be {@code null} + */ + record NamedParameter(String name, Object value) implements Parameter { + + public NamedParameter { + if (name == null) { + throw new IllegalArgumentException("Name must not be null"); + } + } + } + + /** + * Create an unnamed parameter. + * + * @param value the value, can be {@code null} + */ + record UnnamedParameter(Object value) implements Parameter { + } + +} diff --git a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Parameters.java b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Parameters.java new file mode 100644 index 000000000..c4b4936c8 --- /dev/null +++ b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Parameters.java @@ -0,0 +1,158 @@ +package io.quarkiverse.langchain4j.testing.scorer; + +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; + +/** + * A list of parameters. + * These are parameter passed to the function to evaluate. + * Parameters can be named or not. + */ +public class Parameters implements Iterable { + + /** + * Create a new set of unnamed parameters with the given values. + * + * @param values the values, must not be {@code null}, must not be empty. + * @return the parameters, never {@code null} + */ + public static Parameters of(Object... values) { + if (values == null) { + throw new IllegalArgumentException("Values must not be null"); + } + if (values.length == 0) { + throw new IllegalArgumentException("Values must not be empty"); + } + Parameters parameters = new Parameters(); + for (Object value : values) { + parameters.parameters.add(new Parameter.UnnamedParameter(value)); + } + return parameters; + } + + private final List parameters = new CopyOnWriteArrayList<>(); + + /** + * The number of parameters. + */ + public int size() { + return parameters.size(); + } + + /** + * Get the parameter at the given index. + * The index is the position of the parameter in the list. + * It can be either a named or an unnamed parameter. + * + * @param index the index, must be greater or equal to 0 and less than {@link #size()} + * @param the type of the expected value + * @return the value at the given index + */ + public T get(int index) { + return parameters.get(index).cast(); + } + + /** + * Get the parameter at the given index. + * The index is the position of the parameter in the list. + * It can be either a named or an unnamed parameter. + * + * @param index the index, must be greater or equal to 0 and less than {@link #size()} + * @param clazz the type of the expected value + * @param the type of the expected value + * @return the value at the given index + */ + public T get(int index, Class clazz) { + return parameters.get(index).as(clazz); + } + + /** + * Get the named parameter using its name. + * + * @param name the name of the parameter, must not be {@code null} + * @return the value, can be {@code null} if the value is {@code null}. + * @param the type of the expected value + * @throws IllegalArgumentException if the parameter is not found + */ + public T get(String name) { + for (Parameter parameter : parameters) { + if (parameter instanceof Parameter.NamedParameter namedParam && namedParam.name().equals(name)) { + return namedParam.cast(); + } + } + throw new IllegalArgumentException("Parameter not found: " + name); + } + + /** + * Get the named parameter using its name. + * + * @param name the name of the parameter, must not be {@code null} + * @param clazz the type of the expected value, must not be {@code null} + * @return the value, can be {@code null} if the value is {@code null}. + * @param the type of the expected value + * @throws IllegalArgumentException if the parameter is not found + */ + public T get(String name, Class clazz) { + for (Parameter parameter : parameters) { + if (parameter instanceof Parameter.NamedParameter namedParameter && namedParameter.name().equals(name)) { + return namedParameter.as(clazz); + } + } + throw new IllegalArgumentException("Parameter not found: " + name); + } + + /** + * Get the iterator over the parameter values. + * + * @return the iterator, never {@code null} + */ + @SuppressWarnings("NullableProblems") + @Override + public Iterator iterator() { + return parameters.stream().map(Parameter::value).iterator(); + } + + /** + * Get the array of parameter values. + * + * @return the array, never {@code null} + */ + public Object[] toArray() { + return parameters.stream().map(Parameter::value).toArray(); + } + + /** + * Add a named parameter. + * + * @param name the name, must not be {@code null} + * @param value the value, can be {@code null} + * @return this builder. + */ + public Parameters add(String name, Object value) { + parameters.add(new Parameter.NamedParameter(name, value)); + return this; + } + + /** + * Add an parameter. + * + * @param parameter the parameter, must not be {@code null} + * @return this builder. + */ + public Parameters add(Parameter parameter) { + parameters.add(parameter); + return this; + } + + /** + * Add an unnamed parameter with the given value. + * + * @param value the value, can be {@code null} + * @return this builder. + */ + public Parameters add(Object value) { + parameters.add(new Parameter.UnnamedParameter(value)); + return this; + } +} diff --git a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Samples.java b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Samples.java new file mode 100644 index 000000000..55c0c336b --- /dev/null +++ b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Samples.java @@ -0,0 +1,62 @@ +package io.quarkiverse.langchain4j.testing.scorer; + +import java.util.AbstractList; +import java.util.List; + +/** + * A list of {@link EvaluationSample} instances. + * + * @param the type of the expected output as all samples from the set should have the same type. + */ +public class Samples extends AbstractList> + implements List> { + + private final List> samples; + + /** + * Create a new set of samples. + * + * @param samples the samples, must not be {@code null}, must not be empty. + */ + public Samples(List> samples) { + if (samples == null) { + throw new IllegalArgumentException("Samples must not be null"); + } + if (samples.isEmpty()) { + throw new IllegalArgumentException("Samples must not be empty"); + } + this.samples = samples; + } + + /** + * Create a new set of samples. + * + * @param samples the samples, must not be {@code null}, must not be empty. + */ + @SafeVarargs + public Samples(EvaluationSample... samples) { + if (samples == null) { + throw new IllegalArgumentException("Samples must not be null"); + } + if (samples.length == 0) { + throw new IllegalArgumentException("Samples must not be empty"); + } + this.samples = List.of(samples); + } + + /** + * Gets the sample at the given index. + */ + @Override + public EvaluationSample get(int index) { + return samples.get(index); + } + + /** + * Gets the number of samples. + */ + @Override + public int size() { + return samples.size(); + } +} diff --git a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java new file mode 100644 index 000000000..34153814e --- /dev/null +++ b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java @@ -0,0 +1,80 @@ +package io.quarkiverse.langchain4j.testing.scorer; + +import java.io.Closeable; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.Function; + +import org.jboss.logging.Logger; + +public class Scorer implements Closeable { + + private static final Logger LOG = Logger.getLogger(Scorer.class); + private final ExecutorService executor; + + public Scorer(int concurrency) { + if (concurrency > 1) { + executor = Executors.newFixedThreadPool(concurrency); + } else { + executor = Executors.newSingleThreadExecutor(); + } + } + + public Scorer() { + this(1); + } + + @SuppressWarnings({ "unchecked" }) + public EvaluationReport evaluate(Samples samples, Function function, + EvaluationStrategy... strategies) { + List> evaluations = new CopyOnWriteArrayList<>(); + CountDownLatch latch = new CountDownLatch(samples.size()); + for (EvaluationSample sample : samples) { + // TODO Should we handle the context somehow. + executor.submit(() -> { + try { + var response = execute(sample, function); + LOG.infof("Evaluating sample `%s`", sample.name()); + for (EvaluationStrategy strategy : strategies) { + EvaluationResult evaluation = new EvaluationResult<>(sample, + strategy.evaluate(sample, response)); + LOG.infof("Evaluation of sample `%s` with strategy `%s`: %s", sample.name(), + strategy.getClass().getSimpleName(), + evaluation.passed() ? "OK" : "KO"); + evaluations.add(evaluation); + } + } catch (Throwable e) { + LOG.errorf(e, "Failed to evaluate sample `%s`", sample.name()); + evaluations.add(new EvaluationResult<>(sample, false)); + } finally { + latch.countDown(); + } + }); + } + try { + latch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return new EvaluationReport(evaluations); + } + + public void close() { + executor.shutdown(); + } + + public record EvaluationResult(EvaluationSample sample, boolean passed) { + } + + private T execute(EvaluationSample sample, Function function) { + try { + return function.apply(sample.parameters()); + } catch (Exception e) { + throw new AssertionError("Failed to execute sample " + sample, e); + } + } + +} diff --git a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/YamlLoader.java b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/YamlLoader.java new file mode 100644 index 000000000..01d758e59 --- /dev/null +++ b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/YamlLoader.java @@ -0,0 +1,83 @@ +package io.quarkiverse.langchain4j.testing.scorer; + +import java.io.File; +import java.io.FileReader; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.yaml.snakeyaml.Yaml; + +/** + * Utility to load samples from a YAML file. + */ +public class YamlLoader { + + private YamlLoader() { + // Avoid direct instantiation + } + + /** + * Load samples from a YAML file. + * + * @param path the path to the YAML file, must not be {@code null} + * @return the samples, never {@code null} + * @param the type of the expected output from the samples. + */ + @SuppressWarnings("unchecked") + public static Samples load(String path) { + if (path == null) { + throw new IllegalArgumentException("Path must not be null"); + } + if (path.isBlank()) { + throw new IllegalArgumentException("Path must not be blank"); + } + File file = new File(path); + if (!file.exists()) { + throw new IllegalArgumentException("File not found: " + path); + } + + Yaml yaml = new Yaml(); + Iterable list; + List> samples = new ArrayList<>(); + try (var reader = new FileReader(file)) { + list = yaml.load(reader); + if (list == null) { + throw new RuntimeException("Failed to load sample from " + path); + } + for (Object o : list) { + // Expect Map + Map map = (Map) o; + String name = (String) map.get("name"); + List params = (List) map.get("parameters"); + String expected = (String) map.get("expected-output"); + List expectedList = null; + if (expected == null) { + expectedList = (List) map.get("expected-outputs"); + } + List tags = (List) map.get("tags"); + if (tags == null) { + tags = List.of(); + } + Parameters in = new Parameters(); + if (params == null) { + throw new RuntimeException("Parameters not found for sample " + name); + } + for (String p : params) { + in.add(new Parameter.UnnamedParameter(p)); + } + if (expectedList == null && expected == null) { + throw new RuntimeException("Expected output not found for sample " + name); + } + if (expectedList != null) { + samples.add(new EvaluationSample<>(name, in, (T) expectedList, tags)); + } else { + samples.add(new EvaluationSample<>(name, in, (T) expected, tags)); + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + return new Samples<>(samples); + } +} diff --git a/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReportTest.java b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReportTest.java new file mode 100644 index 000000000..74fcfe322 --- /dev/null +++ b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReportTest.java @@ -0,0 +1,82 @@ +package io.quarkiverse.langchain4j.testing.scorer; + +import static org.assertj.core.api.Assertions.*; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.List; + +import org.junit.jupiter.api.Test; + +class EvaluationReportTest { + + @Test + void globalScoreShouldBeCorrect() { + // Create mock evaluations. + Scorer.EvaluationResult result1 = new Scorer.EvaluationResult<>( + new EvaluationSample<>("Sample1", new Parameters(), "expected", List.of("tag1")), + true); + + Scorer.EvaluationResult result2 = new Scorer.EvaluationResult<>( + new EvaluationSample<>("Sample2", new Parameters(), "expected", List.of("tag2")), + false); + + EvaluationReport report = new EvaluationReport(List.of(result1, result2)); + + // Assertions + assertThat(report.score()).isEqualTo(50.0); // 1 passed out of 2. + } + + @Test + void scoreForTagShouldBeCorrect() { + // Create mock evaluations. + Scorer.EvaluationResult result1 = new Scorer.EvaluationResult<>( + new EvaluationSample<>("Sample1", new Parameters(), "expected", List.of("tag1")), + true); + + Scorer.EvaluationResult result2 = new Scorer.EvaluationResult<>( + new EvaluationSample<>("Sample2", new Parameters(), "expected", List.of("tag2")), + false); + + Scorer.EvaluationResult result3 = new Scorer.EvaluationResult<>( + new EvaluationSample<>("Sample3", new Parameters(), "expected", List.of("tag1", "tag2")), + true); + + EvaluationReport report = new EvaluationReport(List.of(result1, result2, result3)); + + // Assertions + assertThat(report.scoreForTag("tag1")).isEqualTo(100.0); // Both tag1 samples passed. + assertThat(report.scoreForTag("tag2")).isEqualTo(50.0); // 1 passed out of 2 for tag2. + } + + @Test + void writeReportShouldGenerateMarkdownFile() throws IOException { + // Create mock evaluations. + Scorer.EvaluationResult result1 = new Scorer.EvaluationResult<>( + new EvaluationSample<>("Sample1", new Parameters(), "expected", List.of("tag1")), + true); + + Scorer.EvaluationResult result2 = new Scorer.EvaluationResult<>( + new EvaluationSample<>("Sample2", new Parameters(), "expected", List.of("tag2")), + false); + + EvaluationReport report = new EvaluationReport(List.of(result1, result2)); + + // Write the report to a temporary file. + File tempFile = File.createTempFile("evaluation-report", ".md"); + report.writeReport(tempFile); + + // Assertions + assertThat(tempFile).exists(); + String content = Files.readString(tempFile.toPath()); + assertThat(content).contains("# Evaluation Report"); + assertThat(content).contains("**Global Score**: 50.0"); + assertThat(content).contains("## Score per tags"); + assertThat(content).contains("- **tag1**: 100.0"); + assertThat(content).contains("- **tag2**: 0.0"); + assertThat(content).contains("## Details"); + assertThat(content).contains("- Sample1: PASSED"); + assertThat(content).contains("- Sample2: FAILED"); + } +} diff --git a/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationSampleTest.java b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationSampleTest.java new file mode 100644 index 000000000..b7e7f2082 --- /dev/null +++ b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationSampleTest.java @@ -0,0 +1,147 @@ +package io.quarkiverse.langchain4j.testing.scorer; + +import static org.assertj.core.api.Assertions.*; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +class EvaluationSampleTest { + + @Test + void builderShouldCreateEvaluationSample() { + Parameters parameters = Parameters.of(1, "test", 3.14); + EvaluationSample sample = EvaluationSample. builder() + .withName("Sample1") + .withParameters(parameters) + .withExpectedOutput("Output") + .withTags("tag1", "tag2") + .build(); + + assertThat(sample.name()).isEqualTo("Sample1"); + assertThat(sample.parameters()).isEqualTo(parameters); + assertThat(sample.expectedOutput()).isEqualTo("Output"); + assertThat(sample.tags()).containsExactly("tag1", "tag2"); + } + + @Test + void builderShouldThrowExceptionIfNameIsNull() { + Parameters parameters = Parameters.of(1, "test", 3.14); + + assertThatThrownBy(() -> EvaluationSample. builder() + .withParameters(parameters) + .withExpectedOutput("Output") + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Name must not be null"); + } + + @Test + void builderShouldThrowExceptionIfParametersAreNull() { + assertThatThrownBy(() -> EvaluationSample. builder() + .withName("Sample1") + .withParameters(null) + .withExpectedOutput("Output") + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Parameters must not be null"); + } + + @Test + void builderShouldThrowExceptionIfParametersAreEmpty() { + assertThatThrownBy(() -> EvaluationSample. builder() + .withName("Sample1") + .withParameters(new Parameters()) + .withExpectedOutput("Output") + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Parameters must not be empty"); + } + + @Test + void builderShouldThrowExceptionIfExpectedOutputIsNull() { + Parameters parameters = Parameters.of(1, "test", 3.14); + + assertThatThrownBy(() -> EvaluationSample. builder() + .withName("Sample1") + .withParameters(parameters) + .withExpectedOutput(null) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Expected output must not be null"); + } + + @Test + void builderWithParametersShouldAddParametersToEvaluationSample() { + Parameters parameters = new Parameters().add(1).add("test"); + EvaluationSample sample = EvaluationSample. builder() + .withName("Sample2") + .withParameters(parameters) + .withExpectedOutput(42) + .build(); + + assertThat(sample.parameters()).isEqualTo(parameters); + assertThat(sample.parameters().get(0, Integer.class)).isEqualTo(1); + assertThat(sample.parameters().get(1, String.class)).isEqualTo("test"); + } + + @Test + void builderWithParameterShouldAddIndividualParameter() { + EvaluationSample sample = EvaluationSample. builder() + .withName("Sample3") + .withParameter(new Parameter.UnnamedParameter("test")) + .withExpectedOutput(42) + .build(); + + assertThat(sample.parameters().size()).isEqualTo(1); + assertThat((String) sample.parameters().get(0)).isEqualTo("test"); + } + + @Test + void builderWithParameterStringShouldAddStringParameter() { + EvaluationSample sample = EvaluationSample. builder() + .withName("Sample4") + .withParameter("string-param") + .withExpectedOutput(42) + .build(); + + assertThat(sample.parameters().size()).isEqualTo(1); + assertThat((String) sample.parameters().get(0)).isEqualTo("string-param"); + } + + @Test + void builderWithTagsShouldAddTagsToEvaluationSample() { + EvaluationSample sample = EvaluationSample. builder() + .withName("Sample5") + .withParameters(Parameters.of(1, 2)) + .withExpectedOutput("result") + .withTags(List.of("tag1", "tag2")) + .build(); + + assertThat(sample.tags()).containsExactly("tag1", "tag2"); + } + + @Test + void builderWithTagShouldAddSingleTag() { + EvaluationSample sample = EvaluationSample. builder() + .withName("Sample6") + .withParameters(Parameters.of(1, 2)) + .withExpectedOutput("result") + .withTag("tag1") + .build(); + + assertThat(sample.tags()).containsExactly("tag1"); + } + + @Test + void builderWithTagsVarargsShouldAddMultipleTags() { + EvaluationSample sample = EvaluationSample. builder() + .withName("Sample7") + .withParameters(Parameters.of(1, 2)) + .withExpectedOutput("result") + .withTags("tag1", "tag2", "tag3") + .build(); + + assertThat(sample.tags()).containsExactly("tag1", "tag2", "tag3"); + } +} diff --git a/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ParameterTest.java b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ParameterTest.java new file mode 100644 index 000000000..c4d376c73 --- /dev/null +++ b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ParameterTest.java @@ -0,0 +1,110 @@ +package io.quarkiverse.langchain4j.testing.scorer; + +import static org.assertj.core.api.Assertions.*; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +class ParameterTest { + + @Test + void unnamedParameterShouldReturnValue() { + Parameter parameter = new Parameter.UnnamedParameter("testValue"); + Object value = parameter.value(); + assertThat(value).isEqualTo("testValue"); + } + + @SuppressWarnings("CastCanBeRemovedNarrowingVariableType") + @Test + void namedParameterShouldReturnNameAndValue() { + Parameter parameter = new Parameter.NamedParameter("testName", 42); + String name = ((Parameter.NamedParameter) parameter).name(); + Object value = parameter.value(); + + assertThat(name).isEqualTo("testName"); + assertThat(value).isEqualTo(42); + } + + @Test + void asMethodShouldReturnValueIfTypeMatches() { + Parameter parameter = new Parameter.UnnamedParameter(123); + Integer value = parameter.as(Integer.class); + assertThat(value).isEqualTo(123); + } + + @Test + void asMethodShouldConvertValueToTargetTypeIfPossible() { + Parameter parameter = new Parameter.UnnamedParameter("123"); + Integer value = parameter.as(Integer.class); + assertThat(value).isEqualTo(123); + } + + @Test + void asMethodShouldThrowExceptionIfValueCannotBeConverted() { + Parameter parameter = new Parameter.UnnamedParameter("notANumber"); + // When / Then + assertThatThrownBy(() -> parameter.as(Integer.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("SRCFG00020"); + } + + @Test + void asMethodShouldThrowExceptionForMissingConverter() { + Parameter parameter = new Parameter.UnnamedParameter("123"); + + assertThatThrownBy(() -> parameter.as(List.class)) + .isInstanceOf(ClassCastException.class); + } + + @Test + void castMethodShouldReturnCastedValue() { + Parameter parameter = new Parameter.UnnamedParameter("testValue"); + String value = parameter.cast(); + assertThat(value).isEqualTo("testValue"); + } + + @Test + void castMethodShouldThrowClassCastExceptionForInvalidCast() { + Parameter parameter = new Parameter.UnnamedParameter(123); + assertThatThrownBy(() -> parameter.as(List.class)) + .isInstanceOf(ClassCastException.class); + } + + @Test + void namedParameterShouldStoreNameAndValueCorrectly() { + Parameter.NamedParameter parameter = new Parameter.NamedParameter("testName", "testValue"); + String name = parameter.name(); + Object value = parameter.value(); + assertThat(name).isEqualTo("testName"); + assertThat(value).isEqualTo("testValue"); + } + + @Test + void unnamedParameterShouldStoreValueCorrectly() { + Parameter.UnnamedParameter parameter = new Parameter.UnnamedParameter("testValue"); + Object value = parameter.value(); + assertThat(value).isEqualTo("testValue"); + } + + @Test + void unnamedParameterShouldAllowNullValue() { + Parameter.UnnamedParameter parameter = new Parameter.UnnamedParameter(null); + Object value = parameter.value(); + assertThat(value).isNull(); + } + + @Test + void namedParameterShouldAllowNullValue() { + Parameter.NamedParameter parameter = new Parameter.NamedParameter("testName", null); + Object value = parameter.value(); + assertThat(value).isNull(); + } + + @Test + void namedParameterShouldThrowExceptionForNullName() { + assertThatThrownBy(() -> new Parameter.NamedParameter(null, "testValue")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Name"); + } +} diff --git a/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ParametersTest.java b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ParametersTest.java new file mode 100644 index 000000000..b880b8d25 --- /dev/null +++ b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ParametersTest.java @@ -0,0 +1,129 @@ +package io.quarkiverse.langchain4j.testing.scorer; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.Iterator; + +import org.junit.jupiter.api.Test; + +class ParametersTest { + + @Test + void ofShouldCreateParametersWithUnnamedValues() { + Parameters parameters = Parameters.of(1, "test", 3.14); + assertThat(parameters.size()).isEqualTo(3); + assertThat((Integer) parameters.get(0)).isEqualTo(1); + assertThat((String) parameters.get(1)).isEqualTo("test"); + assertThat((double) parameters.get(2)).isEqualTo(3.14); + } + + @SuppressWarnings("DataFlowIssue") + @Test + void ofShouldThrowExceptionForNullValues() { + assertThatThrownBy(() -> Parameters.of((Object[]) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Values must not be null"); + } + + @Test + void ofShouldThrowExceptionForEmptyValues() { + assertThatThrownBy(Parameters::of) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Values must not be empty"); + } + + @Test + void getByIndexShouldReturnCorrectValue() { + Parameters parameters = Parameters.of(1, "test", 3.14); + Integer intValue = parameters.get(0); + String stringValue = parameters.get(1); + Double doubleValue = parameters.get(2); + assertThat(intValue).isEqualTo(1); + assertThat(stringValue).isEqualTo("test"); + assertThat(doubleValue).isEqualTo(3.14); + } + + @Test + void getByIndexShouldThrowExceptionForInvalidIndex() { + Parameters parameters = Parameters.of(1, "test", 3.14); + assertThatThrownBy(() -> parameters.get(-1)) + .isInstanceOf(IndexOutOfBoundsException.class); + assertThatThrownBy(() -> parameters.get(3)) + .isInstanceOf(IndexOutOfBoundsException.class); + } + + @Test + void getByIndexWithTypeShouldConvertValue() { + Parameters parameters = Parameters.of("123"); + Integer value = parameters.get(0, Integer.class); + assertThat(value).isEqualTo(123); + } + + @Test + void getByNameShouldReturnCorrectValue() { + Parameters parameters = new Parameters().add("name1", 42).add("name2", "test"); + Integer intValue = parameters.get("name1"); + String stringValue = parameters.get("name2"); + assertThat(intValue).isEqualTo(42); + assertThat(stringValue).isEqualTo("test"); + } + + @Test + void getByNameShouldThrowExceptionIfNameNotFound() { + Parameters parameters = new Parameters().add("name1", 42); + assertThatThrownBy(() -> parameters.get("unknown")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Parameter not found: unknown"); + } + + @Test + void getByNameWithTypeShouldConvertValue() { + Parameters parameters = new Parameters().add("name", "123"); + Integer value = parameters.get("name", Integer.class); + assertThat(value).isEqualTo(123); + } + + @Test + void iteratorShouldIterateOverParameterValues() { + Parameters parameters = Parameters.of(1, "test", 3.14); + Iterator iterator = parameters.iterator(); + assertThat(iterator).toIterable() + .containsExactly(1, "test", 3.14); + } + + @Test + void toArrayShouldReturnAllParameterValues() { + Parameters parameters = Parameters.of(1, "test", 3.14); + Object[] values = parameters.toArray(); + assertThat(values).containsExactly(1, "test", 3.14); + } + + @Test + void addUnnamedParameterShouldIncreaseSize() { + Parameters parameters = new Parameters(); + parameters.add(42).add("test").add(3.14); + assertThat(parameters.size()).isEqualTo(3); + assertThat((int) parameters.get(0)).isEqualTo(42); + assertThat((String) parameters.get(1)).isEqualTo("test"); + assertThat((double) parameters.get(2)).isEqualTo(3.14); + } + + @Test + void addNamedParameterShouldStoreNameAndValue() { + Parameters parameters = new Parameters(); + parameters.add("name1", 42).add("name2", "test"); + assertThat(parameters.size()).isEqualTo(2); + assertThat((int) parameters.get("name1")).isEqualTo(42); + assertThat((String) parameters.get("name2")).isEqualTo("test"); + } + + @Test + void addParameterShouldAcceptCustomParameter() { + Parameters parameters = new Parameters(); + Parameter parameter = new Parameter.NamedParameter("name", 42); + parameters.add(parameter); + assertThat(parameters.size()).isEqualTo(1); + assertThat(parameters.get("name", Integer.class)).isEqualTo(42); + } +} diff --git a/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/SamplesTest.java b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/SamplesTest.java new file mode 100644 index 000000000..fd4fe88e8 --- /dev/null +++ b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/SamplesTest.java @@ -0,0 +1,121 @@ +package io.quarkiverse.langchain4j.testing.scorer; + +import static org.assertj.core.api.Assertions.*; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +class SamplesTest { + + @Test + void constructorShouldCreateSamplesFromList() { + EvaluationSample sample1 = EvaluationSample. builder() + .withName("Sample1") + .withParameters(Parameters.of("param1")) + .withExpectedOutput("Output1") + .build(); + + EvaluationSample sample2 = EvaluationSample. builder() + .withName("Sample2") + .withParameters(Parameters.of("param2")) + .withExpectedOutput("Output2") + .build(); + + Samples samples = new Samples<>(List.of(sample1, sample2)); + + assertThat(samples).hasSize(2); + assertThat(samples.get(0)).isEqualTo(sample1); + assertThat(samples.get(1)).isEqualTo(sample2); + } + + @Test + void constructorShouldCreateSamplesFromVarargs() { + EvaluationSample sample1 = EvaluationSample. builder() + .withName("Sample1") + .withParameters(Parameters.of("param1")) + .withExpectedOutput("Output1") + .build(); + + EvaluationSample sample2 = EvaluationSample. builder() + .withName("Sample2") + .withParameters(Parameters.of("param2")) + .withExpectedOutput("Output2") + .build(); + + Samples samples = new Samples<>(sample1, sample2); + + assertThat(samples).hasSize(2); + assertThat(samples.get(0)).isEqualTo(sample1); + assertThat(samples.get(1)).isEqualTo(sample2); + } + + @SuppressWarnings("DataFlowIssue") + @Test + void constructorShouldThrowExceptionIfListIsNull() { + assertThatThrownBy(() -> new Samples<>((List>) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Samples must not be null"); + } + + @Test + void constructorShouldThrowExceptionIfVarargsAreNull() { + assertThatThrownBy(() -> new Samples<>((EvaluationSample[]) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Samples must not be null"); + } + + @Test + void constructorShouldThrowExceptionIfListIsEmpty() { + assertThatThrownBy(() -> new Samples<>(List.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Samples must not be empty"); + } + + @SuppressWarnings("unchecked") + @Test + void constructorShouldThrowExceptionIfVarargsAreEmpty() { + assertThatThrownBy(Samples::new) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Samples must not be empty"); + } + + @Test + void getShouldReturnSampleAtSpecifiedIndex() { + EvaluationSample sample1 = EvaluationSample. builder() + .withName("Sample1") + .withParameters(Parameters.of("param1")) + .withExpectedOutput("Output1") + .build(); + + EvaluationSample sample2 = EvaluationSample. builder() + .withName("Sample2") + .withParameters(Parameters.of("param2")) + .withExpectedOutput("Output2") + .build(); + + Samples samples = new Samples<>(sample1, sample2); + + assertThat(samples.get(0)).isEqualTo(sample1); + assertThat(samples.get(1)).isEqualTo(sample2); + } + + @Test + void sizeShouldReturnNumberOfSamples() { + EvaluationSample sample1 = EvaluationSample. builder() + .withName("Sample1") + .withParameters(Parameters.of("param1")) + .withExpectedOutput("Output1") + .build(); + + EvaluationSample sample2 = EvaluationSample. builder() + .withName("Sample2") + .withParameters(Parameters.of("param2")) + .withExpectedOutput("Output2") + .build(); + + Samples samples = new Samples<>(sample1, sample2); + + assertThat(samples.size()).isEqualTo(2); + } +} diff --git a/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java new file mode 100644 index 000000000..c039b762f --- /dev/null +++ b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java @@ -0,0 +1,105 @@ +package io.quarkiverse.langchain4j.testing.scorer; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.function.Function; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +class ScorerTest { + + private Scorer scorer; + + @AfterEach + void tearDown() { + if (scorer != null) { + scorer.close(); + } + } + + @SuppressWarnings("unchecked") + @Test + void evaluateShouldReturnCorrectReport() { + scorer = new Scorer(2); + + EvaluationSample sample1 = new EvaluationSample<>( + "Sample1", + new Parameters().add(new Parameter.UnnamedParameter("param1")), + "expected1", + List.of("tag1", "tag2")); + + EvaluationSample sample2 = new EvaluationSample<>( + "Sample2", + new Parameters().add(new Parameter.UnnamedParameter("param2")), + "expected2", + List.of("tag2")); + + Function mockFunction = params -> "expected1"; + EvaluationStrategy strategy = (sample, actual) -> actual.equals(sample.expectedOutput()); + + Samples samples = new Samples<>(sample1, sample2); + EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy); + + assertThat(report).isNotNull(); + assertThat(report.score()).isEqualTo(50.0); // Only one sample should pass. + assertThat(report.evaluations()).hasSize(2); + + Scorer.EvaluationResult result1 = report.evaluations().get(0); + assertThat(result1.passed()).isTrue(); + + Scorer.EvaluationResult result2 = report.evaluations().get(1); + assertThat(result2.passed()).isFalse(); + } + + @Test + @SuppressWarnings("unchecked") + void evaluateShouldHandleExceptionsInFunction() { + scorer = new Scorer(); + EvaluationSample sample = new EvaluationSample<>( + "Sample1", + new Parameters().add(new Parameter.UnnamedParameter("param1")), + "expected", + List.of()); + + Function mockFunction = params -> { + throw new RuntimeException("Test exception"); + }; + + EvaluationStrategy strategy = (s, actual) -> false; + + Samples samples = new Samples<>(sample); + EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy); + + assertThat(report).isNotNull(); + assertThat(report.score()).isEqualTo(0.0); // All evaluations should fail. + assertThat(report.evaluations()).hasSize(1); + assertThat(report.evaluations().get(0).passed()).isFalse(); + } + + @Test + @SuppressWarnings("unchecked") + void evaluateShouldHandleMultipleStrategies() { + scorer = new Scorer(); + + EvaluationSample sample = new EvaluationSample<>( + "Sample1", + new Parameters().add(new Parameter.UnnamedParameter("param1")), + "expected", + List.of()); + + Function mockFunction = params -> "expected"; + + EvaluationStrategy strategy1 = (s, actual) -> actual.equals("expected"); + EvaluationStrategy strategy2 = (s, actual) -> actual.length() > 3; + + Samples samples = new Samples<>(sample); + EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy1, strategy2); + + assertThat(report).isNotNull(); + assertThat(report.score()).isEqualTo(100.0); // Both strategies should pass for the sample. + assertThat(report.evaluations()).hasSize(2); // One result per strategy. + report.evaluations().forEach(e -> assertThat(e.passed()).isTrue()); + } +} diff --git a/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/YamlLoaderTest.java b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/YamlLoaderTest.java new file mode 100644 index 000000000..e4e83e46a --- /dev/null +++ b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/YamlLoaderTest.java @@ -0,0 +1,149 @@ +package io.quarkiverse.langchain4j.testing.scorer; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.File; +import java.io.FileWriter; +import java.nio.file.Path; +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +class YamlLoaderTest { + + @SuppressWarnings("DataFlowIssue") + @Test + void loadShouldThrowExceptionWhenPathIsNull() { + assertThatThrownBy(() -> YamlLoader.load(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Path must not be null"); + } + + @Test + void loadShouldThrowExceptionWhenPathIsBlank() { + assertThatThrownBy(() -> YamlLoader.load(" ")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Path must not be blank"); + } + + @Test + void loadShouldThrowExceptionWhenFileDoesNotExist() { + assertThatThrownBy(() -> YamlLoader.load("non-existent-file.yaml")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("File not found:"); + } + + @Test + void loadShouldThrowExceptionWhenYamlIsEmpty(@TempDir Path tempDir) throws Exception { + File emptyYaml = tempDir.resolve("empty.yaml").toFile(); + try (FileWriter writer = new FileWriter(emptyYaml)) { + writer.write(""); // Write an empty YAML file + } + + assertThatThrownBy(() -> YamlLoader.load(emptyYaml.getPath())) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("Failed to load sample from"); + } + + @Test + void loadShouldThrowExceptionWhenExpectedOutputIsMissing(@TempDir Path tempDir) throws Exception { + File invalidYaml = tempDir.resolve("invalid.yaml").toFile(); + try (FileWriter writer = new FileWriter(invalidYaml)) { + writer.write( + """ + - name: Sample1 + parameters: + - param1 + """); // Missing "expected-output" or "expected-outputs" + } + + assertThatThrownBy(() -> YamlLoader.load(invalidYaml.getPath())) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("Expected output not found for sample Sample1"); + } + + @SuppressWarnings("unchecked") + @Test + void loadShouldLoadValidYamlFile(@TempDir Path tempDir) throws Exception { + File validYaml = tempDir.resolve("valid.yaml").toFile(); + try (FileWriter writer = new FileWriter(validYaml)) { + writer.write( + """ + - name: Sample1 + parameters: + - p1 + - p2 + expected-output: ExpectedOutput1 + tags: + - tag1 + - tag2 + - name: Sample2 + parameters: + - p3 + expected-outputs: + - Output1 + - Output2 + """); + } + + Samples samples = YamlLoader.load(validYaml.getPath()); + + assertThat(samples).hasSize(2); + + // Validate Sample1 + EvaluationSample sample1 = samples.get(0); + assertThat(sample1.name()).isEqualTo("Sample1"); + assertThat(sample1.parameters()).containsExactly("p1", "p2"); + assertThat(sample1.expectedOutput()).isEqualTo("ExpectedOutput1"); + assertThat(sample1.tags()).containsExactly("tag1", "tag2"); + + // Validate Sample2 + EvaluationSample sample2 = samples.get(1); + assertThat(sample2.name()).isEqualTo("Sample2"); + assertThat(sample2.parameters()).containsExactly("p3"); + assertThat(sample2.expectedOutput()).isInstanceOf(List.class); + assertThat((List) sample2.expectedOutput()).containsExactly("Output1", "Output2"); + assertThat(sample2.tags()).isEmpty(); + } + + @Test + void loadShouldHandleYamlWithMissingTags(@TempDir Path tempDir) throws Exception { + File yamlWithoutTags = tempDir.resolve("no-tags.yaml").toFile(); + try (FileWriter writer = new FileWriter(yamlWithoutTags)) { + writer.write( + """ + - name: Sample1 + parameters: + - p1 + expected-output: ExpectedOutput1 + """); // No "tags" field + } + + Samples samples = YamlLoader.load(yamlWithoutTags.getPath()); + + assertThat(samples).hasSize(1); + + EvaluationSample sample = samples.get(0); + assertThat(sample.name()).isEqualTo("Sample1"); + assertThat(sample.parameters()).containsExactly("p1"); + assertThat(sample.expectedOutput()).isEqualTo("ExpectedOutput1"); + assertThat(sample.tags()).isEmpty(); // Ensure tags default to an empty list + } + + @Test + void loadShouldThrowExceptionOnInvalidYamlFormat(@TempDir Path tempDir) throws Exception { + File invalidYaml = tempDir.resolve("invalid-format.yaml").toFile(); + try (FileWriter writer = new FileWriter(invalidYaml)) { + writer.write( + """ + - invalid: + key: value + """); // Invalid structure for samples + } + + assertThatThrownBy(() -> YamlLoader.load(invalidYaml.getPath())) + .isInstanceOf(RuntimeException.class); + } +} diff --git a/testing/scorer/scorer-junit5/pom.xml b/testing/scorer/scorer-junit5/pom.xml new file mode 100644 index 000000000..80eda1647 --- /dev/null +++ b/testing/scorer/scorer-junit5/pom.xml @@ -0,0 +1,53 @@ + + + 4.0.0 + + + io.quarkiverse.langchain4j + quarkus-langchain4j-testing-scorer-parent + 999-SNAPSHOT + + + quarkus-langchain4j-testing-scorer-junit5 + Quarkus LangChain4j - Testing - Scorer Junit5 Extension + Provides the Junit5 extension to use the scorer testing utilities + + + + + org.junit.jupiter + junit-jupiter + compile + + + org.junit.jupiter + junit-jupiter-api + compile + + + org.junit.jupiter + junit-jupiter-params + compile + + + io.quarkiverse.langchain4j + quarkus-langchain4j-testing-scorer-core + 999-SNAPSHOT + + + + org.assertj + assertj-core + 3.26.3 + test + + + org.mockito + mockito-core + test + + + + \ No newline at end of file diff --git a/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/AiScorer.java b/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/AiScorer.java new file mode 100644 index 000000000..8b4898393 --- /dev/null +++ b/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/AiScorer.java @@ -0,0 +1,18 @@ +package io.quarkiverse.langchain4j.scorer.junit5; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.junit.jupiter.api.extension.ExtendWith; + +/** + * Annotation to enable the LangChain4J Scorer JUnit 5 extension. + * This is equivalent to adding the {@code @ExtendWith(ScorerExtension.class)} annotation to a test class. + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ ElementType.TYPE }) +@ExtendWith(ScorerExtension.class) +public @interface AiScorer { +} diff --git a/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/SampleLocation.java b/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/SampleLocation.java new file mode 100644 index 000000000..8c324225d --- /dev/null +++ b/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/SampleLocation.java @@ -0,0 +1,22 @@ +package io.quarkiverse.langchain4j.scorer.junit5; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Allows configuring the location of the Yaml file defining the sample. + * The value of this annotation should be the path to the Yaml file. + * The target of this annotation should be a parameter of type {@link io.quarkiverse.langchain4j.testing.scorer.Samples}. + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.PARAMETER) +public @interface SampleLocation { + + /** + * @return the location of the Yaml file defining the sample. + */ + String value(); + +} diff --git a/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerConfiguration.java b/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerConfiguration.java new file mode 100644 index 000000000..85ba53027 --- /dev/null +++ b/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerConfiguration.java @@ -0,0 +1,23 @@ +package io.quarkiverse.langchain4j.scorer.junit5; + +import static java.lang.annotation.ElementType.*; +import static java.lang.annotation.RetentionPolicy.*; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +/** + * Allows configuring the number of threads to use for the evaluation. + * The target of this annotation should be a parameter or a field of type + * {@link io.quarkiverse.langchain4j.testing.scorer.Scorer}. + */ +@Retention(RUNTIME) +@Target({ FIELD, PARAMETER }) +public @interface ScorerConfiguration { + + /** + * @return the number of threads to use for the evaluation. + */ + int concurrency() default 1; + +} diff --git a/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerExtension.java b/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerExtension.java new file mode 100644 index 000000000..6173f3dee --- /dev/null +++ b/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerExtension.java @@ -0,0 +1,85 @@ +package io.quarkiverse.langchain4j.scorer.junit5; + +import java.lang.reflect.Field; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CopyOnWriteArrayList; + +import org.junit.jupiter.api.extension.AfterEachCallback; +import org.junit.jupiter.api.extension.BeforeEachCallback; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.ParameterContext; +import org.junit.jupiter.api.extension.ParameterResolutionException; +import org.junit.jupiter.api.extension.ParameterResolver; +import org.junit.platform.commons.support.HierarchyTraversalMode; +import org.junit.platform.commons.support.ReflectionSupport; + +import io.quarkiverse.langchain4j.testing.scorer.Samples; +import io.quarkiverse.langchain4j.testing.scorer.Scorer; +import io.quarkiverse.langchain4j.testing.scorer.YamlLoader; + +public class ScorerExtension implements BeforeEachCallback, AfterEachCallback, ParameterResolver { + private final List scorers = new CopyOnWriteArrayList<>(); + + @Override + public void beforeEach(ExtensionContext extensionContext) { + Optional> maybeClass = extensionContext.getTestClass(); + if (maybeClass.isPresent()) { + List fields = ReflectionSupport.findFields(maybeClass.get(), + field -> field.getType().isAssignableFrom(Scorer.class), HierarchyTraversalMode.TOP_DOWN); + for (Field field : fields) { + Scorer sc; + if (field.isAnnotationPresent(ScorerConfiguration.class)) { + ScorerConfiguration annotation = field.getAnnotation(ScorerConfiguration.class); + sc = new Scorer(annotation.concurrency()); + } else { + sc = new Scorer(); + } + scorers.add(sc); + inject(sc, extensionContext.getRequiredTestInstance(), field); + } + } + } + + private void inject(Scorer sc, Object instance, Field field) { + try { + field.setAccessible(true); + field.set(instance, sc); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + @Override + public void afterEach(ExtensionContext extensionContext) { + for (Scorer scorer : scorers) { + scorer.close(); + } + } + + @Override + public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) + throws ParameterResolutionException { + return (parameterContext.findAnnotation(SampleLocation.class).isPresent() + && parameterContext.getParameter().getType().isAssignableFrom(Samples.class)) + || parameterContext.getParameter().getType().isAssignableFrom(Scorer.class); + } + + @Override + public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) + throws ParameterResolutionException { + if (parameterContext.getParameter().getType().isAssignableFrom(Scorer.class)) { + if (parameterContext.getParameter().isAnnotationPresent(ScorerConfiguration.class)) { + ScorerConfiguration annotation = parameterContext.getParameter().getAnnotation(ScorerConfiguration.class); + return new Scorer(annotation.concurrency()); + } else { + return new Scorer(); + } + } else { + // List of data samples + String path = parameterContext.findAnnotation(SampleLocation.class).orElseThrow().value(); + return YamlLoader.load(path); + } + } + +} diff --git a/testing/scorer/scorer-junit5/src/test/java/io/quarkiverse/langchain4j/scorer/junit5/test/ScorerExtensionTest.java b/testing/scorer/scorer-junit5/src/test/java/io/quarkiverse/langchain4j/scorer/junit5/test/ScorerExtensionTest.java new file mode 100644 index 000000000..da9ab6d36 --- /dev/null +++ b/testing/scorer/scorer-junit5/src/test/java/io/quarkiverse/langchain4j/scorer/junit5/test/ScorerExtensionTest.java @@ -0,0 +1,50 @@ +package io.quarkiverse.langchain4j.scorer.junit5.test; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mockito; + +import io.quarkiverse.langchain4j.scorer.junit5.SampleLocation; +import io.quarkiverse.langchain4j.scorer.junit5.ScorerConfiguration; +import io.quarkiverse.langchain4j.scorer.junit5.ScorerExtension; +import io.quarkiverse.langchain4j.testing.scorer.Samples; +import io.quarkiverse.langchain4j.testing.scorer.Scorer; + +@ExtendWith(ScorerExtension.class) +class ScorerExtensionTest { + + @ScorerConfiguration(concurrency = 3) + private Scorer scorerWithConcurrency; + + private Scorer defaultScorer; + + @Test + void scorerFieldInjectionShouldWork() { + assertThat(scorerWithConcurrency).isNotNull(); + assertThat(scorerWithConcurrency).extracting("executor").isNotNull(); + assertThat(defaultScorer).isNotNull(); + assertThat(defaultScorer).extracting("executor").isNotNull(); + } + + @Test + void scorerParameterShouldBeResolved(@ScorerConfiguration(concurrency = 2) Scorer scorer) { + assertThat(scorer).isNotNull(); + assertThat(scorer).extracting("executor").isNotNull(); + } + + @Test + void samplesParameterShouldBeResolved(@SampleLocation("src/test/resources/test-samples.yaml") Samples samples) { + assertThat(samples).isNotNull(); + assertThat(samples).hasSizeGreaterThan(0); + assertThat(samples.get(0).name()).isEqualTo("Sample1"); // Assuming the YAML has this entry. + } + + @Test + void scorerShouldBeClosedAfterTest() { + Scorer mockScorer = Mockito.mock(Scorer.class); + mockScorer.close(); + Mockito.verify(mockScorer).close(); + } +} diff --git a/testing/scorer/scorer-junit5/src/test/resources/test-samples.yaml b/testing/scorer/scorer-junit5/src/test/resources/test-samples.yaml new file mode 100644 index 000000000..78f7320b5 --- /dev/null +++ b/testing/scorer/scorer-junit5/src/test/resources/test-samples.yaml @@ -0,0 +1,14 @@ +--- +- name: "Sample1" + parameters: + - "What types of cars do you offer for rental?" + expected-output: | + We offer three categories of cars: + 1. Compact Commuter – Ideal for city driving, fuel-efficient, and budget-friendly. Example: Toyota Corolla, Honda Civic. + 2. Family Explorer SUV – Perfect for family trips with spacious seating for up to 7 passengers. Example: Toyota RAV4, Hyundai Santa Fe. + 3. Luxury Cruiser – Designed for traveling in style with premium features. Example: Mercedes-Benz E-Class, BMW 5 Series. +- name: "Sample2" + parameters: + - "Can I cancel my car rental booking at any time?" + expected-output: | + Our cancellation policy states that reservations can be canceled up to 11 days prior to the start of the booking period. If the booking period is less than 4 days, cancellations are not permitted. \ No newline at end of file diff --git a/testing/scorer/scorer-strategies/ai-judge/pom.xml b/testing/scorer/scorer-strategies/ai-judge/pom.xml new file mode 100644 index 000000000..6110a5641 --- /dev/null +++ b/testing/scorer/scorer-strategies/ai-judge/pom.xml @@ -0,0 +1,32 @@ + + + 4.0.0 + + + io.quarkiverse.langchain4j + quarkus-langchain4j-testing-scorer-parent + 999-SNAPSHOT + ../.. + + + quarkus-langchain4j-testing-scorer-ai-judge + Quarkus LangChain4j - Testing - Scorer - Strategy - AI Judge + Ask an LLM to judge evaluate responses + + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core + 999-SNAPSHOT + + + io.quarkiverse.langchain4j + quarkus-langchain4j-testing-scorer-core + 999-SNAPSHOT + + + + + \ No newline at end of file diff --git a/testing/scorer/scorer-strategies/ai-judge/src/main/java/io/quarkiverse/langchain4j/testing/scorer/judge/AiJudgeStrategy.java b/testing/scorer/scorer-strategies/ai-judge/src/main/java/io/quarkiverse/langchain4j/testing/scorer/judge/AiJudgeStrategy.java new file mode 100644 index 000000000..764d70eeb --- /dev/null +++ b/testing/scorer/scorer-strategies/ai-judge/src/main/java/io/quarkiverse/langchain4j/testing/scorer/judge/AiJudgeStrategy.java @@ -0,0 +1,61 @@ +package io.quarkiverse.langchain4j.testing.scorer.judge; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import io.quarkiverse.langchain4j.testing.scorer.EvaluationSample; +import io.quarkiverse.langchain4j.testing.scorer.EvaluationStrategy; + +/** + * A strategy to evaluate the output of a model using an AI judge, _i.e._ another model verifying if the expected + * output and the actual response match. + */ +public class AiJudgeStrategy implements EvaluationStrategy { + + private final ChatLanguageModel model; + private final String prompt; + + /** + * Create a new instance of `AiJudgeStrategy`. + * + * @param model the LLM model (chat model) to use as a judge. + * @param prompt the prompt to use to evaluate the response. + * The prompt should contain the placeholders `{response}` and `{expected_output}`. + */ + public AiJudgeStrategy(ChatLanguageModel model, String prompt) { + this.model = model; + this.prompt = prompt; + } + + /** + * Create a new instance of `AiJudgeStrategy` using the default prompt. + * + * @param model the LLM model (chat model) to use as a judge. + */ + public AiJudgeStrategy(ChatLanguageModel model) { + this(model, """ + You are an AI evaluating a response and the expected output. + You need to evaluate whether the model response is correct or not. + Return true if the response is correct, false otherwise. + + Response to evaluate: {response} + Expected output: {expected_output} + + """); + } + + /** + * Evaluate the output of a model. + * + * @param sample the sample to evaluate. + * @param output the output of the model. + * @return {@code true} if the output is correct, {@code false} otherwise. + */ + @Override + public boolean evaluate(EvaluationSample sample, String output) { + String expectedOutput = sample.expectedOutput(); + String prompt = this.prompt + .replace("{response}", output) + .replace("{expected_output}", expectedOutput); + var verdict = model.generate(prompt); + return Boolean.parseBoolean(verdict.trim()); + } +} diff --git a/testing/scorer/scorer-strategies/semantic-similarity/pom.xml b/testing/scorer/scorer-strategies/semantic-similarity/pom.xml new file mode 100644 index 000000000..467c4bd94 --- /dev/null +++ b/testing/scorer/scorer-strategies/semantic-similarity/pom.xml @@ -0,0 +1,32 @@ + + + 4.0.0 + + + io.quarkiverse.langchain4j + quarkus-langchain4j-testing-scorer-parent + 999-SNAPSHOT + ../../pom.xml + + + quarkus-langchain4j-testing-scorer-semantic-similarity + Quarkus LangChain4j - Testing - Scorer - Strategy - Semantic Similarity + Apply semantic similarity to check if the response matches the expected outcome + + + + dev.langchain4j + langchain4j-embeddings-bge-small-en-v15 + true + + + io.quarkiverse.langchain4j + quarkus-langchain4j-testing-scorer-core + 999-SNAPSHOT + + + + + \ No newline at end of file diff --git a/testing/scorer/scorer-strategies/semantic-similarity/src/main/java/io/quarkiverse/langchain4j/testing/scorer/similarity/SemanticSimilarityStrategy.java b/testing/scorer/scorer-strategies/semantic-similarity/src/main/java/io/quarkiverse/langchain4j/testing/scorer/similarity/SemanticSimilarityStrategy.java new file mode 100644 index 000000000..a251fcc2b --- /dev/null +++ b/testing/scorer/scorer-strategies/semantic-similarity/src/main/java/io/quarkiverse/langchain4j/testing/scorer/similarity/SemanticSimilarityStrategy.java @@ -0,0 +1,88 @@ +package io.quarkiverse.langchain4j.testing.scorer.similarity; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.onnx.bgesmallenv15.BgeSmallEnV15EmbeddingModel; +import dev.langchain4j.model.output.Response; +import io.quarkiverse.langchain4j.testing.scorer.EvaluationSample; +import io.quarkiverse.langchain4j.testing.scorer.EvaluationStrategy; + +/** + * A strategy to evaluate the output of a model using semantic similarity. + */ +public class SemanticSimilarityStrategy implements EvaluationStrategy { + + private final EmbeddingModel model; + private final double minSimilarity; + + /** + * Create a new instance of `SemanticSimilarityStrategy`. + * + * @param model the embedding model to use to calculate the similarity. + * @param minSimilarity the minimum similarity required to consider the output correct. + */ + public SemanticSimilarityStrategy(EmbeddingModel model, double minSimilarity) { + this.model = model; + this.minSimilarity = minSimilarity; + } + + /** + * Create a new instance of `SemanticSimilarityStrategy` using the default model (`BgeSmallEnV15`) and a default minimum + * similarity. + */ + public SemanticSimilarityStrategy() { + this(new BgeSmallEnV15EmbeddingModel(), 0.9); + } + + /** + * Create a new instance of `SemanticSimilarityStrategy` using the default model and a custom minimum similarity. + * + * @param minSimilarity the minimum similarity required to consider the output correct. + */ + public SemanticSimilarityStrategy(double minSimilarity) { + this(new BgeSmallEnV15EmbeddingModel(), minSimilarity); + } + + /** + * Evaluate the output of a model. + * + * @param sample the sample to evaluate. + * @param output the output of the model. + * @return {@code true} if the output is correct, {@code false} otherwise. + */ + @Override + public boolean evaluate(EvaluationSample sample, String output) { + Response actual = model.embed(output); + Response expectation = model.embed(sample.expectedOutput()); + return calculateCosineSimilarity(expectation.content().vector(), + actual.content().vector()) > minSimilarity; + } + + public static double calculateCosineSimilarity(float[] vectorA, float[] vectorB) { + if (vectorA.length != vectorB.length) { + throw new IllegalArgumentException("Vectors must be of the same length"); + } + + double dotProduct = 0.0; + double magnitudeA = 0.0; + double magnitudeB = 0.0; + + // Calculate dot product and magnitudes + for (int i = 0; i < vectorA.length; i++) { + dotProduct += vectorA[i] * vectorB[i]; + magnitudeA += Math.pow(vectorA[i], 2); + magnitudeB += Math.pow(vectorB[i], 2); + } + + // Compute magnitudes + magnitudeA = Math.sqrt(magnitudeA); + magnitudeB = Math.sqrt(magnitudeB); + + // Avoid division by zero + if (magnitudeA == 0 || magnitudeB == 0) { + throw new IllegalArgumentException("Vector magnitude cannot be zero"); + } + + return dotProduct / (magnitudeA * magnitudeB); + } +}