diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java index dbec8d9d8b..7d82a5d787 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java @@ -292,7 +292,8 @@ public Response getProjectStats( log.info("Find projects stats by '{}' on workspaceId '{}'", criteria, workspaceId); ProjectStatsSummary projectStatisticsSummary = projectService.getStats(page, size, criteria, sortingFields); - log.info("Found projects stats by '{}', count '{}' on workspaceId '{}'", criteria, projectStatisticsSummary.content().size(), workspaceId); + log.info("Found projects stats by '{}', count '{}' on workspaceId '{}'", criteria, + projectStatisticsSummary.content().size(), workspaceId); return Response.ok().entity(projectStatisticsSummary).build(); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java index d2adc5be8d..d5f68cd262 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java @@ -85,7 +85,8 @@ public interface ProjectService { void recordLastUpdatedTrace(String workspaceId, Collection lastUpdatedTraces); - ProjectStatsSummary getStats(int page, int size, @NonNull ProjectCriteria criteria, @NonNull List sortingFields); + ProjectStatsSummary getStats(int page, int size, @NonNull ProjectCriteria criteria, + @NonNull List sortingFields); } @Slf4j @@ -487,7 +488,8 @@ public Project retrieveByName(@NonNull String projectName) { .feedbackScores(StatsMapper.getStatsFeedbackScores(projectStats.get(project.id()))) .usage(StatsMapper.getStatsUsage(projectStats.get(project.id()))) .duration(StatsMapper.getStatsDuration(projectStats.get(project.id()))) - .totalEstimatedCost(StatsMapper.getStatsTotalEstimatedCost(projectStats.get(project.id()))) + .totalEstimatedCost( + StatsMapper.getStatsTotalEstimatedCost(projectStats.get(project.id()))) .build(); }) .orElseThrow(this::createNotFoundError); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java index c94a4062ba..9b4988e6d5 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java @@ -582,6 +582,8 @@ void getProjectById__whenProjectExists__thenReturnProject() { var id = createProject(project, apiKey, workspaceName); + project = buildProjectStats(project.toBuilder().id(id).build(), apiKey, workspaceName); + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) .path("retrieve") .request() @@ -589,16 +591,18 @@ void getProjectById__whenProjectExists__thenReturnProject() { .header(WORKSPACE_HEADER, workspaceName) .post(Entity.json(ProjectRetrieve.builder().name(project.name()).build()))) { - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_OK); assertThat(actualResponse.hasEntity()).isTrue(); var actualEntity = actualResponse.readEntity(Project.class); assertThat(actualEntity) .usingRecursiveComparison() - .ignoringFields(IGNORED_FIELDS) - .isEqualTo(project.toBuilder() - .id(id) - .build()); + .ignoringFields("createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt", + "lastUpdatedTraceAt") + .ignoringCollectionOrder() + .withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class) + .withComparatorForFields(StatsUtils::closeToEpsilonComparator, "totalEstimatedCost") + .isEqualTo(project); } } @@ -1153,7 +1157,8 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsage__thenReturnPr Comparator comparator = Comparator.comparing(Project::id).reversed(); - List expectedProjectStats = getProjectStatsSummaryItems(apiKey, workspaceName, comparator); + List expectedProjectStats = getProjectStatsSummaryItems(apiKey, workspaceName, + comparator); var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) .path("/stats") @@ -1162,13 +1167,14 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsage__thenReturnPr .header(WORKSPACE_HEADER, workspaceName) .get(); - var actualEntity = actualResponse.readEntity(Project.ProjectPage.class); + var actualEntity = actualResponse.readEntity(ProjectStatsSummary.class); assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(org.apache.http.HttpStatus.SC_OK); assertThat(expectedProjectStats).hasSameSizeAs(actualEntity.content()); assertThat(actualEntity.content()) .usingRecursiveComparison() + .ignoringCollectionOrder() .withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class) .withComparatorForFields(StatsUtils::closeToEpsilonComparator, "totalEstimatedCost") .isEqualTo(expectedProjectStats); @@ -1185,7 +1191,8 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsageSortedLastTrac Comparator comparator = Comparator.comparing(Project::lastUpdatedTraceAt).reversed(); - List expectedProjectStats = getProjectStatsSummaryItems(apiKey, workspaceName, comparator); + List expectedProjectStats = getProjectStatsSummaryItems(apiKey, workspaceName, + comparator); var sorting = List.of(SortingField.builder() .field(SortableFields.LAST_UPDATED_TRACE_AT) @@ -1209,12 +1216,14 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsageSortedLastTrac assertThat(actualEntity.content()) .usingRecursiveComparison() + .ignoringCollectionOrder() .withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class) .withComparatorForFields(StatsUtils::closeToEpsilonComparator, "totalEstimatedCost") .isEqualTo(expectedProjectStats); } - private List getProjectStatsSummaryItems(String apiKey, String workspaceName, Comparator comparing) { + private List getProjectStatsSummaryItems(String apiKey, String workspaceName, + Comparator comparing) { var projects = PodamFactoryUtils.manufacturePojoList(factory, Project.class) .parallelStream() .map(project -> project.toBuilder() @@ -1266,6 +1275,7 @@ void getProjects__whenProjectsHasNoTracesSpansFeedbackScoresAndUsage__thenReturn .totalEstimatedCost(null) .usage(null) .feedbackScores(null) + .projectId(project.id()) .build()) .sorted(Comparator.comparing(ProjectStatsSummaryItem::projectId).reversed()) .toList(); @@ -1290,107 +1300,6 @@ void getProjects__whenProjectsHasNoTracesSpansFeedbackScoresAndUsage__thenReturn .isEqualTo(expectedProjectStats); } - private Project buildProjectStats(Project project, String apiKey, String workspaceName) { - var traces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class).stream() - .map(trace -> { - Instant startTime = Instant.now(); - Instant endTime = startTime.plusMillis(PodamUtils.getIntegerInRange(1, 1000)); - return trace.toBuilder() - .projectName(project.name()) - .startTime(startTime) - .endTime(endTime) - .duration(DurationUtils.getDurationInMillisWithSubMilliPrecision(startTime, endTime)) - .build(); - }) - .toList(); - - traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName); - - List scores = PodamFactoryUtils.manufacturePojoList(factory, - FeedbackScoreBatchItem.class); - - traces = traces.stream().map(trace -> { - List spans = PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() - .map(span -> span.toBuilder() - .usage(spanResourceClient.getTokenUsage()) - .model(spanResourceClient.randomModel().toString()) - .traceId(trace.id()) - .projectName(trace.projectName()) - .totalEstimatedCost(null) - .build()) - .toList(); - - spanResourceClient.batchCreateSpans(spans, apiKey, workspaceName); - - List feedbackScores = scores.stream() - .map(feedbackScore -> feedbackScore.toBuilder() - .projectId(project.id()) - .projectName(project.name()) - .id(trace.id()) - .build()) - .toList(); - - traceResourceClient.feedbackScores(feedbackScores, apiKey, workspaceName); - - return trace.toBuilder() - .feedbackScores( - feedbackScores.stream() - .map(score -> FeedbackScore.builder() - .value(score.value()) - .name(score.name()) - .build()) - .toList()) - .usage(StatsUtils.aggregateSpansUsage(spans)) - .totalEstimatedCost(StatsUtils.aggregateSpansCost(spans)) - .build(); - }).toList(); - - List durations = StatsUtils.calculateQuantiles( - traces.stream() - .map(Trace::duration) - .toList(), - List.of(0.5, 0.90, 0.99)); - - return project.toBuilder() - .duration(new ProjectStats.PercentageValues(durations.get(0), durations.get(1), durations.get(2))) - .totalEstimatedCost(getTotalEstimatedCost(traces)) - .usage(traces.stream() - .map(Trace::usage) - .flatMap(usage -> usage.entrySet().stream()) - .collect(groupingBy(Map.Entry::getKey, averagingDouble(Map.Entry::getValue)))) - .feedbackScores(getScoreAverages(traces)) - .lastUpdatedTraceAt(traces.stream().map(Trace::lastUpdatedAt).max(Instant::compareTo).orElse(null)) - .build(); - } - - private List getScoreAverages(List traces) { - return traces.stream() - .map(Trace::feedbackScores) - .flatMap(List::stream) - .collect(groupingBy(FeedbackScore::name, - BigDecimalCollectors.averagingBigDecimal(FeedbackScore::value))) - .entrySet() - .stream() - .map(entry -> FeedbackScoreAverage.builder() - .name(entry.getKey()) - .value(entry.getValue()) - .build()) - .toList(); - } - - private double getTotalEstimatedCost(List traces) { - long count = traces.stream() - .map(Trace::totalEstimatedCost) - .filter(Objects::nonNull) - .filter(cost -> cost.compareTo(BigDecimal.ZERO) > 0) - .count(); - - return traces.stream() - .map(Trace::totalEstimatedCost) - .reduce(BigDecimal.ZERO, BigDecimal::add) - .divide(BigDecimal.valueOf(count), ValidationUtils.SCALE, RoundingMode.HALF_UP).doubleValue(); - } - @Test @DisplayName("when projects is with traces created in batch, then return project with last updated trace at") void getProjects__whenProjectsHasTracesBatch__thenReturnProjectWithLastUpdatedTraceAt() { @@ -1507,6 +1416,107 @@ private void assertAllProjectsHavePersistedLastTraceAt(String workspaceId, List< } } + private Project buildProjectStats(Project project, String apiKey, String workspaceName) { + var traces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class).stream() + .map(trace -> { + Instant startTime = Instant.now(); + Instant endTime = startTime.plusMillis(PodamUtils.getIntegerInRange(1, 1000)); + return trace.toBuilder() + .projectName(project.name()) + .startTime(startTime) + .endTime(endTime) + .duration(DurationUtils.getDurationInMillisWithSubMilliPrecision(startTime, endTime)) + .build(); + }) + .toList(); + + traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName); + + List scores = PodamFactoryUtils.manufacturePojoList(factory, + FeedbackScoreBatchItem.class); + + traces = traces.stream().map(trace -> { + List spans = PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() + .map(span -> span.toBuilder() + .usage(spanResourceClient.getTokenUsage()) + .model(spanResourceClient.randomModel().toString()) + .traceId(trace.id()) + .projectName(trace.projectName()) + .totalEstimatedCost(null) + .build()) + .toList(); + + spanResourceClient.batchCreateSpans(spans, apiKey, workspaceName); + + List feedbackScores = scores.stream() + .map(feedbackScore -> feedbackScore.toBuilder() + .projectId(project.id()) + .projectName(project.name()) + .id(trace.id()) + .build()) + .toList(); + + traceResourceClient.feedbackScores(feedbackScores, apiKey, workspaceName); + + return trace.toBuilder() + .feedbackScores( + feedbackScores.stream() + .map(score -> FeedbackScore.builder() + .value(score.value()) + .name(score.name()) + .build()) + .toList()) + .usage(StatsUtils.aggregateSpansUsage(spans)) + .totalEstimatedCost(StatsUtils.aggregateSpansCost(spans)) + .build(); + }).toList(); + + List durations = StatsUtils.calculateQuantiles( + traces.stream() + .map(Trace::duration) + .toList(), + List.of(0.5, 0.90, 0.99)); + + return project.toBuilder() + .duration(new ProjectStats.PercentageValues(durations.get(0), durations.get(1), durations.get(2))) + .totalEstimatedCost(getTotalEstimatedCost(traces)) + .usage(traces.stream() + .map(Trace::usage) + .flatMap(usage -> usage.entrySet().stream()) + .collect(groupingBy(Map.Entry::getKey, averagingDouble(Map.Entry::getValue)))) + .feedbackScores(getScoreAverages(traces)) + .lastUpdatedTraceAt(traces.stream().map(Trace::lastUpdatedAt).max(Instant::compareTo).orElse(null)) + .build(); + } + + private List getScoreAverages(List traces) { + return traces.stream() + .map(Trace::feedbackScores) + .flatMap(List::stream) + .collect(groupingBy(FeedbackScore::name, + BigDecimalCollectors.averagingBigDecimal(FeedbackScore::value))) + .entrySet() + .stream() + .map(entry -> FeedbackScoreAverage.builder() + .name(entry.getKey()) + .value(entry.getValue()) + .build()) + .toList(); + } + + private double getTotalEstimatedCost(List traces) { + long count = traces.stream() + .map(Trace::totalEstimatedCost) + .filter(Objects::nonNull) + .filter(cost -> cost.compareTo(BigDecimal.ZERO) > 0) + .count(); + + return traces.stream() + .map(Trace::totalEstimatedCost) + .reduce(BigDecimal.ZERO, BigDecimal::add) + .divide(BigDecimal.valueOf(count), ValidationUtils.SCALE, RoundingMode.HALF_UP).doubleValue(); + } + @Nested @DisplayName("Get: {id}") @TestInstance(TestInstance.Lifecycle.PER_CLASS)