From 4c93bafa3ebea716a1bdc3a2cd078258d5cc3ecd Mon Sep 17 00:00:00 2001
From: Diego Ramp <diego.ramp@mobi.ch>
Date: Fri, 10 Jan 2025 10:55:29 +0100
Subject: [PATCH] Provide Evaluations in same order as they where submitted

---
 .../testing/scorer/EvaluationReport.java      |  8 +--
 .../langchain4j/testing/scorer/Scorer.java    | 69 ++++++++++++-------
 .../testing/scorer/ScorerTest.java            | 50 ++++++++++++--
 3 files changed, 90 insertions(+), 37 deletions(-)

diff --git a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReport.java b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReport.java
index 10454fedc..b3ace2f30 100644
--- a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReport.java
+++ b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReport.java
@@ -8,9 +8,9 @@
 /**
  * Report of the evaluation of a set of samples.
  */
-public class EvaluationReport {
+public class EvaluationReport<T> {
 
-    private final List<Scorer.EvaluationResult<?>> evaluations;
+    private final List<Scorer.EvaluationResult<T>> evaluations;
     private final double score;
 
     /**
@@ -18,7 +18,7 @@ public class EvaluationReport {
      *
      * @param evaluations the evaluations, must not be {@code null}, must not be empty.
      */
-    public EvaluationReport(List<Scorer.EvaluationResult<?>> evaluations) {
+    public EvaluationReport(List<Scorer.EvaluationResult<T>> evaluations) {
         this.evaluations = evaluations;
         this.score = 100.0 * evaluations.stream().filter(Scorer.EvaluationResult::passed).count() / evaluations.size();
     }
@@ -33,7 +33,7 @@ public double score() {
     /**
      * @return the evaluations
      */
-    public List<Scorer.EvaluationResult<?>> evaluations() {
+    public List<Scorer.EvaluationResult<T>> evaluations() {
         return evaluations;
     }
 
diff --git a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java
index 2c6195abe..4bbf4517b 100644
--- a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java
+++ b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java
@@ -1,6 +1,7 @@
 package io.quarkiverse.langchain4j.testing.scorer;
 
 import java.io.Closeable;
+import java.util.Comparator;
 import java.util.List;
 import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.CountDownLatch;
@@ -28,50 +29,64 @@ public Scorer() {
     }
 
     @SuppressWarnings({ "unchecked" })
-    public <T> EvaluationReport evaluate(Samples<T> samples, Function<Parameters, T> function,
-            EvaluationStrategy<T>... strategies) {
-        List<EvaluationResult<?>> evaluations = new CopyOnWriteArrayList<>();
+    public <T> EvaluationReport<T> evaluate(
+            Samples<T> samples, Function<Parameters, T> function, EvaluationStrategy<T>... strategies) {
+        List<OrderedEvaluationResult<T>> evaluations = new CopyOnWriteArrayList<>();
         CountDownLatch latch = new CountDownLatch(samples.size());
+        var index = 0;
         for (EvaluationSample<T> sample : samples) {
             // TODO Should we handle the context somehow.
-            executor.submit(() -> {
-                try {
-                    var response = execute(sample, function);
-                    LOG.infof("Evaluating sample `%s`", sample.name());
-                    for (EvaluationStrategy<T> strategy : strategies) {
-                        EvaluationResult<T> evaluation = EvaluationResult.fromCompletedEvaluation(sample,
-                                response, strategy.evaluate(sample, response));
-                        LOG.infof("Evaluation of sample `%s` with strategy `%s`: %s", sample.name(),
-                                strategy.getClass().getSimpleName(),
-                                evaluation.passed() ? "OK" : "KO");
-                        evaluations.add(evaluation);
-                    }
-                } catch (Throwable e) {
-                    LOG.errorf(e, "Failed to evaluate sample `%s`", sample.name());
-                    evaluations.add(EvaluationResult.fromEvaluationThrowable(sample, e));
-                } finally {
-                    latch.countDown();
-                }
-            });
+            var currentIndex = index++;
+            executor.submit(
+                    () -> {
+                        try {
+                            var response = execute(sample, function);
+                            LOG.infof("Evaluating sample `%s`", sample.name());
+                            for (EvaluationStrategy<T> strategy : strategies) {
+                                EvaluationResult<T> evaluation = EvaluationResult.fromCompletedEvaluation(
+                                        sample, response, strategy.evaluate(sample, response));
+                                LOG.infof(
+                                        "Evaluation of sample `%s` with strategy `%s`: %s",
+                                        sample.name(),
+                                        strategy.getClass().getSimpleName(),
+                                        evaluation.passed() ? "OK" : "KO");
+                                evaluations.add(new OrderedEvaluationResult(currentIndex, evaluation));
+                            }
+                        } catch (Throwable e) {
+                            LOG.errorf(e, "Failed to evaluate sample `%s`", sample.name());
+                            evaluations.add(
+                                    new OrderedEvaluationResult(
+                                            currentIndex, EvaluationResult.fromEvaluationThrowable(sample, e)));
+                        } finally {
+                            latch.countDown();
+                        }
+                    });
         }
         try {
             latch.await();
         } catch (InterruptedException e) {
             Thread.currentThread().interrupt();
         }
-        return new EvaluationReport(evaluations);
+        var orderedEvalutions = evaluations.stream()
+                .sorted(Comparator.comparing(OrderedEvaluationResult::index))
+                .map(OrderedEvaluationResult::evaluation)
+                .toList();
+        return new EvaluationReport<>(orderedEvalutions);
     }
 
     public void close() {
         executor.shutdown();
     }
 
-    public record EvaluationResult<T>(EvaluationSample<T> sample, T result, Throwable thrown, boolean passed) {
-        public static <T> EvaluationResult<T> fromCompletedEvaluation(EvaluationSample<T> sample, T result, boolean passed) {
+    public record EvaluationResult<T>(
+            EvaluationSample<T> sample, T result, Throwable thrown, boolean passed) {
+        public static <T> EvaluationResult<T> fromCompletedEvaluation(
+                EvaluationSample<T> sample, T result, boolean passed) {
             return new EvaluationResult<>(sample, result, null, passed);
         }
 
-        public static <T> EvaluationResult<T> fromEvaluationThrowable(EvaluationSample<T> sample, Throwable thrown) {
+        public static <T> EvaluationResult<T> fromEvaluationThrowable(
+                EvaluationSample<T> sample, Throwable thrown) {
             return new EvaluationResult<>(sample, null, thrown, false);
         }
     }
@@ -84,4 +99,6 @@ private <T> T execute(EvaluationSample<T> sample, Function<Parameters, T> functi
         }
     }
 
+    private record OrderedEvaluationResult<T>(int index, EvaluationResult<T> evaluation) {
+    }
 }
diff --git a/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java
index 02dd615d2..279e20887 100644
--- a/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java
+++ b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java
@@ -4,6 +4,7 @@
 
 import java.util.List;
 import java.util.function.Function;
+import java.util.stream.Stream;
 
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.Test;
@@ -40,20 +41,55 @@ void evaluateShouldReturnCorrectReport() {
         EvaluationStrategy<String> strategy = (sample, actual) -> actual.equals(sample.expectedOutput());
 
         Samples<String> samples = new Samples<>(sample1, sample2);
-        EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy);
+        EvaluationReport<String> report = scorer.evaluate(samples, mockFunction, strategy);
 
         assertThat(report).isNotNull();
         assertThat(report.score()).isEqualTo(50.0); // Only one sample should pass.
         assertThat(report.evaluations()).hasSize(2);
 
         var actualEvaluations = report.evaluations().stream()
-                .map(e -> "%s[%s;%s=%s]".formatted(e.sample().name(), e.sample().expectedOutput(), e.result(), e.passed()))
+                .map(
+                        e -> "%s[%s;%s=%s]"
+                                .formatted(
+                                        e.sample().name(), e.sample().expectedOutput(), e.result(), e.passed()))
                 .toList();
-        assertThat(actualEvaluations).containsExactlyInAnyOrder(
-                "Sample1[expected1:param1;expected1:param1=true]",
-                "Sample2[expected2;expected1:param1=false]");
+        assertThat(actualEvaluations)
+                .containsExactly(
+                        "Sample1[expected1:param1;expected1:param1=true]",
+                        "Sample2[expected2;expected1:param1=false]");
     }
 
+    @SuppressWarnings("unchecked")
+    @Test
+    void evaluateShouldReturnCorrectlyOrderedReport() {
+        scorer = new Scorer(2);
+        var sleeps = Stream.of(25l, 0l);
+        var samples = new Samples<>(
+                sleeps
+                        .map(
+                                sleep -> new EvaluationSample<>(
+                                        "%s".formatted(sleep),
+                                        new Parameters().add(new Parameter.UnnamedParameter(sleep)),
+                                        "irrelevant-for-this-test",
+                                        List.of()))
+                        .toList());
+
+        var actual = scorer.evaluate(samples, this::sleep, (sample, actualOutput) -> true);
+
+        var actualOrder = actual.evaluations().stream().map(e -> e.sample().name()).toList();
+        assertThat(actualOrder).containsExactly("25", "0");
+    }
+
+    private String sleep(Parameters params) {
+        long ms = params.get(0);
+        try {
+            Thread.sleep(ms);
+        } catch (InterruptedException e) {
+            throw new RuntimeException(e);
+        }
+        return "sleeped %s".formatted(ms);
+    };
+
     @Test
     @SuppressWarnings("unchecked")
     void evaluateShouldHandleExceptionsInFunction() {
@@ -71,7 +107,7 @@ void evaluateShouldHandleExceptionsInFunction() {
         EvaluationStrategy<String> strategy = (s, actual) -> false;
 
         Samples<String> samples = new Samples<>(sample);
-        EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy);
+        EvaluationReport<String> report = scorer.evaluate(samples, mockFunction, strategy);
 
         assertThat(report).isNotNull();
         assertThat(report.score()).isEqualTo(0.0); // All evaluations should fail.
@@ -96,7 +132,7 @@ void evaluateShouldHandleMultipleStrategies() {
         EvaluationStrategy<String> strategy2 = (s, actual) -> actual.length() > 3;
 
         Samples<String> samples = new Samples<>(sample);
-        EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy1, strategy2);
+        EvaluationReport<String> report = scorer.evaluate(samples, mockFunction, strategy1, strategy2);
 
         assertThat(report).isNotNull();
         assertThat(report.score()).isEqualTo(100.0); // Both strategies should pass for the sample.