From 535cf342a74ff27dee284aef4a385496484f8587 Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Thu, 30 Jan 2025 11:52:09 +0100 Subject: [PATCH 1/2] OPIK-904: Split find project endpoint into two endpoints find and stats --- .../main/java/com/comet/opik/api/Project.java | 11 ++- .../comet/opik/api/ProjectStatsSummary.java | 28 ++++++ .../resources/v1/priv/ProjectsResource.java | 33 ++++++- .../com/comet/opik/domain/ProjectService.java | 76 ++++++++++---- .../java/com/comet/opik/domain/SpanDAO.java | 10 +- .../java/com/comet/opik/domain/TraceDAO.java | 6 +- .../v1/priv/ProjectsResourceTest.java | 98 ++++++++++--------- 7 files changed, 183 insertions(+), 79 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/ProjectStatsSummary.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java b/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java index 8e39e14a98..b777c729dd 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java @@ -34,13 +34,13 @@ public record Project( @JsonView({ Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Instant lastUpdatedTraceAt, @JsonView({ - Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable List feedbackScores, + Project.View.Detailed.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable List feedbackScores, @JsonView({ - Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable PercentageValues duration, + Project.View.Detailed.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable PercentageValues duration, @JsonView({ - Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Double totalEstimatedCost, + Project.View.Detailed.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Double totalEstimatedCost, @JsonView({ - Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Map usage){ + Project.View.Detailed.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Map usage){ public static class View { public static class Write { @@ -48,6 +48,9 @@ public static class Write { public static class Public { } + + public static class Detailed extends Public { + } } public record ProjectPage( diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/ProjectStatsSummary.java b/apps/opik-backend/src/main/java/com/comet/opik/api/ProjectStatsSummary.java new file mode 100644 index 0000000000..312bf044f1 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/ProjectStatsSummary.java @@ -0,0 +1,28 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Builder; + +import java.util.List; +import java.util.Map; +import java.util.UUID; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record ProjectStatsSummary(List content) { + + @Builder(toBuilder = true) + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) + public record ProjectStatsSummaryItem( + UUID projectId, + List feedbackScores, + ProjectStats.PercentageValues duration, + Double totalEstimatedCost, + Map usage) { + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java index d319cdd0c9..dbec8d9d8b 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java @@ -7,6 +7,7 @@ import com.comet.opik.api.Project; import com.comet.opik.api.ProjectCriteria; import com.comet.opik.api.ProjectRetrieve; +import com.comet.opik.api.ProjectStatsSummary; import com.comet.opik.api.ProjectUpdate; import com.comet.opik.api.error.ErrorMessage; import com.comet.opik.api.metrics.ProjectMetricRequest; @@ -68,6 +69,7 @@ @Tag(name = "Projects", description = "Project related resources") public class ProjectsResource { + private static final String PAGE_SIZE = "10"; private final @NonNull ProjectService projectService; private final @NonNull Provider requestContext; private final @NonNull SortingFactoryProjects sortingFactory; @@ -81,7 +83,7 @@ public class ProjectsResource { @JsonView({Project.View.Public.class}) public Response find( @QueryParam("page") @Min(1) @DefaultValue("1") int page, - @QueryParam("size") @Min(1) @DefaultValue("10") int size, + @QueryParam("size") @Min(1) @DefaultValue(PAGE_SIZE) int size, @QueryParam("name") String name, @QueryParam("sorting") String sorting) { @@ -188,7 +190,7 @@ public Response deleteById(@PathParam("id") UUID id) { @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = ErrorMessage.class))) }) - @JsonView({Project.View.Public.class}) + @JsonView({Project.View.Detailed.class}) public Response retrieveProject( @RequestBody(content = @Content(schema = @Schema(implementation = ProjectRetrieve.class))) @Valid ProjectRetrieve retrieve) { String workspaceId = requestContext.get().getWorkspaceId(); @@ -268,4 +270,31 @@ private void validate(ProjectMetricRequest request) { throw new BadRequestException(ERR_START_BEFORE_END); } } + + @GET + @Path("/stats") + @Operation(operationId = "getProjectStats", summary = "Get Project Stats", description = "Get Project Stats", responses = { + @ApiResponse(responseCode = "200", description = "Project Stats", content = @Content(schema = @Schema(implementation = ProjectStatsSummary.class))), + }) + public Response getProjectStats( + @QueryParam("page") @Min(1) @DefaultValue("1") int page, + @QueryParam("size") @Min(1) @DefaultValue(PAGE_SIZE) int size, + @QueryParam("name") String name, + @QueryParam("sorting") String sorting) { + + var criteria = ProjectCriteria.builder() + .projectName(name) + .build(); + + String workspaceId = requestContext.get().getWorkspaceId(); + + List sortingFields = sortingFactory.newSorting(sorting); + + log.info("Find projects stats by '{}' on workspaceId '{}'", criteria, workspaceId); + ProjectStatsSummary projectStatisticsSummary = projectService.getStats(page, size, criteria, sortingFields); + log.info("Found projects stats by '{}', count '{}' on workspaceId '{}'", criteria, projectStatisticsSummary.content().size(), workspaceId); + + return Response.ok().entity(projectStatisticsSummary).build(); + } + } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java index 41c82ac613..d2adc5be8d 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java @@ -5,6 +5,7 @@ import com.comet.opik.api.Project.ProjectPage; import com.comet.opik.api.ProjectCriteria; import com.comet.opik.api.ProjectIdLastUpdated; +import com.comet.opik.api.ProjectStatsSummary; import com.comet.opik.api.ProjectUpdate; import com.comet.opik.api.error.EntityAlreadyExistsException; import com.comet.opik.api.error.ErrorMessage; @@ -44,6 +45,7 @@ import java.util.stream.Collectors; import static com.comet.opik.api.ProjectStats.ProjectStatItem; +import static com.comet.opik.api.ProjectStatsSummary.ProjectStatsSummaryItem; import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.READ_ONLY; import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.WRITE; import static java.util.Collections.reverseOrder; @@ -82,6 +84,8 @@ public interface ProjectService { Project retrieveByName(String projectName); void recordLastUpdatedTrace(String workspaceId, Collection lastUpdatedTraces); + + ProjectStatsSummary getStats(int page, int size, @NonNull ProjectCriteria criteria, @NonNull List sortingFields); } @Slf4j @@ -203,14 +207,36 @@ public Project get(@NonNull UUID id, @NonNull String workspaceId) { .nonTransaction(connection -> traceDAO.getLastUpdatedTraceAt(Set.of(id), workspaceId, connection)) .block(); - Map> projectStats = getProjectStats(List.of(id), workspaceId); + return project.toBuilder() + .lastUpdatedTraceAt(lastUpdatedTraceAt.get(project.id())) + .build(); + } + + @Override + public ProjectStatsSummary getStats(int page, int size, @NonNull ProjectCriteria criteria, + @NonNull List sortingFields) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + List projectIds = find(page, size, criteria, sortingFields) + .content() + .stream() + .map(Project::id) + .toList(); + + Map> projectStats = getProjectStats(projectIds, workspaceId); - return enhanceProject(project, lastUpdatedTraceAt.get(project.id()), projectStats.get(project.id())); + return ProjectStatsSummary.builder() + .content( + projectIds.stream() + .map(projectId -> getStats(projectId, projectStats.get(projectId))) + .toList()) + .build(); } - private Project enhanceProject(Project project, Instant lastUpdatedTraceAt, Map projectStats) { - return project.toBuilder() - .lastUpdatedTraceAt(lastUpdatedTraceAt) + private ProjectStatsSummaryItem getStats(UUID projectId, Map projectStats) { + return ProjectStatsSummaryItem.builder() + .projectId(projectId) .feedbackScores(StatsMapper.getStatsFeedbackScores(projectStats)) .duration(StatsMapper.getStatsDuration(projectStats)) .totalEstimatedCost(StatsMapper.getStatsTotalEstimatedCost(projectStats)) @@ -285,14 +311,14 @@ public Page find(int page, int size, @NonNull ProjectCriteria criteria, return traceDAO.getLastUpdatedTraceAt(projectIds, workspaceId, connection); }).block(); - List projectIds = projectRecordSet.content.stream().map(Project::id).toList(); - - Map> projectStats = getProjectStats(projectIds, workspaceId); - List projects = projectRecordSet.content() .stream() - .map(project -> enhanceProject(project, projectLastUpdatedTraceAtMap.get(project.id()), - projectStats.get(project.id()))) + .map(project -> { + Instant lastUpdatedTraceAt = projectLastUpdatedTraceAtMap.get(project.id()); + return project.toBuilder() + .lastUpdatedTraceAt(lastUpdatedTraceAt) + .build(); + }) .toList(); return new ProjectPage(page, projects.size(), projectRecordSet.total(), projects, @@ -303,8 +329,7 @@ private Map> getProjectStats(List projectIds, St return traceDAO.getStatsByProjectIds(projectIds, workspaceId) .map(stats -> stats.entrySet().stream() .map(entry -> { - Map statsMap = entry.getValue() - .stats() + Map statsMap = entry.getValue().stats() .stream() .collect(toMap(ProjectStatItem::getName, ProjectStatItem::getValue)); @@ -342,9 +367,11 @@ private Page findWithLastTraceSorting(int page, int size, @NonNull Proj // get last trace for each project id Set allProjectIds = allProjectIdsLastUpdated.stream().map(ProjectIdLastUpdated::id) .collect(toUnmodifiableSet()); + Map projectLastUpdatedTraceAtMap = transactionTemplateAsync .nonTransaction(connection -> traceDAO.getLastUpdatedTraceAt(allProjectIds, workspaceId, connection)) .block(); + if (projectLastUpdatedTraceAtMap == null) { return ProjectPage.empty(page); } @@ -365,12 +392,12 @@ private Page findWithLastTraceSorting(int page, int size, @NonNull Proj return repository.findByIds(new HashSet<>(finalIds), workspaceId); }).stream().collect(Collectors.toMap(Project::id, Function.identity())); - Map> projectStats = getProjectStats(finalIds, workspaceId); - // compose the final projects list by the correct order and add last trace to it - List projects = finalIds.stream().map(projectsById::get) - .map(project -> enhanceProject(project, projectLastUpdatedTraceAtMap.get(project.id()), - projectStats.get(project.id()))) + List projects = finalIds.stream() + .map(projectsById::get) + .map(project -> project.toBuilder() + .lastUpdatedTraceAt(projectLastUpdatedTraceAtMap.get(project.id())) + .build()) .toList(); return new ProjectPage(page, projects.size(), allProjectIdsLastUpdated.size(), projects, @@ -389,7 +416,9 @@ private List sortByLastTrace( ? reverseOrder(Map.Entry.comparingByValue()) : Map.Entry.comparingByValue(); - return projectLastUpdatedTraceAtMap.entrySet().stream().sorted(comparator) + return projectLastUpdatedTraceAtMap.entrySet() + .stream() + .sorted(comparator) .map(Map.Entry::getKey) .toList(); } @@ -453,8 +482,13 @@ public Project retrieveByName(@NonNull String projectName) { Map> projectStats = getProjectStats(List.of(project.id()), workspaceId); - return enhanceProject(project, projectLastUpdatedTraceAtMap.get(project.id()), - projectStats.get(project.id())); + return project.toBuilder() + .lastUpdatedTraceAt(projectLastUpdatedTraceAtMap.get(project.id())) + .feedbackScores(StatsMapper.getStatsFeedbackScores(projectStats.get(project.id()))) + .usage(StatsMapper.getStatsUsage(projectStats.get(project.id()))) + .duration(StatsMapper.getStatsDuration(projectStats.get(project.id()))) + .totalEstimatedCost(StatsMapper.getStatsTotalEstimatedCost(projectStats.get(project.id()))) + .build(); }) .orElseThrow(this::createNotFoundError); }); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java index 893faea911..13d3d52ba7 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java @@ -601,9 +601,15 @@ AND notEquals(s.start_time, toDateTime64('1970-01-01 00:00:00.000', 9)), (dateDiff('microsecond', s.start_time, s.end_time) / 1000.0), NULL) AS duration_millis, groupArray(tuple(c.*)) AS comments - FROM spans s + FROM ( + SELECT + *, + row_number() OVER (PARTITION BY id ORDER BY last_updated_at DESC) AS latest\s + FROM spans + WHERE id IN (SELECT id FROM span_ids) + ) AS s LEFT JOIN comments_final AS c ON s.id = c.entity_id - WHERE s.id IN (SELECT id FROM span_ids) + WHERE s.latest = 1 GROUP BY s.*, duration_millis diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java index 7462e92ff9..8f84555e5d 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java @@ -435,14 +435,14 @@ AND id in ( if(end_time IS NOT NULL AND start_time IS NOT NULL AND notEquals(start_time, toDateTime64('1970-01-01 00:00:00.000', 9)), (dateDiff('microsecond', start_time, end_time) / 1000.0), - NULL) AS duration_millis + NULL) AS duration_millis, + row_number() OVER (PARTITION BY id ORDER BY last_updated_at DESC) AS latest FROM traces WHERE id IN (SELECT id FROM traces_ids) - ORDER BY id DESC, last_updated_at DESC - LIMIT 1 BY id ) AS t LEFT JOIN span_usage AS s ON t.id = s.trace_id LEFT JOIN comments_final AS c ON t.id = c.entity_id + WHERE t.latest = 1 GROUP BY t.*, t.duration_millis diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java index bd662b48b5..c94a4062ba 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java @@ -8,6 +8,7 @@ import com.comet.opik.api.Project; import com.comet.opik.api.ProjectRetrieve; import com.comet.opik.api.ProjectStats; +import com.comet.opik.api.ProjectStatsSummary; import com.comet.opik.api.ProjectUpdate; import com.comet.opik.api.Span; import com.comet.opik.api.Trace; @@ -85,6 +86,7 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import static com.comet.opik.api.ProjectStatsSummary.ProjectStatsSummaryItem; import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME; import static com.comet.opik.api.resources.utils.FeedbackScoreAssertionUtils.assertFeedbackScoreNames; import static com.comet.opik.api.resources.utils.MigrationUtils.CLICKHOUSE_CHANGELOG_FILE; @@ -1149,23 +1151,12 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsage__thenReturnPr mockTargetWorkspace(apiKey, workspaceName, workspaceId); - var projects = PodamFactoryUtils.manufacturePojoList(factory, Project.class) - .parallelStream() - .map(project -> project.toBuilder() - .id(createProject(project, apiKey, workspaceName)) - .totalEstimatedCost(null) - .usage(null) - .feedbackScores(null) - .duration(null) - .build()) - .toList(); + Comparator comparator = Comparator.comparing(Project::id).reversed(); - List expectedProjects = projects.parallelStream() - .map(project -> buildProjectStats(project, apiKey, workspaceName)) - .sorted(Comparator.comparing(Project::id).reversed()) - .toList(); + List expectedProjectStats = getProjectStatsSummaryItems(apiKey, workspaceName, comparator); var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .path("/stats") .request() .header(HttpHeaders.AUTHORIZATION, apiKey) .header(WORKSPACE_HEADER, workspaceName) @@ -1174,15 +1165,13 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsage__thenReturnPr var actualEntity = actualResponse.readEntity(Project.ProjectPage.class); assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(org.apache.http.HttpStatus.SC_OK); - assertThat(expectedProjects).hasSameSizeAs(actualEntity.content()); + assertThat(expectedProjectStats).hasSameSizeAs(actualEntity.content()); assertThat(actualEntity.content()) .usingRecursiveComparison() - .ignoringFields("createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt", "lastUpdatedTraceAt") - .ignoringCollectionOrder() .withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class) .withComparatorForFields(StatsUtils::closeToEpsilonComparator, "totalEstimatedCost") - .isEqualTo(expectedProjects); + .isEqualTo(expectedProjectStats); } @Test @@ -1194,21 +1183,9 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsageSortedLastTrac mockTargetWorkspace(apiKey, workspaceName, workspaceId); - var projects = PodamFactoryUtils.manufacturePojoList(factory, Project.class) - .parallelStream() - .map(project -> project.toBuilder() - .id(createProject(project, apiKey, workspaceName)) - .totalEstimatedCost(null) - .usage(null) - .feedbackScores(null) - .duration(null) - .build()) - .toList(); + Comparator comparator = Comparator.comparing(Project::lastUpdatedTraceAt).reversed(); - List expectedProjects = projects.parallelStream() - .map(project -> buildProjectStats(project, apiKey, workspaceName)) - .sorted(Comparator.comparing(Project::id).reversed()) - .toList(); + List expectedProjectStats = getProjectStatsSummaryItems(apiKey, workspaceName, comparator); var sorting = List.of(SortingField.builder() .field(SortableFields.LAST_UPDATED_TRACE_AT) @@ -1218,23 +1195,49 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsageSortedLastTrac var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) .queryParam("sorting", URLEncoder.encode(JsonUtils.writeValueAsString(sorting), StandardCharsets.UTF_8)) + .path("/stats") .request() .header(HttpHeaders.AUTHORIZATION, apiKey) .header(WORKSPACE_HEADER, workspaceName) .get(); - var actualEntity = actualResponse.readEntity(Project.ProjectPage.class); - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(org.apache.http.HttpStatus.SC_OK); + var actualEntity = actualResponse.readEntity(ProjectStatsSummary.class); - assertThat(expectedProjects).hasSameSizeAs(actualEntity.content()); + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_OK); + + assertThat(expectedProjectStats).hasSameSizeAs(actualEntity.content()); assertThat(actualEntity.content()) .usingRecursiveComparison() - .ignoringFields("createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt", "lastUpdatedTraceAt") - .ignoringCollectionOrder() .withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class) .withComparatorForFields(StatsUtils::closeToEpsilonComparator, "totalEstimatedCost") - .isEqualTo(expectedProjects); + .isEqualTo(expectedProjectStats); + } + + private List getProjectStatsSummaryItems(String apiKey, String workspaceName, Comparator comparing) { + var projects = PodamFactoryUtils.manufacturePojoList(factory, Project.class) + .parallelStream() + .map(project -> project.toBuilder() + .id(createProject(project, apiKey, workspaceName)) + .totalEstimatedCost(null) + .usage(null) + .feedbackScores(null) + .duration(null) + .build()) + .toList(); + + return projects + .parallelStream() + .map(project -> buildProjectStats(project, apiKey, workspaceName)) + .sorted(comparing) + .map(project -> ProjectStatsSummaryItem.builder() + .duration(project.duration()) + .totalEstimatedCost(project.totalEstimatedCost()) + .usage(project.usage()) + .feedbackScores(project.feedbackScores()) + .projectId(project.id()) + .build()) + .toList(); } @Test @@ -1257,34 +1260,34 @@ void getProjects__whenProjectsHasNoTracesSpansFeedbackScoresAndUsage__thenReturn .build()) .toList(); - List expectedProjects = projects.parallelStream() - .map(project -> project.toBuilder() + List expectedProjectStats = projects.parallelStream() + .map(project -> ProjectStatsSummaryItem.builder() .duration(null) .totalEstimatedCost(null) .usage(null) .feedbackScores(null) .build()) - .sorted(Comparator.comparing(Project::id).reversed()) + .sorted(Comparator.comparing(ProjectStatsSummaryItem::projectId).reversed()) .toList(); var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .path("/stats") .request() .header(HttpHeaders.AUTHORIZATION, apiKey) .header(WORKSPACE_HEADER, workspaceName) .get(); - var actualEntity = actualResponse.readEntity(Project.ProjectPage.class); - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(org.apache.http.HttpStatus.SC_OK); + var actualEntity = actualResponse.readEntity(ProjectStatsSummary.class); - assertThat(expectedProjects).hasSameSizeAs(actualEntity.content()); + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_OK); + + assertThat(expectedProjectStats).hasSameSizeAs(actualEntity.content()); assertThat(actualEntity.content()) .usingRecursiveComparison() - .ignoringFields("createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt", "lastUpdatedTraceAt") - .ignoringCollectionOrder() .withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class) .withComparatorForFields(StatsUtils::closeToEpsilonComparator, "totalEstimatedCost") - .isEqualTo(expectedProjects); + .isEqualTo(expectedProjectStats); } private Project buildProjectStats(Project project, String apiKey, String workspaceName) { @@ -1356,6 +1359,7 @@ private Project buildProjectStats(Project project, String apiKey, String workspa .flatMap(usage -> usage.entrySet().stream()) .collect(groupingBy(Map.Entry::getKey, averagingDouble(Map.Entry::getValue)))) .feedbackScores(getScoreAverages(traces)) + .lastUpdatedTraceAt(traces.stream().map(Trace::lastUpdatedAt).max(Instant::compareTo).orElse(null)) .build(); } From 56099b210891095385debf638039b6b59a82281b Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Thu, 30 Jan 2025 12:23:44 +0100 Subject: [PATCH 2/2] OPIK-904: Fix tests --- .../resources/v1/priv/ProjectsResource.java | 3 +- .../com/comet/opik/domain/ProjectService.java | 6 +- .../java/com/comet/opik/domain/SpanDAO.java | 2 +- .../v1/priv/ProjectsResourceTest.java | 231 +++++++++--------- 4 files changed, 128 insertions(+), 114 deletions(-) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java index dbec8d9d8b..7d82a5d787 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java @@ -292,7 +292,8 @@ public Response getProjectStats( log.info("Find projects stats by '{}' on workspaceId '{}'", criteria, workspaceId); ProjectStatsSummary projectStatisticsSummary = projectService.getStats(page, size, criteria, sortingFields); - log.info("Found projects stats by '{}', count '{}' on workspaceId '{}'", criteria, projectStatisticsSummary.content().size(), workspaceId); + log.info("Found projects stats by '{}', count '{}' on workspaceId '{}'", criteria, + projectStatisticsSummary.content().size(), workspaceId); return Response.ok().entity(projectStatisticsSummary).build(); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java index d2adc5be8d..d5f68cd262 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java @@ -85,7 +85,8 @@ public interface ProjectService { void recordLastUpdatedTrace(String workspaceId, Collection lastUpdatedTraces); - ProjectStatsSummary getStats(int page, int size, @NonNull ProjectCriteria criteria, @NonNull List sortingFields); + ProjectStatsSummary getStats(int page, int size, @NonNull ProjectCriteria criteria, + @NonNull List sortingFields); } @Slf4j @@ -487,7 +488,8 @@ public Project retrieveByName(@NonNull String projectName) { .feedbackScores(StatsMapper.getStatsFeedbackScores(projectStats.get(project.id()))) .usage(StatsMapper.getStatsUsage(projectStats.get(project.id()))) .duration(StatsMapper.getStatsDuration(projectStats.get(project.id()))) - .totalEstimatedCost(StatsMapper.getStatsTotalEstimatedCost(projectStats.get(project.id()))) + .totalEstimatedCost( + StatsMapper.getStatsTotalEstimatedCost(projectStats.get(project.id()))) .build(); }) .orElseThrow(this::createNotFoundError); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java index 13d3d52ba7..e416a99a7e 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java @@ -604,7 +604,7 @@ AND notEquals(s.start_time, toDateTime64('1970-01-01 00:00:00.000', 9)), FROM ( SELECT *, - row_number() OVER (PARTITION BY id ORDER BY last_updated_at DESC) AS latest\s + row_number() OVER (PARTITION BY id ORDER BY last_updated_at DESC) AS latest FROM spans WHERE id IN (SELECT id FROM span_ids) ) AS s diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java index c94a4062ba..be9fb937d3 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java @@ -115,6 +115,8 @@ class ProjectsResourceTest { public static final String URL_TEMPLATE_TRACE = "%s/v1/private/traces"; public static final String[] IGNORED_FIELDS = {"createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt", "lastUpdatedTraceAt", "feedbackScores", "duration", "totalEstimatedCost", "usage"}; + public static final String[] IGNORED_FIELD_MIN = {"createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt", + "lastUpdatedTraceAt"}; private static final String API_KEY = UUID.randomUUID().toString(); private static final String USER = UUID.randomUUID().toString(); @@ -582,6 +584,8 @@ void getProjectById__whenProjectExists__thenReturnProject() { var id = createProject(project, apiKey, workspaceName); + project = buildProjectStats(project.toBuilder().id(id).build(), apiKey, workspaceName); + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) .path("retrieve") .request() @@ -589,16 +593,17 @@ void getProjectById__whenProjectExists__thenReturnProject() { .header(WORKSPACE_HEADER, workspaceName) .post(Entity.json(ProjectRetrieve.builder().name(project.name()).build()))) { - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_OK); assertThat(actualResponse.hasEntity()).isTrue(); var actualEntity = actualResponse.readEntity(Project.class); assertThat(actualEntity) .usingRecursiveComparison() - .ignoringFields(IGNORED_FIELDS) - .isEqualTo(project.toBuilder() - .id(id) - .build()); + .ignoringFields(IGNORED_FIELD_MIN) + .ignoringCollectionOrder() + .withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class) + .withComparatorForFields(StatsUtils::closeToEpsilonComparator, "totalEstimatedCost") + .isEqualTo(project); } } @@ -1153,7 +1158,8 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsage__thenReturnPr Comparator comparator = Comparator.comparing(Project::id).reversed(); - List expectedProjectStats = getProjectStatsSummaryItems(apiKey, workspaceName, comparator); + List expectedProjectStats = getProjectStatsSummaryItems(apiKey, workspaceName, + comparator); var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) .path("/stats") @@ -1162,13 +1168,14 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsage__thenReturnPr .header(WORKSPACE_HEADER, workspaceName) .get(); - var actualEntity = actualResponse.readEntity(Project.ProjectPage.class); + var actualEntity = actualResponse.readEntity(ProjectStatsSummary.class); assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(org.apache.http.HttpStatus.SC_OK); assertThat(expectedProjectStats).hasSameSizeAs(actualEntity.content()); assertThat(actualEntity.content()) .usingRecursiveComparison() + .ignoringCollectionOrder() .withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class) .withComparatorForFields(StatsUtils::closeToEpsilonComparator, "totalEstimatedCost") .isEqualTo(expectedProjectStats); @@ -1185,7 +1192,8 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsageSortedLastTrac Comparator comparator = Comparator.comparing(Project::lastUpdatedTraceAt).reversed(); - List expectedProjectStats = getProjectStatsSummaryItems(apiKey, workspaceName, comparator); + List expectedProjectStats = getProjectStatsSummaryItems(apiKey, workspaceName, + comparator); var sorting = List.of(SortingField.builder() .field(SortableFields.LAST_UPDATED_TRACE_AT) @@ -1209,12 +1217,14 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsageSortedLastTrac assertThat(actualEntity.content()) .usingRecursiveComparison() + .ignoringCollectionOrder() .withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class) .withComparatorForFields(StatsUtils::closeToEpsilonComparator, "totalEstimatedCost") .isEqualTo(expectedProjectStats); } - private List getProjectStatsSummaryItems(String apiKey, String workspaceName, Comparator comparing) { + private List getProjectStatsSummaryItems(String apiKey, String workspaceName, + Comparator comparing) { var projects = PodamFactoryUtils.manufacturePojoList(factory, Project.class) .parallelStream() .map(project -> project.toBuilder() @@ -1266,6 +1276,7 @@ void getProjects__whenProjectsHasNoTracesSpansFeedbackScoresAndUsage__thenReturn .totalEstimatedCost(null) .usage(null) .feedbackScores(null) + .projectId(project.id()) .build()) .sorted(Comparator.comparing(ProjectStatsSummaryItem::projectId).reversed()) .toList(); @@ -1290,107 +1301,6 @@ void getProjects__whenProjectsHasNoTracesSpansFeedbackScoresAndUsage__thenReturn .isEqualTo(expectedProjectStats); } - private Project buildProjectStats(Project project, String apiKey, String workspaceName) { - var traces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class).stream() - .map(trace -> { - Instant startTime = Instant.now(); - Instant endTime = startTime.plusMillis(PodamUtils.getIntegerInRange(1, 1000)); - return trace.toBuilder() - .projectName(project.name()) - .startTime(startTime) - .endTime(endTime) - .duration(DurationUtils.getDurationInMillisWithSubMilliPrecision(startTime, endTime)) - .build(); - }) - .toList(); - - traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName); - - List scores = PodamFactoryUtils.manufacturePojoList(factory, - FeedbackScoreBatchItem.class); - - traces = traces.stream().map(trace -> { - List spans = PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() - .map(span -> span.toBuilder() - .usage(spanResourceClient.getTokenUsage()) - .model(spanResourceClient.randomModel().toString()) - .traceId(trace.id()) - .projectName(trace.projectName()) - .totalEstimatedCost(null) - .build()) - .toList(); - - spanResourceClient.batchCreateSpans(spans, apiKey, workspaceName); - - List feedbackScores = scores.stream() - .map(feedbackScore -> feedbackScore.toBuilder() - .projectId(project.id()) - .projectName(project.name()) - .id(trace.id()) - .build()) - .toList(); - - traceResourceClient.feedbackScores(feedbackScores, apiKey, workspaceName); - - return trace.toBuilder() - .feedbackScores( - feedbackScores.stream() - .map(score -> FeedbackScore.builder() - .value(score.value()) - .name(score.name()) - .build()) - .toList()) - .usage(StatsUtils.aggregateSpansUsage(spans)) - .totalEstimatedCost(StatsUtils.aggregateSpansCost(spans)) - .build(); - }).toList(); - - List durations = StatsUtils.calculateQuantiles( - traces.stream() - .map(Trace::duration) - .toList(), - List.of(0.5, 0.90, 0.99)); - - return project.toBuilder() - .duration(new ProjectStats.PercentageValues(durations.get(0), durations.get(1), durations.get(2))) - .totalEstimatedCost(getTotalEstimatedCost(traces)) - .usage(traces.stream() - .map(Trace::usage) - .flatMap(usage -> usage.entrySet().stream()) - .collect(groupingBy(Map.Entry::getKey, averagingDouble(Map.Entry::getValue)))) - .feedbackScores(getScoreAverages(traces)) - .lastUpdatedTraceAt(traces.stream().map(Trace::lastUpdatedAt).max(Instant::compareTo).orElse(null)) - .build(); - } - - private List getScoreAverages(List traces) { - return traces.stream() - .map(Trace::feedbackScores) - .flatMap(List::stream) - .collect(groupingBy(FeedbackScore::name, - BigDecimalCollectors.averagingBigDecimal(FeedbackScore::value))) - .entrySet() - .stream() - .map(entry -> FeedbackScoreAverage.builder() - .name(entry.getKey()) - .value(entry.getValue()) - .build()) - .toList(); - } - - private double getTotalEstimatedCost(List traces) { - long count = traces.stream() - .map(Trace::totalEstimatedCost) - .filter(Objects::nonNull) - .filter(cost -> cost.compareTo(BigDecimal.ZERO) > 0) - .count(); - - return traces.stream() - .map(Trace::totalEstimatedCost) - .reduce(BigDecimal.ZERO, BigDecimal::add) - .divide(BigDecimal.valueOf(count), ValidationUtils.SCALE, RoundingMode.HALF_UP).doubleValue(); - } - @Test @DisplayName("when projects is with traces created in batch, then return project with last updated trace at") void getProjects__whenProjectsHasTracesBatch__thenReturnProjectWithLastUpdatedTraceAt() { @@ -1507,6 +1417,107 @@ private void assertAllProjectsHavePersistedLastTraceAt(String workspaceId, List< } } + private Project buildProjectStats(Project project, String apiKey, String workspaceName) { + var traces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class).stream() + .map(trace -> { + Instant startTime = Instant.now(); + Instant endTime = startTime.plusMillis(PodamUtils.getIntegerInRange(1, 1000)); + return trace.toBuilder() + .projectName(project.name()) + .startTime(startTime) + .endTime(endTime) + .duration(DurationUtils.getDurationInMillisWithSubMilliPrecision(startTime, endTime)) + .build(); + }) + .toList(); + + traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName); + + List scores = PodamFactoryUtils.manufacturePojoList(factory, + FeedbackScoreBatchItem.class); + + traces = traces.stream().map(trace -> { + List spans = PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() + .map(span -> span.toBuilder() + .usage(spanResourceClient.getTokenUsage()) + .model(spanResourceClient.randomModel().toString()) + .traceId(trace.id()) + .projectName(trace.projectName()) + .totalEstimatedCost(null) + .build()) + .toList(); + + spanResourceClient.batchCreateSpans(spans, apiKey, workspaceName); + + List feedbackScores = scores.stream() + .map(feedbackScore -> feedbackScore.toBuilder() + .projectId(project.id()) + .projectName(project.name()) + .id(trace.id()) + .build()) + .toList(); + + traceResourceClient.feedbackScores(feedbackScores, apiKey, workspaceName); + + return trace.toBuilder() + .feedbackScores( + feedbackScores.stream() + .map(score -> FeedbackScore.builder() + .value(score.value()) + .name(score.name()) + .build()) + .toList()) + .usage(StatsUtils.aggregateSpansUsage(spans)) + .totalEstimatedCost(StatsUtils.aggregateSpansCost(spans)) + .build(); + }).toList(); + + List durations = StatsUtils.calculateQuantiles( + traces.stream() + .map(Trace::duration) + .toList(), + List.of(0.5, 0.90, 0.99)); + + return project.toBuilder() + .duration(new ProjectStats.PercentageValues(durations.get(0), durations.get(1), durations.get(2))) + .totalEstimatedCost(getTotalEstimatedCost(traces)) + .usage(traces.stream() + .map(Trace::usage) + .flatMap(usage -> usage.entrySet().stream()) + .collect(groupingBy(Map.Entry::getKey, averagingDouble(Map.Entry::getValue)))) + .feedbackScores(getScoreAverages(traces)) + .lastUpdatedTraceAt(traces.stream().map(Trace::lastUpdatedAt).max(Instant::compareTo).orElse(null)) + .build(); + } + + private List getScoreAverages(List traces) { + return traces.stream() + .map(Trace::feedbackScores) + .flatMap(List::stream) + .collect(groupingBy(FeedbackScore::name, + BigDecimalCollectors.averagingBigDecimal(FeedbackScore::value))) + .entrySet() + .stream() + .map(entry -> FeedbackScoreAverage.builder() + .name(entry.getKey()) + .value(entry.getValue()) + .build()) + .toList(); + } + + private double getTotalEstimatedCost(List traces) { + long count = traces.stream() + .map(Trace::totalEstimatedCost) + .filter(Objects::nonNull) + .filter(cost -> cost.compareTo(BigDecimal.ZERO) > 0) + .count(); + + return traces.stream() + .map(Trace::totalEstimatedCost) + .reduce(BigDecimal.ZERO, BigDecimal::add) + .divide(BigDecimal.valueOf(count), ValidationUtils.SCALE, RoundingMode.HALF_UP).doubleValue(); + } + @Nested @DisplayName("Get: {id}") @TestInstance(TestInstance.Lifecycle.PER_CLASS)