From 4c93bafa3ebea716a1bdc3a2cd078258d5cc3ecd Mon Sep 17 00:00:00 2001 From: Diego Ramp Date: Fri, 10 Jan 2025 10:55:29 +0100 Subject: [PATCH] Provide Evaluations in same order as they where submitted --- .../testing/scorer/EvaluationReport.java | 8 +-- .../langchain4j/testing/scorer/Scorer.java | 69 ++++++++++++------- .../testing/scorer/ScorerTest.java | 50 ++++++++++++-- 3 files changed, 90 insertions(+), 37 deletions(-) diff --git a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReport.java b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReport.java index 10454fedc..b3ace2f30 100644 --- a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReport.java +++ b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReport.java @@ -8,9 +8,9 @@ /** * Report of the evaluation of a set of samples. */ -public class EvaluationReport { +public class EvaluationReport { - private final List> evaluations; + private final List> evaluations; private final double score; /** @@ -18,7 +18,7 @@ public class EvaluationReport { * * @param evaluations the evaluations, must not be {@code null}, must not be empty. */ - public EvaluationReport(List> evaluations) { + public EvaluationReport(List> evaluations) { this.evaluations = evaluations; this.score = 100.0 * evaluations.stream().filter(Scorer.EvaluationResult::passed).count() / evaluations.size(); } @@ -33,7 +33,7 @@ public double score() { /** * @return the evaluations */ - public List> evaluations() { + public List> evaluations() { return evaluations; } diff --git a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java index 2c6195abe..4bbf4517b 100644 --- a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java +++ b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java @@ -1,6 +1,7 @@ package io.quarkiverse.langchain4j.testing.scorer; import java.io.Closeable; +import java.util.Comparator; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; @@ -28,50 +29,64 @@ public Scorer() { } @SuppressWarnings({ "unchecked" }) - public EvaluationReport evaluate(Samples samples, Function function, - EvaluationStrategy... strategies) { - List> evaluations = new CopyOnWriteArrayList<>(); + public EvaluationReport evaluate( + Samples samples, Function function, EvaluationStrategy... strategies) { + List> evaluations = new CopyOnWriteArrayList<>(); CountDownLatch latch = new CountDownLatch(samples.size()); + var index = 0; for (EvaluationSample sample : samples) { // TODO Should we handle the context somehow. - executor.submit(() -> { - try { - var response = execute(sample, function); - LOG.infof("Evaluating sample `%s`", sample.name()); - for (EvaluationStrategy strategy : strategies) { - EvaluationResult evaluation = EvaluationResult.fromCompletedEvaluation(sample, - response, strategy.evaluate(sample, response)); - LOG.infof("Evaluation of sample `%s` with strategy `%s`: %s", sample.name(), - strategy.getClass().getSimpleName(), - evaluation.passed() ? "OK" : "KO"); - evaluations.add(evaluation); - } - } catch (Throwable e) { - LOG.errorf(e, "Failed to evaluate sample `%s`", sample.name()); - evaluations.add(EvaluationResult.fromEvaluationThrowable(sample, e)); - } finally { - latch.countDown(); - } - }); + var currentIndex = index++; + executor.submit( + () -> { + try { + var response = execute(sample, function); + LOG.infof("Evaluating sample `%s`", sample.name()); + for (EvaluationStrategy strategy : strategies) { + EvaluationResult evaluation = EvaluationResult.fromCompletedEvaluation( + sample, response, strategy.evaluate(sample, response)); + LOG.infof( + "Evaluation of sample `%s` with strategy `%s`: %s", + sample.name(), + strategy.getClass().getSimpleName(), + evaluation.passed() ? "OK" : "KO"); + evaluations.add(new OrderedEvaluationResult(currentIndex, evaluation)); + } + } catch (Throwable e) { + LOG.errorf(e, "Failed to evaluate sample `%s`", sample.name()); + evaluations.add( + new OrderedEvaluationResult( + currentIndex, EvaluationResult.fromEvaluationThrowable(sample, e))); + } finally { + latch.countDown(); + } + }); } try { latch.await(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } - return new EvaluationReport(evaluations); + var orderedEvalutions = evaluations.stream() + .sorted(Comparator.comparing(OrderedEvaluationResult::index)) + .map(OrderedEvaluationResult::evaluation) + .toList(); + return new EvaluationReport<>(orderedEvalutions); } public void close() { executor.shutdown(); } - public record EvaluationResult(EvaluationSample sample, T result, Throwable thrown, boolean passed) { - public static EvaluationResult fromCompletedEvaluation(EvaluationSample sample, T result, boolean passed) { + public record EvaluationResult( + EvaluationSample sample, T result, Throwable thrown, boolean passed) { + public static EvaluationResult fromCompletedEvaluation( + EvaluationSample sample, T result, boolean passed) { return new EvaluationResult<>(sample, result, null, passed); } - public static EvaluationResult fromEvaluationThrowable(EvaluationSample sample, Throwable thrown) { + public static EvaluationResult fromEvaluationThrowable( + EvaluationSample sample, Throwable thrown) { return new EvaluationResult<>(sample, null, thrown, false); } } @@ -84,4 +99,6 @@ private T execute(EvaluationSample sample, Function functi } } + private record OrderedEvaluationResult(int index, EvaluationResult evaluation) { + } } diff --git a/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java index 02dd615d2..279e20887 100644 --- a/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java +++ b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java @@ -4,6 +4,7 @@ import java.util.List; import java.util.function.Function; +import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; @@ -40,20 +41,55 @@ void evaluateShouldReturnCorrectReport() { EvaluationStrategy strategy = (sample, actual) -> actual.equals(sample.expectedOutput()); Samples samples = new Samples<>(sample1, sample2); - EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy); + EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy); assertThat(report).isNotNull(); assertThat(report.score()).isEqualTo(50.0); // Only one sample should pass. assertThat(report.evaluations()).hasSize(2); var actualEvaluations = report.evaluations().stream() - .map(e -> "%s[%s;%s=%s]".formatted(e.sample().name(), e.sample().expectedOutput(), e.result(), e.passed())) + .map( + e -> "%s[%s;%s=%s]" + .formatted( + e.sample().name(), e.sample().expectedOutput(), e.result(), e.passed())) .toList(); - assertThat(actualEvaluations).containsExactlyInAnyOrder( - "Sample1[expected1:param1;expected1:param1=true]", - "Sample2[expected2;expected1:param1=false]"); + assertThat(actualEvaluations) + .containsExactly( + "Sample1[expected1:param1;expected1:param1=true]", + "Sample2[expected2;expected1:param1=false]"); } + @SuppressWarnings("unchecked") + @Test + void evaluateShouldReturnCorrectlyOrderedReport() { + scorer = new Scorer(2); + var sleeps = Stream.of(25l, 0l); + var samples = new Samples<>( + sleeps + .map( + sleep -> new EvaluationSample<>( + "%s".formatted(sleep), + new Parameters().add(new Parameter.UnnamedParameter(sleep)), + "irrelevant-for-this-test", + List.of())) + .toList()); + + var actual = scorer.evaluate(samples, this::sleep, (sample, actualOutput) -> true); + + var actualOrder = actual.evaluations().stream().map(e -> e.sample().name()).toList(); + assertThat(actualOrder).containsExactly("25", "0"); + } + + private String sleep(Parameters params) { + long ms = params.get(0); + try { + Thread.sleep(ms); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return "sleeped %s".formatted(ms); + }; + @Test @SuppressWarnings("unchecked") void evaluateShouldHandleExceptionsInFunction() { @@ -71,7 +107,7 @@ void evaluateShouldHandleExceptionsInFunction() { EvaluationStrategy strategy = (s, actual) -> false; Samples samples = new Samples<>(sample); - EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy); + EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy); assertThat(report).isNotNull(); assertThat(report.score()).isEqualTo(0.0); // All evaluations should fail. @@ -96,7 +132,7 @@ void evaluateShouldHandleMultipleStrategies() { EvaluationStrategy strategy2 = (s, actual) -> actual.length() > 3; Samples samples = new Samples<>(sample); - EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy1, strategy2); + EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy1, strategy2); assertThat(report).isNotNull(); assertThat(report.score()).isEqualTo(100.0); // Both strategies should pass for the sample.